diff --git a/.asf.yaml b/.asf.yaml index a3522af9efd0f..c605a4692974e 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -21,8 +21,8 @@ # https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features notifications: - commits: commits@arrow.apache.org - issues: github@arrow.apache.org + commits: commits@arrow.apache.org + issues: github@arrow.apache.org pullrequests: github@arrow.apache.org jira_options: link label worklog github: @@ -44,6 +44,10 @@ github: rebase: false features: issues: true + protected_branches: + main: + required_pull_request_reviews: + required_approving_review_count: 1 # publishes the content of the `asf-site` branch to # https://arrow.apache.org/datafusion/ diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index a46513fc39d9d..5a93f6f27b436 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -43,6 +43,11 @@ runs: # "1" means line tables only, which is useful for panic tracebacks. shell: bash run: echo "RUSTFLAGS=-C debuginfo=1" >> $GITHUB_ENV + - name: Disable incremental compilation + # Disable incremental compilation to save diskspace (the CI doesn't recompile modified files) + # https://github.com/apache/arrow-datafusion/issues/6676 + shell: bash + run: echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV - name: Enable backtraces shell: bash run: echo "RUST_BACKTRACE=1" >> $GITHUB_ENV diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ec66a1270cb85..06db092d6fc89 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -# Which issue does this PR close? +## Which issue does this PR close? -# What changes are included in this PR? +## What changes are included in this PR? -# Are these changes tested? +## Are these changes tested? -# Are there any user-facing changes? +## Are there any user-facing changes? \ No newline at end of file +--> diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 8854f53a40bc4..19af21ec910be 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -28,9 +28,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Audit licenses @@ -40,10 +40,10 @@ jobs: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-node@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 with: - node-version: "14" + node-version: "20" - name: Prettier check run: | # if you encounter error, rerun the command below and commit the changes diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 7499239a9d122..77b257743331e 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -39,14 +39,14 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@4.1.0 + uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index bcfb1b4791704..34a37948785b5 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -16,35 +16,37 @@ # under the License. development-process: - - dev/**.* - - .github/**.* - - ci/**.* - - .asf.yaml +- changed-files: + - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] documentation: - - docs/**.* - - README.md - - ./**/README.md - - DEVELOPERS.md - - datafusion/docs/**.* +- changed-files: + - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] sql: - - datafusion/sql/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sql/**/*'] logical-expr: - - datafusion/expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/expr/**/*'] physical-expr: - - datafusion/physical-expr/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] optimizer: - - datafusion/optimizer/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/optimizer/**/*'] core: - - datafusion/core/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/core/**/*'] substrait: - - datafusion/substrait/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/substrait/**/*'] sqllogictest: - - datafusion/core/tests/sqllogictests/**/* +- changed-files: + - any-glob-to-any-file: ['datafusion/sqllogictest/**/*'] diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 998c62034804a..ab6a615ab60be 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -15,16 +15,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Checkout asf-site branch - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: asf-site path: asf-site - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index 821321c8c571d..c2f3dd684a23e 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -40,14 +40,13 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - # Note: this does not include dictionary_expressions to reduce codegen - name: Run doctests run: cargo test --doc --features avro,json - name: Verify Working Directory Clean diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 7c6f468482431..099aab0614357 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -46,26 +46,31 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 - - name: Cache Cargo - uses: actions/cache@v3 - with: - # these represent dependencies downloaded by cargo - # and thus do not depend on the OS, arch nor rust version. - path: /github/home/.cargo - key: cargo-cache- + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Cache Cargo + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + ./target/ + ./datafusion-cli/target/ + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache-benchmark-${{ hashFiles('datafusion/**/Cargo.toml', 'benchmarks/Cargo.toml', 'datafusion-cli/Cargo.toml') }} + - name: Check workspace without default features run: cargo check --no-default-features -p datafusion - name: Check workspace in debug mode run: cargo check - # Note: this does not include dictionary_expressions to reduce codegen - name: Check workspace with all features run: cargo check --workspace --benches --features avro,json - name: Check Cargo.lock for datafusion-cli @@ -82,21 +87,36 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- + rust-version: stable + - name: Run tests (excluding doctests) + run: cargo test --lib --tests --bins --features avro,json,backtrace + - name: Verify Working Directory Clean + run: git diff --exit-code + + linux-test-datafusion-cli: + name: cargo test datafusion-cli (amd64) + needs: [ linux-build-lib ] + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Run tests (excluding doctests) - run: cargo test --lib --tests --bins --features avro,json,dictionary_expressions + run: | + cd datafusion-cli + cargo test --lib --tests --bins --all-features - name: Verify Working Directory Clean run: git diff --exit-code @@ -107,7 +127,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain @@ -119,19 +139,7 @@ jobs: # test datafusion-sql examples cargo run --example sql # test datafusion-examples - cargo run --example avro_sql --features=datafusion/avro - cargo run --example csv_sql - cargo run --example custom_datasource - cargo run --example dataframe - cargo run --example dataframe_in_memory - cargo run --example deserialize_to_struct - cargo run --example expr_api - cargo run --example parquet_sql - cargo run --example parquet_sql_multiple_files - cargo run --example memtable - cargo run --example rewrite_expr - cargo run --example simple_udf - cargo run --example simple_udaf + ci/scripts/rust_example.sh - name: Verify Working Directory Clean run: git diff --exit-code @@ -143,16 +151,18 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - # Note: this does not include dictionary_expressions to reduce codegen - name: Run doctests - run: cargo test --doc --features avro,json + run: | + cargo test --doc --features avro,json + cd datafusion-cli + cargo test --doc --all-features - name: Verify Working Directory Clean run: git diff --exit-code @@ -164,7 +174,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -173,6 +183,25 @@ jobs: run: | export RUSTDOCFLAGS="-D warnings -A rustdoc::private-intra-doc-links" cargo doc --document-private-items --no-deps --workspace + cd datafusion-cli + cargo doc --document-private-items --no-deps + + linux-wasm-pack: + name: build with wasm-pack + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Install wasm-pack + run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + - name: Build with wasm-pack + working-directory: ./datafusion/wasmtest + run: wasm-pack build --dev # verify that the benchmark queries return the correct results verify-benchmark-results: @@ -182,32 +211,27 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - name: Generate benchmark data and expected query results run: | - mkdir -p datafusion/core/tests/sqllogictests/test_files/tpch/data + mkdir -p datafusion/sqllogictest/test_files/tpch/data git clone https://github.com/databricks/tpch-dbgen.git cd tpch-dbgen make ./dbgen -f -s 0.1 - mv *.tbl ../datafusion/core/tests/sqllogictests/test_files/tpch/data + mv *.tbl ../datafusion/sqllogictest/test_files/tpch/data - name: Verify that benchmark queries return expected results run: | - export TPCH_DATA=`realpath datafusion/core/tests/sqllogictests/test_files/tpch/data` - cargo test serde_q --profile release-nonlto --features=ci -- --test-threads=1 - INCLUDE_TPCH=true cargo test -p datafusion --test sqllogictests + export TPCH_DATA=`realpath datafusion/sqllogictest/test_files/tpch/data` + # use release build for plan verificaton because debug build causes stack overflow + cargo test plan_q --package datafusion-benchmarks --profile release-nonlto --features=ci -- --test-threads=1 + INCLUDE_TPCH=true cargo test --test sqllogictests - name: Verify Working Directory Clean run: git diff --exit-code @@ -230,7 +254,7 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Setup toolchain @@ -238,7 +262,7 @@ jobs: rustup toolchain install stable rustup default stable - name: Run sqllogictest - run: PG_COMPAT=true PG_URI="postgresql://postgres:postgres@localhost:$POSTGRES_PORT/db_test" cargo test -p datafusion --test sqllogictests + run: PG_COMPAT=true PG_URI="postgresql://postgres:postgres@localhost:$POSTGRES_PORT/db_test" cargo test --features=postgres --test sqllogictests env: POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} @@ -246,7 +270,7 @@ jobs: name: cargo test (win64) runs-on: windows-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protobuf compiler @@ -270,16 +294,19 @@ jobs: shell: bash run: | export PATH=$PATH:$HOME/d/protoc/bin - cargo test --lib --tests --bins --features avro,json,dictionary_expressions + cargo test --lib --tests --bins --features avro,json,backtrace + cd datafusion-cli + cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down RUSTFLAGS: "-C debuginfo=0" + RUST_BACKTRACE: "1" macos: name: cargo test (mac) runs-on: macos-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - name: Install protobuf compiler @@ -303,28 +330,25 @@ jobs: - name: Run tests (excluding doctests) shell: bash run: | - cargo test --lib --tests --bins --features avro,json,dictionary_expressions + cargo test --lib --tests --bins --features avro,json,backtrace + cd datafusion-cli + cargo test --lib --tests --bins --all-features env: # do not produce debug symbols to keep memory usage down RUSTFLAGS: "-C debuginfo=0" + RUST_BACKTRACE: "1" test-datafusion-pyarrow: name: cargo test pyarrow (amd64) needs: [ linux-build-lib ] runs-on: ubuntu-20.04 container: - image: amd64/rust + image: amd64/rust:bullseye # Workaround https://github.com/actions/setup-python/issues/721 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: "3.8" - name: Install PyArrow @@ -344,7 +368,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Run gen @@ -359,7 +383,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -376,7 +400,7 @@ jobs: # name: coverage # runs-on: ubuntu-latest # steps: - # - uses: actions/checkout@v3 + # - uses: actions/checkout@v4 # with: # submodules: true # - name: Install protobuf compiler @@ -418,15 +442,9 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -444,15 +462,9 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -469,15 +481,9 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -487,12 +493,11 @@ jobs: - name: Check Cargo.toml formatting run: | - # if you encounter error, try rerun the command below, finally run 'git diff' to - # check which Cargo.toml introduces formatting violation + # if you encounter an error, try running 'cargo tomlfmt -p path/to/Cargo.toml' to fix the formatting automatically. + # If the error still persists, you need to manually edit the Cargo.toml file, which introduces formatting violation. # # ignore ./Cargo.toml because putting workspaces in multi-line lists make it easy to read ci/scripts/rust_toml_fmt.sh - git diff --exit-code config-docs-check: name: check configs.md is up-to-date @@ -501,24 +506,43 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache- - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: - node-version: "14" + node-version: "20" - name: Check if configs.md has been modified run: | # If you encounter an error, run './dev/update_config_docs.sh' and commit ./dev/update_config_docs.sh git diff --exit-code + + # Verify MSRV for the crates which are directly used by other projects. + msrv: + name: Verify MSRV + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Install cargo-msrv + run: cargo install cargo-msrv + - name: Check datafusion + working-directory: datafusion/core + run: cargo msrv verify + - name: Check datafusion-substrait + working-directory: datafusion/substrait + run: cargo msrv verify + - name: Check datafusion-proto + working-directory: datafusion/proto + run: cargo msrv verify + - name: Check datafusion-cli + working-directory: datafusion-cli + run: cargo msrv verify diff --git a/.gitignore b/.gitignore index 65d3c0f345e39..203455e4a796e 100644 --- a/.gitignore +++ b/.gitignore @@ -103,4 +103,7 @@ datafusion/CHANGELOG.md.bak .githubchangeloggenerator.cache* # Generated tpch data -datafusion/core/tests/sqllogictests/test_files/tpch/data/* +datafusion/sqllogictests/test_files/tpch/data/* + +# Scratch temp dir for sqllogictests +datafusion/sqllogictest/test_files/scratch* diff --git a/Cargo.toml b/Cargo.toml index 6b24a44e9ad79..2bcbe059ab25f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,33 +24,72 @@ members = [ "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", + "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", - "datafusion/row", "datafusion/sql", + "datafusion/sqllogictest", "datafusion/substrait", + "datafusion/wasmtest", "datafusion-examples", + "docs", "test-utils", "benchmarks", ] +resolver = "2" [workspace.package] -version = "26.0.0" -edition = "2021" -readme = "README.md" authors = ["Apache Arrow "] -license = "Apache-2.0" +edition = "2021" homepage = "https://github.com/apache/arrow-datafusion" +license = "Apache-2.0" +readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.64" +rust-version = "1.70" +version = "33.0.0" [workspace.dependencies] -arrow = { version = "41.0.0", features = ["prettyprint"] } -arrow-flight = { version = "41.0.0", features = ["flight-sql-experimental"] } -arrow-buffer = { version = "41.0.0", default-features = false } -arrow-schema = { version = "41.0.0", default-features = false } -arrow-array = { version = "41.0.0", default-features = false, features = ["chrono-tz"] } -parquet = { version = "41.0.0", features = ["arrow", "async", "object_store"] } +arrow = { version = "49.0.0", features = ["prettyprint"] } +arrow-array = { version = "49.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "49.0.0", default-features = false } +arrow-flight = { version = "49.0.0", features = ["flight-sql-experimental"] } +arrow-ord = { version = "49.0.0", default-features = false } +arrow-schema = { version = "49.0.0", default-features = false } +async-trait = "0.1.73" +bigdecimal = "0.4.1" +bytes = "1.4" +ctor = "0.2.0" +datafusion = { path = "datafusion/core", version = "33.0.0" } +datafusion-common = { path = "datafusion/common", version = "33.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "33.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "33.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "33.0.0" } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "33.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "33.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "33.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "33.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "33.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "33.0.0" } +dashmap = "5.4.0" +doc-comment = "0.3" +env_logger = "0.10" +futures = "0.3" +half = "2.2.1" +indexmap = "2.0.0" +itertools = "0.12" +log = "^0.4" +num_cpus = "1.13.0" +object_store = { version = "0.8.0", default-features = false } +parking_lot = "0.12" +parquet = { version = "49.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +rand = "0.8" +rstest = "0.18.0" +serde_json = "1" +sqlparser = { version = "0.40.0", features = ["visitor"] } +tempfile = "3" +thiserror = "1.0.44" +chrono = { version = "0.4.31", default-features = false } +url = "2.2" [profile.release] codegen-units = 1 diff --git a/README.md b/README.md index 2ddb4e6630a5b..883700a39355a 100644 --- a/README.md +++ b/README.md @@ -19,26 +19,61 @@ # DataFusion -[![Coverage Status](https://codecov.io/gh/apache/arrow-datafusion/rust/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/arrow-datafusion?branch=master) - logo DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) -in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-python) are also available. +in-memory format. [Python Bindings](https://github.com/apache/arrow-datafusion-python) are also available. DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. -DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. +Here are links to some important information -https://arrow.apache.org/datafusion/ contains the project's documentation. +- [Project Site](https://arrow.apache.org/datafusion) +- [Installation](https://arrow.apache.org/datafusion/user-guide/cli.html#installation) +- [Rust Getting Started](https://arrow.apache.org/datafusion/user-guide/example-usage.html) +- [Rust DataFrame API](https://arrow.apache.org/datafusion/user-guide/dataframe.html) +- [Rust API docs](https://docs.rs/datafusion/latest/datafusion) +- [Rust Examples](https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples) +- [Python DataFrame API](https://arrow.apache.org/datafusion-python/) +- [Architecture](https://docs.rs/datafusion/latest/datafusion/index.html#architecture) -## Using DataFusion +## What can you do with this crate? -The [example usage] section in the user guide and the [datafusion-examples] code in the crate contain information on using DataFusion. +DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://arrow.apache.org/datafusion/user-guide/introduction.html#known-users) to see a list known users. ## Contributing to DataFusion -The [developer’s guide] contains information on how to contribute. +Please see the [developer’s guide] for contributing and [communication] for getting in touch with us. -[example usage]: https://arrow.apache.org/datafusion/user-guide/example-usage.html -[datafusion-examples]: https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples [developer’s guide]: https://arrow.apache.org/datafusion/contributor-guide/index.html#developer-s-guide +[communication]: https://arrow.apache.org/datafusion/contributor-guide/communication.html + +## Crate features + +This crate has several [features] which can be specified in your `Cargo.toml`. + +[features]: https://doc.rust-lang.org/cargo/reference/features.html + +Default features: + +- `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` +- `crypto_expressions`: cryptographic functions such as `md5` and `sha256` +- `encoding_expressions`: `encode` and `decode` functions +- `parquet`: support for reading the [Apache Parquet] format +- `regex_expressions`: regular expression functions, such as `regexp_match` +- `unicode_expressions`: Include unicode aware functions such as `character_length` + +Optional features: + +- `avro`: support for reading the [Apache Avro] format +- `backtrace`: include backtrace information in error messages +- `pyarrow`: conversions between PyArrow and DataFusion types +- `serde`: enable arrow-schema's `serde` feature +- `simd`: enable arrow-rs's manual `SIMD` kernels (requires Rust `nightly`) + +[apache avro]: https://avro.apache.org/ +[apache parquet]: https://parquet.apache.org/ + +## Rust Version Compatibility + +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 0ccb4fb84a1ad..c5a24a0a5cf91 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,13 +18,13 @@ [package] name = "datafusion-benchmarks" description = "DataFusion Benchmarks" -version = "26.0.0" -edition = "2021" +version = "33.0.0" +edition = { workspace = true } authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" -rust-version = "1.62" +rust-version = "1.70" [features] ci = [] @@ -34,18 +34,20 @@ snmalloc = ["snmalloc-rs"] [dependencies] arrow = { workspace = true } -datafusion = { path = "../datafusion/core", version = "26.0.0" } -env_logger = "0.10" -futures = "0.3" +datafusion = { path = "../datafusion/core", version = "33.0.0" } +datafusion-common = { path = "../datafusion/common", version = "33.0.0" } +env_logger = { workspace = true } +futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } -num_cpus = "1.13.0" -parquet = { workspace = true } +num_cpus = { workspace = true } +parquet = { workspace = true, default-features = true } serde = { version = "1.0.136", features = ["derive"] } -serde_json = "1.0.78" +serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } [dev-dependencies] -datafusion-proto = { path = "../datafusion/proto", version = "26.0.0" } +datafusion-proto = { path = "../datafusion/proto", version = "33.0.0" } diff --git a/benchmarks/README.md b/benchmarks/README.md index cf8a20a823f58..c0baa43ab8708 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -20,11 +20,14 @@ # DataFusion Benchmarks This crate contains benchmarks based on popular public data sets and -open source benchmark suites, making it easy to run more realistic -benchmarks to help with performance and scalability testing of DataFusion. +open source benchmark suites, to help with performance and scalability +testing of DataFusion. -# Benchmarks Against Other Engines +## Other engines + +The benchmarks measure changes to DataFusion itself, rather than +its performance against other engines. For competitive benchmarking, DataFusion is included in the benchmark setups for several popular benchmarks that compare performance with other engines. For example: @@ -36,30 +39,35 @@ benchmarks that compare performance with other engines. For example: # Running the benchmarks -## Running Benchmarks +## `bench.sh` -The easiest way to run benchmarks from DataFusion source checkouts is -to use the [bench.sh](bench.sh) script. Usage instructions can be -found with: +The easiest way to run benchmarks is the [bench.sh](bench.sh) +script. Usage instructions can be found with: ```shell # show usage ./bench.sh ``` -## Generating Data +## Generating data + +You can create / download the data for these benchmarks using the [bench.sh](bench.sh) script: -You can create data for all these benchmarks using the [bench.sh](bench.sh) script: +Create / download all datasets ```shell ./bench.sh data ``` -Data is generated in the `data` subdirectory and will not be checked -in because this directory has been added to the `.gitignore` file. +Create / download a specific dataset (TPCH) + +```shell +./bench.sh data tpch +``` +Data is placed in the `data` subdirectory. -## Example to compare peformance on main to a branch +## Comparing performance of main and a branch ```shell git checkout main @@ -143,27 +151,17 @@ Benchmark tpch_mem.json ``` -# Benchmark Descriptions: - -## `tpch` Benchmark derived from TPC-H +### Running Benchmarks Manually -These benchmarks are derived from the [TPC-H][1] benchmark. And we use this repo as the source of tpch-gen and answers: -https://github.com/databricks/tpch-dbgen.git, based on [2.17.1](https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf) version of TPC-H. - - -### Running the DataFusion Benchmarks Manually - -The benchmark can then be run (assuming the data created from `dbgen` is in `./data`) with a command such as: +Assuming data in the `data` directory, the `tpch` benchmark can be run with a command like this ```bash -cargo run --release --bin tpch -- benchmark datafusion --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 +cargo run --release --bin dfbench -- tpch --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 ``` -If you omit `--query=` argument, then all benchmarks will be run one by one (from query 1 to query 22). +See the help for more details -```bash -cargo run --release --bin tpch -- benchmark datafusion --iterations 1 --path ./data --format tbl --batch-size 4096 -``` +### Different features You can enable the features `simd` (to use SIMD instructions, `cargo nightly` is required.) and/or `mimalloc` or `snmalloc` (to use either the mimalloc or snmalloc allocator) as features by passing them in as `--features`: @@ -171,12 +169,6 @@ You can enable the features `simd` (to use SIMD instructions, `cargo nightly` is cargo run --release --features "simd mimalloc" --bin tpch -- benchmark datafusion --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 ``` -If you want to disable collection of statistics (and thus cost based optimizers), you can pass `--disable-statistics` flag. - -```bash -cargo run --release --bin tpch -- benchmark datafusion --iterations 3 --path /mnt/tpch-parquet --format parquet --query 17 --disable-statistics -``` - The benchmark program also supports CSV and Parquet input file formats and a utility is provided to convert from `tbl` (generated by the `dbgen` utility) to CSV and Parquet. @@ -188,9 +180,10 @@ Or if you want to verify and run all the queries in the benchmark, you can just ### Comparing results between runs -Any `tpch` execution with `-o ` argument will produce a summary file right under the `` -directory. It is a JSON serialized form of all the runs that happened as well as the runtime metadata -(number of cores, DataFusion version, etc.). +Any `dfbench` execution with `-o ` argument will produce a +summary JSON in the specified directory. This file contains a +serialized form of all the runs that happened and runtime +metadata (number of cores, DataFusion version, etc.). ```shell $ git checkout main @@ -236,89 +229,71 @@ This will produce output like └──────────────┴──────────────┴──────────────┴───────────────┘ ``` -### Expected output - -The result of query 1 should produce the following output when executed against the SF=1 dataset. +# Benchmark Runner -``` -+--------------+--------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------+-------------+ -| l_returnflag | l_linestatus | sum_qty | sum_base_price | sum_disc_price | sum_charge | avg_qty | avg_price | avg_disc | count_order | -+--------------+--------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------+-------------+ -| A | F | 37734107 | 56586554400.73001 | 53758257134.870026 | 55909065222.82768 | 25.522005853257337 | 38273.12973462168 | 0.049985295838396455 | 1478493 | -| N | F | 991417 | 1487504710.3799996 | 1413082168.0541 | 1469649223.1943746 | 25.516471920522985 | 38284.467760848296 | 0.05009342667421622 | 38854 | -| N | O | 74476023 | 111701708529.50996 | 106118209986.10472 | 110367023144.56622 | 25.502229680934594 | 38249.1238377803 | 0.049996589476752576 | 2920373 | -| R | F | 37719753 | 56568041380.90001 | 53741292684.60399 | 55889619119.83194 | 25.50579361269077 | 38250.854626099666 | 0.05000940583012587 | 1478870 | -+--------------+--------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------+-------------+ -Query 1 iteration 0 took 1956.1 ms -Query 1 avg time: 1956.11 ms -``` +The `dfbench` program contains subcommands to run the various +benchmarks. When benchmarking, it should always be built in release +mode using `--release`. -## NYC Taxi Benchmark +Full help for each benchmark can be found in the relevant sub +command. For example to get help for tpch, run -These benchmarks are based on the [New York Taxi and Limousine Commission][2] data set. +```shell +cargo run --release --bin dfbench --help +... +datafusion-benchmarks 27.0.0 +benchmark command -```bash -cargo run --release --bin nyctaxi -- --iterations 3 --path /mnt/nyctaxi/csv --format csv --batch-size 4096 -``` +USAGE: + dfbench -Example output: +SUBCOMMANDS: + clickbench Run the clickbench benchmark + help Prints this message or the help of the given subcommand(s) + parquet-filter Test performance of parquet filter pushdown + sort Test performance of parquet filter pushdown + tpch Run the tpch benchmark. + tpch-convert Convert tpch .slt files to .parquet or .csv files -```bash -Running benchmarks with the following options: Opt { debug: false, iterations: 3, batch_size: 4096, path: "/mnt/nyctaxi/csv", file_format: "csv" } -Executing 'fare_amt_by_passenger' -Query 'fare_amt_by_passenger' iteration 0 took 7138 ms -Query 'fare_amt_by_passenger' iteration 1 took 7599 ms -Query 'fare_amt_by_passenger' iteration 2 took 7969 ms ``` -## h2o benchmarks +# Benchmarks -```bash -cargo run --release --bin h2o group-by --query 1 --path /mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv --mem-table --debug -``` +The output of `dfbench` help includes a descripion of each benchmark, which is reproducedd here for convenience -Example run: +## ClickBench -``` -Running benchmarks with the following options: GroupBy(GroupBy { query: 1, path: "/mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv", debug: false }) -Executing select id1, sum(v1) as v1 from x group by id1 -+-------+--------+ -| id1 | v1 | -+-------+--------+ -| id063 | 199420 | -| id094 | 200127 | -| id044 | 198886 | -... -| id093 | 200132 | -| id003 | 199047 | -+-------+--------+ +The ClickBench[1] benchmarks are widely cited in the industry and +focus on grouping / aggregation / filtering. This runner uses the +scripts and queries from [2]. -h2o groupby query 1 took 1669 ms -``` +[1]: https://github.com/ClickHouse/ClickBench +[2]: https://github.com/ClickHouse/ClickBench/tree/main/datafusion -[1]: http://www.tpc.org/tpch/ -[2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page +## Parquet Filter -## Parquet benchmarks +Test performance of parquet filter pushdown -This is a set of benchmarks for testing and verifying performance of parquet filtering and sorting. -The queries are executed on a synthetic dataset generated during the benchmark execution and designed to simulate web server access logs. +The queries are executed on a synthetic dataset generated during +the benchmark execution and designed to simulate web server access +logs. -To run filter benchmarks, run: +Example -```base -cargo run --release --bin parquet -- filter --path ./data --scale-factor 1.0 -``` +dfbench parquet-filter --path ./data --scale-factor 1.0 -This will generate the synthetic dataset at `./data/logs.parquet`. The size of the dataset can be controlled through the `size_factor` +generates the synthetic dataset at `./data/logs.parquet`. The size +of the dataset can be controlled through the `size_factor` (with the default value of `1.0` generating a ~1GB parquet file). -For each filter we will run the query using different `ParquetScanOption` settings. +For each filter we will run the query using different +`ParquetScanOption` settings. -Example run: +Example output: ``` -Running benchmarks with the following options: Opt { debug: false, iterations: 3, partitions: 2, path: "./data", batch_size: 8192, scale_factor: 1.0 } +Running benchmarks with the following options: Opt { debug: false, iterations: 3, partitions: 2, path: "./data", +batch_size: 8192, scale_factor: 1.0 } Generated test dataset with 10699521 rows Executing with filter 'request_method = Utf8("GET")' Using scan options ParquetScanOptions { pushdown_filters: false, reorder_predicates: false, enable_page_index: false } @@ -336,12 +311,56 @@ Iteration 2 returned 1781686 rows in 1947 ms ... ``` -Similarly, to run sorting benchmarks, run: +## Sort +Test performance of sorting large datasets + +This test sorts a a synthetic dataset generated during the +benchmark execution, designed to simulate sorting web server +access logs. Such sorting is often done during data transformation +steps. + +The tests sort the entire dataset using several different sort +orders. + +## TPCH + +Run the tpch benchmark. -```base -cargo run --release --bin parquet -- sort --path ./data --scale-factor 1.0 +This benchmarks is derived from the [TPC-H][1] version +[2.17.1]. The data and answers are generated using `tpch-gen` from +[2]. + +[1]: http://www.tpc.org/tpch/ +[2]: https://github.com/databricks/tpch-dbgen.git, +[2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf + + +# Older Benchmarks + +## h2o benchmarks + +```bash +cargo run --release --bin h2o group-by --query 1 --path /mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv --mem-table --debug ``` -This proceeds in the same way as the filter benchmarks: each sort expression -combination will be run using the same set of `ParquetScanOption` as the -filter benchmarks. +Example run: + +``` +Running benchmarks with the following options: GroupBy(GroupBy { query: 1, path: "/mnt/bigdata/h2oai/N_1e7_K_1e2_single.csv", debug: false }) +Executing select id1, sum(v1) as v1 from x group by id1 ++-------+--------+ +| id1 | v1 | ++-------+--------+ +| id063 | 199420 | +| id094 | 200127 | +| id044 | 198886 | +... +| id093 | 200132 | +| id003 | 199047 | ++-------+--------+ + +h2o groupby query 1 took 1669 ms +``` + +[1]: http://www.tpc.org/tpch/ +[2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index dee6896aec388..bdbdc0e517625 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -18,7 +18,9 @@ # This script is meant for developers of DataFusion -- it is runnable # from the standard DataFusion development environment and uses cargo, -# etc. +# etc and orchestrates gathering data and run the benchmark binary in +# different configurations. + # Exit on error set -e @@ -33,7 +35,7 @@ BENCHMARK=all DATAFUSION_DIR=${DATAFUSION_DIR:-$SCRIPT_DIR/..} DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} #CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} -CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --profile release-nonlto"} # TEMP: for faster iterations +CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --profile release-nonlto"} # for faster iterations usage() { echo " @@ -64,10 +66,14 @@ compare: Comares results from benchmark runs * Benchmarks ********** all(default): Data/Run/Compare for all benchmarks -tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table -tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory -parquet: Benchmark of parquet reader's filtering speed -sort: Benchmark of sorting speed +tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table +tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory +tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table +tpch10_mem: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory +parquet: Benchmark of parquet reader's filtering speed +sort: Benchmark of sorting speed +clickbench_1: ClickBench queries against a single parquet file +clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet ********** * Supported Configuration (Environment Variables) @@ -116,7 +122,7 @@ main() { data) BENCHMARK=${ARG2:-"${BENCHMARK}"} echo "***************************" - echo "DataFusion Benchmark Data Generation" + echo "DataFusion Benchmark Runner and Data Generator" echo "COMMAND: ${COMMAND}" echo "BENCHMARK: ${BENCHMARK}" echo "DATA_DIR: ${DATA_DIR}" @@ -124,14 +130,30 @@ main() { echo "***************************" case "$BENCHMARK" in all) - data_tpch + data_tpch "1" + data_tpch "10" + data_clickbench_1 + data_clickbench_partitioned ;; tpch) - data_tpch + data_tpch "1" ;; tpch_mem) - # same data for tpch_mem - data_tpch + # same data as for tpch + data_tpch "1" + ;; + tpch10) + data_tpch "10" + ;; + tpch_mem10) + # same data as for tpch10 + data_tpch "10" + ;; + clickbench_1) + data_clickbench_1 + ;; + clickbench_partitioned) + data_clickbench_partitioned ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" @@ -160,18 +182,29 @@ main() { # navigate to the appropriate directory pushd "${DATAFUSION_DIR}/benchmarks" > /dev/null mkdir -p "${RESULTS_DIR}" + mkdir -p "${DATA_DIR}" case "$BENCHMARK" in all) - run_tpch - run_tpch_mem + run_tpch "1" + run_tpch_mem "1" + run_tpch "10" + run_tpch_mem "10" run_parquet run_sort + run_clickbench_1 + run_clickbench_partitioned ;; tpch) - run_tpch + run_tpch "1" ;; tpch_mem) - run_tpch_mem + run_tpch_mem "1" + ;; + tpch10) + run_tpch "10" + ;; + tpch_mem10) + run_tpch_mem "10" ;; parquet) run_parquet @@ -179,6 +212,12 @@ main() { sort) run_sort ;; + clickbench_1) + run_clickbench_1 + ;; + clickbench_partitioned) + run_clickbench_partitioned + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -201,60 +240,87 @@ main() { -# Creates TPCH data if it doesn't already exist +# Creates TPCH data at a certain scale factor, if it doesn't already +# exist +# +# call like: data_tpch($scale_factor) +# +# Creates data in $DATA_DIR/tpch_sf1 for scale factor 1 +# Creates data in $DATA_DIR/tpch_sf10 for scale factor 10 +# etc data_tpch() { - echo "Creating tpch dataset..." + SCALE_FACTOR=$1 + if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 + fi + + TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + echo "Creating tpch dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." # Ensure the target data directory exists - mkdir -p "${DATA_DIR}" + mkdir -p "${TPCH_DIR}" # Create 'tbl' (CSV format) data into $DATA_DIR if it does not already exist - SCALE_FACTOR=1 - FILE="${DATA_DIR}/supplier.tbl" + FILE="${TPCH_DIR}/supplier.tbl" if test -f "${FILE}"; then echo " tbl files exist ($FILE exists)." else echo " creating tbl files with tpch_dbgen..." - docker run -v "${DATA_DIR}":/data -it --rm ghcr.io/databloom-ai/tpch-docker:main -vf -s ${SCALE_FACTOR} + docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/databloom-ai/tpch-docker:main -vf -s ${SCALE_FACTOR} fi # Copy expected answers into the ./data/answers directory if it does not already exist - FILE="${DATA_DIR}/answers/q1.out" + FILE="${TPCH_DIR}/answers/q1.out" if test -f "${FILE}"; then echo " Expected answers exist (${FILE} exists)." else - echo " Copying answers to ${DATA_DIR}/answers" - mkdir -p "${DATA_DIR}/answers" - docker run -v "${DATA_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/databloom-ai/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" + echo " Copying answers to ${TPCH_DIR}/answers" + mkdir -p "${TPCH_DIR}/answers" + docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/databloom-ai/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi # Create 'parquet' files from tbl - FILE="${DATA_DIR}/supplier" + FILE="${TPCH_DIR}/supplier" if test -d "${FILE}"; then echo " parquet files exist ($FILE exists)." else echo " creating parquet files using benchmark binary ..." pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${DATA_DIR}" --output "${DATA_DIR}" --format parquet + $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet popd > /dev/null fi } # Runs the tpch benchmark run_tpch() { + SCALE_FACTOR=$1 + if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 + fi + TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + RESULTS_FILE="${RESULTS_DIR}/tpch.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${DATA_DIR}" --format parquet -o ${RESULTS_FILE} + $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --format parquet -o ${RESULTS_FILE} } # Runs the tpch in memory run_tpch_mem() { + SCALE_FACTOR=$1 + if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 + fi + TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + RESULTS_FILE="${RESULTS_DIR}/tpch_mem.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${DATA_DIR}" -m --format parquet -o ${RESULTS_FILE} + $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" -m --format parquet -o ${RESULTS_FILE} } # Runs the parquet filter benchmark @@ -273,6 +339,68 @@ run_sort() { $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o ${RESULTS_FILE} } + +# Downloads the single file hits.parquet ClickBench datasets from +# https://github.com/ClickHouse/ClickBench/tree/main#data-loading +# +# Creates data in $DATA_DIR/hits.parquet +data_clickbench_1() { + pushd "${DATA_DIR}" > /dev/null + + # Avoid downloading if it already exists and is the right size + OUTPUT_SIZE=`wc -c hits.parquet 2>/dev/null | awk '{print $1}' || true` + echo -n "Checking hits.parquet..." + if test "${OUTPUT_SIZE}" = "14779976446"; then + echo -n "... found ${OUTPUT_SIZE} bytes ..." + else + URL="https://datasets.clickhouse.com/hits_compatible/hits.parquet" + echo -n "... downloading ${URL} (14GB) ... " + wget --continue ${URL} + fi + echo " Done" + popd > /dev/null +} + +# Downloads the 100 file partitioned ClickBench datasets from +# https://github.com/ClickHouse/ClickBench/tree/main#data-loading +# +# Creates data in $DATA_DIR/hits_partitioned +data_clickbench_partitioned() { + MAX_CONCURRENT_DOWNLOADS=10 + + mkdir -p "${DATA_DIR}/hits_partitioned" + pushd "${DATA_DIR}/hits_partitioned" > /dev/null + + echo -n "Checking hits_partitioned..." + OUTPUT_SIZE=`wc -c * 2>/dev/null | tail -n 1 | awk '{print $1}' || true` + if test "${OUTPUT_SIZE}" = "14737666736"; then + echo -n "... found ${OUTPUT_SIZE} bytes ..." + else + echo -n " downloading with ${MAX_CONCURRENT_DOWNLOADS} parallel workers" + seq 0 99 | xargs -P${MAX_CONCURRENT_DOWNLOADS} -I{} bash -c 'wget -q --continue https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hits_{}.parquet && echo -n "."' + fi + + echo " Done" + popd > /dev/null +} + + +# Runs the clickbench benchmark with a single large parquet file +run_clickbench_1() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running clickbench (1 file) benchmark..." + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} +} + + # Runs the clickbench benchmark with a single large parquet file +run_clickbench_partitioned() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running clickbench (partitioned, 100 files) benchmark..." + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} +} + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" BRANCH1="${ARG2}" diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 80aa3c76b754c..ec2b28fa0556c 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -109,7 +109,6 @@ def compare( noise_threshold: float, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) - comparison = BenchmarkRun.load_from_file(comparison_path) console = Console() @@ -124,27 +123,57 @@ def compare( table.add_column(comparison_header, justify="right", style="dim") table.add_column("Change", justify="right", style="dim") + faster_count = 0 + slower_count = 0 + no_change_count = 0 + total_baseline_time = 0 + total_comparison_time = 0 + for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query + total_baseline_time += baseline_result.execution_time + total_comparison_time += comparison_result.execution_time + change = comparison_result.execution_time / baseline_result.execution_time if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): - change = "no change" + change_text = "no change" + no_change_count += 1 elif change < 1.0: - change = f"+{(1 / change):.2f}x faster" + change_text = f"+{(1 / change):.2f}x faster" + faster_count += 1 else: - change = f"{change:.2f}x slower" + change_text = f"{change:.2f}x slower" + slower_count += 1 table.add_row( f"Q{baseline_result.query}", f"{baseline_result.execution_time:.2f}ms", f"{comparison_result.execution_time:.2f}ms", - change, + change_text, ) console.print(table) + # Calculate averages + avg_baseline_time = total_baseline_time / len(baseline.queries) + avg_comparison_time = total_comparison_time / len(comparison.queries) + + # Summary table + summary_table = Table(show_header=True, header_style="bold magenta") + summary_table.add_column("Benchmark Summary", justify="left", style="dim") + summary_table.add_column("", justify="right", style="dim") + + summary_table.add_row(f"Total Time ({baseline_header})", f"{total_baseline_time:.2f}ms") + summary_table.add_row(f"Total Time ({comparison_header})", f"{total_comparison_time:.2f}ms") + summary_table.add_row(f"Average Time ({baseline_header})", f"{avg_baseline_time:.2f}ms") + summary_table.add_row(f"Average Time ({comparison_header})", f"{avg_comparison_time:.2f}ms") + summary_table.add_row("Queries Faster", str(faster_count)) + summary_table.add_row("Queries Slower", str(slower_count)) + summary_table.add_row("Queries with No Change", str(no_change_count)) + + console.print(summary_table) def main() -> None: parser = ArgumentParser() diff --git a/benchmarks/queries/clickbench/README.txt b/benchmarks/queries/clickbench/README.txt new file mode 100644 index 0000000000000..b46900956e54b --- /dev/null +++ b/benchmarks/queries/clickbench/README.txt @@ -0,0 +1 @@ +Downloaded from https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql diff --git a/benchmarks/queries/clickbench/queries.sql b/benchmarks/queries/clickbench/queries.sql new file mode 100644 index 0000000000000..52e72e02e1e0d --- /dev/null +++ b/benchmarks/queries/clickbench/queries.sql @@ -0,0 +1,43 @@ +SELECT COUNT(*) FROM hits; +SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; +SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; +SELECT AVG("UserID") FROM hits; +SELECT COUNT(DISTINCT "UserID") FROM hits; +SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; +SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits; +SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; +SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; +SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; +SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; +SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; +SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; +SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; +SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; +SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; +SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; +SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; +SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime"), "SearchPhrase" LIMIT 10; +SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; +SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; +SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; +SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; +SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; +SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; +SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100; +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/q8.sql b/benchmarks/queries/q8.sql index 4f34dca6a0458..6ddb2a6747589 100644 --- a/benchmarks/queries/q8.sql +++ b/benchmarks/queries/q8.sql @@ -1,9 +1,9 @@ select o_year, - cast(cast(sum(case - when nation = 'BRAZIL' then volume - else 0 - end) as decimal(12,2)) / cast(sum(volume) as decimal(12,2)) as decimal(15,2)) as mkt_share + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share from ( select diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs new file mode 100644 index 0000000000000..441b6cdc02933 --- /dev/null +++ b/benchmarks/src/bin/dfbench.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion benchmark runner +use datafusion::error::Result; + +use structopt::StructOpt; + +#[cfg(feature = "snmalloc")] +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[cfg(feature = "mimalloc")] +#[global_allocator] +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; + +use datafusion_benchmarks::{clickbench, parquet_filter, sort, tpch}; + +#[derive(Debug, StructOpt)] +#[structopt(about = "benchmark command")] +enum Options { + Tpch(tpch::RunOpt), + TpchConvert(tpch::ConvertOpt), + Clickbench(clickbench::RunOpt), + ParquetFilter(parquet_filter::RunOpt), + Sort(sort::RunOpt), +} + +// Main benchmark runner entrypoint +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match Options::from_args() { + Options::Tpch(opt) => opt.run().await, + Options::TpchConvert(opt) => opt.run().await, + Options::Clickbench(opt) => opt.run().await, + Options::ParquetFilter(opt) => opt.run().await, + Options::Sort(opt) => opt.run().await, + } +} diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs index d75f9a30b4e9d..1bb8cb9d43e4b 100644 --- a/benchmarks/src/bin/h2o.rs +++ b/benchmarks/src/bin/h2o.rs @@ -72,7 +72,7 @@ async fn group_by(opt: &GroupBy) -> Result<()> { let mut config = ConfigOptions::from_env()?; config.execution.batch_size = 65535; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let schema = Schema::new(vec![ Field::new("id1", DataType::Utf8, false), diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs deleted file mode 100644 index 1de490c905e5e..0000000000000 --- a/benchmarks/src/bin/nyctaxi.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Apache Arrow Rust Benchmarks - -use std::collections::HashMap; -use std::path::PathBuf; -use std::process; -use std::time::Instant; - -use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::arrow::util::pretty; - -use datafusion::error::Result; -use datafusion::execution::context::{SessionConfig, SessionContext}; - -use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; -use datafusion_benchmarks::BenchmarkRun; -use structopt::StructOpt; - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[derive(Debug, StructOpt)] -#[structopt(name = "Benchmarks", about = "Apache Arrow Rust Benchmarks.")] -struct Opt { - /// Activate debug mode to see query results - #[structopt(short, long)] - debug: bool, - - /// Number of iterations of each test run - #[structopt(short = "i", long = "iterations", default_value = "3")] - iterations: usize, - - /// Number of partitions to process in parallel - #[structopt(short = "p", long = "partitions", default_value = "2")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, - - /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - - /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] - file_format: String, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -#[tokio::main] -async fn main() -> Result<()> { - let opt = Opt::from_args(); - println!("Running benchmarks with the following options: {opt:?}"); - - let config = SessionConfig::new() - .with_target_partitions(opt.partitions) - .with_batch_size(opt.batch_size); - let mut ctx = SessionContext::with_config(config); - - let path = opt.path.to_str().unwrap(); - - match opt.file_format.as_str() { - "csv" => { - let schema = nyctaxi_schema(); - let options = CsvReadOptions::new().schema(&schema).has_header(true); - ctx.register_csv("tripdata", path, options).await? - } - "parquet" => { - ctx.register_parquet("tripdata", path, ParquetReadOptions::default()) - .await? - } - other => { - println!("Invalid file format '{other}'"); - process::exit(-1); - } - } - - datafusion_sql_benchmarks(&mut ctx, opt).await -} - -async fn datafusion_sql_benchmarks(ctx: &mut SessionContext, opt: Opt) -> Result<()> { - let iterations = opt.iterations; - let debug = opt.debug; - let output = opt.output_path; - let mut rundata = BenchmarkRun::new(); - let mut queries = HashMap::new(); - queries.insert("fare_amt_by_passenger", "SELECT passenger_count, MIN(fare_amount), MAX(fare_amount), SUM(fare_amount) FROM tripdata GROUP BY passenger_count"); - for (name, sql) in &queries { - println!("Executing '{name}'"); - rundata.start_new_case(name); - for i in 0..iterations { - let (rows, elapsed) = execute_sql(ctx, sql, debug).await?; - println!( - "Query '{}' iteration {} took {} ms", - name, - i, - elapsed.as_secs_f64() * 1000.0 - ); - rundata.write_iter(elapsed, rows); - } - } - rundata.maybe_write_json(output.as_ref())?; - Ok(()) -} - -async fn execute_sql( - ctx: &SessionContext, - sql: &str, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let dataframe = ctx.sql(sql).await?; - if debug { - println!("Optimized logical plan:\n{:?}", dataframe.logical_plan()); - } - let result = dataframe.collect().await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rowcount = result.iter().map(|b| b.num_rows()).sum(); - Ok((rowcount, elapsed)) -} - -fn nyctaxi_schema() -> Schema { - Schema::new(vec![ - Field::new("VendorID", DataType::Utf8, true), - Field::new("tpep_pickup_datetime", DataType::Utf8, true), - Field::new("tpep_dropoff_datetime", DataType::Utf8, true), - Field::new("passenger_count", DataType::Int32, true), - Field::new("trip_distance", DataType::Utf8, true), - Field::new("RatecodeID", DataType::Utf8, true), - Field::new("store_and_fwd_flag", DataType::Utf8, true), - Field::new("PULocationID", DataType::Utf8, true), - Field::new("DOLocationID", DataType::Utf8, true), - Field::new("payment_type", DataType::Utf8, true), - Field::new("fare_amount", DataType::Float64, true), - Field::new("extra", DataType::Float64, true), - Field::new("mta_tax", DataType::Float64, true), - Field::new("tip_amount", DataType::Float64, true), - Field::new("tolls_amount", DataType::Float64, true), - Field::new("improvement_surcharge", DataType::Float64, true), - Field::new("total_amount", DataType::Float64, true), - ]) -} diff --git a/benchmarks/src/bin/parquet.rs b/benchmarks/src/bin/parquet.rs index 98b2da7c2fcd7..18871803fc0b6 100644 --- a/benchmarks/src/bin/parquet.rs +++ b/benchmarks/src/bin/parquet.rs @@ -15,22 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::util::pretty; use datafusion::common::Result; -use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::optimizer::utils::disjunction; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::prelude::{col, SessionConfig, SessionContext}; -use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_benchmarks::BenchmarkRun; -use parquet::file::properties::WriterProperties; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::Instant; + +use datafusion_benchmarks::{parquet_filter, sort}; use structopt::StructOpt; -use test_utils::AccessLogGenerator; #[cfg(feature = "snmalloc")] #[global_allocator] @@ -40,63 +28,9 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[structopt(name = "Benchmarks", about = "Apache Arrow Rust Benchmarks.")] enum ParquetBenchCmd { /// Benchmark sorting parquet files - Sort(Opt), + Sort(sort::RunOpt), /// Benchmark parquet filter pushdown - Filter(Opt), -} - -#[derive(Debug, StructOpt, Clone)] -struct Opt { - /// Activate debug mode to see query results - #[structopt(short, long)] - debug: bool, - - /// Number of iterations of each test run - #[structopt(short = "i", long = "iterations", default_value = "3")] - iterations: usize, - - /// Number of partitions to process in parallel - #[structopt(long = "partitions", default_value = "2")] - partitions: usize, - - /// Path to folder where access log file will be generated - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - /// Data page size of the generated parquet file - #[structopt(long = "page-size")] - page_size: Option, - - /// Data page size of the generated parquet file - #[structopt(long = "row-group-size")] - row_group_size: Option, - - /// Total size of generated dataset. The default scale factor of 1.0 will generate a roughly 1GB parquet file - #[structopt(short = "s", long = "scale-factor", default_value = "1.0")] - scale_factor: f32, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} -impl Opt { - /// Initialize parquet test file given options. - fn init_file(&self) -> Result { - let path = self.path.join("logs.parquet"); - - let mut props_builder = WriterProperties::builder(); - - if let Some(s) = self.page_size { - props_builder = props_builder - .set_data_page_size_limit(s) - .set_write_batch_size(s); - } - - if let Some(s) = self.row_group_size { - props_builder = props_builder.set_max_row_group_size(s); - } - - gen_data(path, self.scale_factor, props_builder.build()) - } + Filter(parquet_filter::RunOpt), } #[tokio::main] @@ -105,245 +39,11 @@ async fn main() -> Result<()> { match cmd { ParquetBenchCmd::Filter(opt) => { println!("running filter benchmarks"); - let test_file = opt.init_file()?; - run_filter_benchmarks(opt, &test_file).await?; + opt.run().await } ParquetBenchCmd::Sort(opt) => { println!("running sort benchmarks"); - let test_file = opt.init_file()?; - run_sort_benchmarks(opt, &test_file).await?; - } - } - Ok(()) -} - -async fn run_sort_benchmarks(opt: Opt, test_file: &TestParquetFile) -> Result<()> { - use datafusion::physical_expr::expressions::col; - let mut rundata = BenchmarkRun::new(); - let schema = test_file.schema(); - let sort_cases = vec![ - ( - "sort utf8", - vec![PhysicalSortExpr { - expr: col("request_method", &schema)?, - options: Default::default(), - }], - ), - ( - "sort int", - vec![PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }], - ), - ( - "sort decimal", - vec![ - // sort decimal - PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }, - ], - ), - ( - "sort integer tuple", - vec![ - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("response_bytes", &schema)?, - options: Default::default(), - }, - ], - ), - ( - "sort utf8 tuple", - vec![ - // sort utf8 tuple - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("host", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("pod", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("image", &schema)?, - options: Default::default(), - }, - ], - ), - ( - "sort mixed tuple", - vec![ - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }, - ], - ), - ]; - for (title, expr) in sort_cases { - println!("Executing '{title}' (sorting by: {expr:?})"); - rundata.start_new_case(title); - for i in 0..opt.iterations { - let config = SessionConfig::new().with_target_partitions(opt.partitions); - let ctx = SessionContext::with_config(config); - let (rows, elapsed) = exec_sort(&ctx, &expr, test_file, opt.debug).await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} finished in {ms} ms"); - rundata.write_iter(elapsed, rows); - } - println!("\n"); - } - if let Some(path) = &opt.output_path { - std::fs::write(path, rundata.to_json())?; - } - Ok(()) -} -fn parquet_scan_disp(opts: &ParquetScanOptions) -> String { - format!( - "pushdown_filters={}, reorder_filters={}, page_index={}", - opts.pushdown_filters, opts.reorder_filters, opts.enable_page_index - ) -} -async fn run_filter_benchmarks(opt: Opt, test_file: &TestParquetFile) -> Result<()> { - let mut rundata = BenchmarkRun::new(); - let scan_options_matrix = vec![ - ParquetScanOptions { - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: true, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: false, - }, - ]; - - let filter_matrix = vec![ - ("Selective-ish filter", col("request_method").eq(lit("GET"))), - ( - "Non-selective filter", - col("request_method").not_eq(lit("GET")), - ), - ( - "Basic conjunction", - col("request_method") - .eq(lit("POST")) - .and(col("response_status").eq(lit(503_u16))), - ), - ( - "Nested filters", - col("request_method").eq(lit("POST")).and(or( - col("response_status").eq(lit(503_u16)), - col("response_status").eq(lit(403_u16)), - )), - ), - ( - "Many filters", - disjunction([ - col("request_method").not_eq(lit("GET")), - col("response_status").eq(lit(400_u16)), - col("service").eq(lit("backend")), - ]) - .unwrap(), - ), - ("Filter everything", col("response_status").eq(lit(429_u16))), - ("Filter nothing", col("response_status").gt(lit(0_u16))), - ]; - - for (name, filter_expr) in &filter_matrix { - println!("Executing '{name}' (filter: {filter_expr})"); - for scan_options in &scan_options_matrix { - println!("Using scan options {scan_options:?}"); - rundata - .start_new_case(&format!("{name}: {}", parquet_scan_disp(scan_options))); - for i in 0..opt.iterations { - let config = scan_options.config().with_target_partitions(opt.partitions); - let ctx = SessionContext::with_config(config); - - let (rows, elapsed) = - exec_scan(&ctx, test_file, filter_expr.clone(), opt.debug).await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} returned {rows} rows in {ms} ms"); - rundata.write_iter(elapsed, rows); - } + opt.run().await } - println!("\n"); - } - rundata.maybe_write_json(opt.output_path.as_ref())?; - Ok(()) -} - -async fn exec_scan( - ctx: &SessionContext, - test_file: &TestParquetFile, - filter: Expr, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let exec = test_file.create_scan(Some(filter)).await?; - - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} - -async fn exec_sort( - ctx: &SessionContext, - expr: &[PhysicalSortExpr], - test_file: &TestParquetFile, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let scan = test_file.create_scan(None).await?; - let exec = Arc::new(SortExec::new(expr.to_owned(), scan)); - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} - -fn gen_data( - path: PathBuf, - scale_factor: f32, - props: WriterProperties, -) -> Result { - let generator = AccessLogGenerator::new(); - - let num_batches = 100_f32 * scale_factor; - - TestParquetFile::try_new(path, props, generator.take(num_batches as usize)) } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 4ba8b26bba335..95480935700d7 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -15,28 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. +//! tpch binary only entrypoint -use datafusion::datasource::file_format::{csv::CsvFormat, FileFormat}; -use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::parquet::basic::Compression; -use datafusion::physical_plan::display::DisplayableExecutionPlan; -use datafusion::physical_plan::{collect, displayable}; -use datafusion::prelude::*; -use datafusion::{ - arrow::record_batch::RecordBatch, datasource::file_format::parquet::ParquetFormat, -}; -use datafusion::{ - arrow::util::pretty, - datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}, -}; -use datafusion_benchmarks::{tpch::*, BenchmarkRun}; -use std::{iter::Iterator, path::PathBuf, sync::Arc, time::Instant}; - -use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; -use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; -use datafusion::datasource::listing::ListingTableUrl; +use datafusion::error::Result; +use datafusion_benchmarks::tpch; use structopt::StructOpt; #[cfg(feature = "snmalloc")] @@ -47,500 +29,32 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -#[derive(Debug, StructOpt, Clone)] -struct DataFusionBenchmarkOpt { - /// Query number. If not specified, runs all queries - #[structopt(short, long)] - query: Option, - - /// Activate debug mode to see query results - #[structopt(short, long)] - debug: bool, - - /// Number of iterations of each test run - #[structopt(short = "i", long = "iterations", default_value = "3")] - iterations: usize, - - /// Number of partitions to process in parallel - #[structopt(short = "n", long = "partitions", default_value = "2")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, - - /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - - /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] - file_format: String, - - /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] - mem_table: bool, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, - - /// Whether to disable collection of statistics (and cost based optimizations) or not. - #[structopt(short = "S", long = "disable-statistics")] - disable_statistics: bool, -} - -#[derive(Debug, StructOpt)] -struct ConvertOpt { - /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] - input_path: PathBuf, - - /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] - output_path: PathBuf, - - /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] - file_format: String, - - /// Compression to use when writing Parquet files - #[structopt(short = "c", long = "compression", default_value = "zstd")] - compression: String, - - /// Number of partitions to produce - #[structopt(short = "n", long = "partitions", default_value = "1")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, -} - #[derive(Debug, StructOpt)] #[structopt(about = "benchmark command")] enum BenchmarkSubCommandOpt { #[structopt(name = "datafusion")] - DataFusionBenchmark(DataFusionBenchmarkOpt), + DataFusionBenchmark(tpch::RunOpt), } #[derive(Debug, StructOpt)] #[structopt(name = "TPC-H", about = "TPC-H Benchmarks.")] enum TpchOpt { Benchmark(BenchmarkSubCommandOpt), - Convert(ConvertOpt), + Convert(tpch::ConvertOpt), } +/// 'tpch' entry point, with tortured command line arguments. Please +/// use `dbbench` instead. +/// +/// Note: this is kept to be backwards compatible with the benchmark names prior to +/// #[tokio::main] async fn main() -> Result<()> { - use BenchmarkSubCommandOpt::*; - env_logger::init(); match TpchOpt::from_args() { - TpchOpt::Benchmark(DataFusionBenchmark(opt)) => { - benchmark_datafusion(opt).await.map(|_| ()) - } - TpchOpt::Convert(opt) => { - let compression = match opt.compression.as_str() { - "none" => Compression::UNCOMPRESSED, - "snappy" => Compression::SNAPPY, - "brotli" => Compression::BROTLI(Default::default()), - "gzip" => Compression::GZIP(Default::default()), - "lz4" => Compression::LZ4, - "lz0" => Compression::LZO, - "zstd" => Compression::ZSTD(Default::default()), - other => { - return Err(DataFusionError::NotImplemented(format!( - "Invalid compression format: {other}" - ))); - } - }; - convert_tbl( - opt.input_path.to_str().unwrap(), - opt.output_path.to_str().unwrap(), - &opt.file_format, - opt.partitions, - opt.batch_size, - compression, - ) - .await - } - } -} - -const TPCH_QUERY_START_ID: usize = 1; -const TPCH_QUERY_END_ID: usize = 22; - -async fn benchmark_datafusion( - opt: DataFusionBenchmarkOpt, -) -> Result>> { - println!("Running benchmarks with the following options: {opt:?}"); - let query_range = match opt.query { - Some(query_id) => query_id..=query_id, - None => TPCH_QUERY_START_ID..=TPCH_QUERY_END_ID, - }; - - let mut benchmark_run = BenchmarkRun::new(); - let mut results = vec![]; - for query_id in query_range { - benchmark_run.start_new_case(&format!("Query {query_id}")); - let (query_run, result) = benchmark_query(&opt, query_id).await?; - results.push(result); - for iter in query_run { - benchmark_run.write_iter(iter.elapsed, iter.row_count); - } - } - benchmark_run.maybe_write_json(opt.output_path.as_ref())?; - Ok(results) -} - -async fn benchmark_query( - opt: &DataFusionBenchmarkOpt, - query_id: usize, -) -> Result<(Vec, Vec)> { - let mut query_results = vec![]; - let config = SessionConfig::new() - .with_target_partitions(opt.partitions) - .with_batch_size(opt.batch_size) - .with_collect_statistics(!opt.disable_statistics); - let ctx = SessionContext::with_config(config); - - // register tables - register_tables(opt, &ctx).await?; - - let mut millis = vec![]; - // run benchmark - let mut result: Vec = Vec::with_capacity(1); - for i in 0..opt.iterations { - let start = Instant::now(); - - let sql = &get_query_sql(query_id)?; - - // query 15 is special, with 3 statements. the second statement is the one from which we - // want to capture the results - if query_id == 15 { - for (n, query) in sql.iter().enumerate() { - if n == 1 { - result = execute_query(&ctx, query, opt.debug).await?; - } else { - execute_query(&ctx, query, opt.debug).await?; - } - } - } else { - for query in sql { - result = execute_query(&ctx, query, opt.debug).await?; - } - } - - let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; - let ms = elapsed.as_secs_f64() * 1000.0; - millis.push(ms); - let row_count = result.iter().map(|b| b.num_rows()).sum(); - println!( - "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" - ); - query_results.push(QueryResult { elapsed, row_count }); - } - - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {query_id} avg time: {avg:.2} ms"); - - Ok((query_results, result)) -} - -async fn register_tables( - opt: &DataFusionBenchmarkOpt, - ctx: &SessionContext, -) -> Result<()> { - for table in TPCH_TABLES { - let table_provider = { - get_table( - ctx, - opt.path.to_str().unwrap(), - table, - opt.file_format.as_str(), - opt.partitions, - ) - .await? - }; - - if opt.mem_table { - println!("Loading table '{table}' into memory"); - let start = Instant::now(); - let memtable = - MemTable::load(table_provider, Some(opt.partitions), &ctx.state()) - .await?; - println!( - "Loaded table '{}' into memory in {} ms", - table, - start.elapsed().as_millis() - ); - ctx.register_table(*table, Arc::new(memtable))?; - } else { - ctx.register_table(*table, table_provider)?; - } - } - Ok(()) -} - -async fn execute_query( - ctx: &SessionContext, - sql: &str, - debug: bool, -) -> Result> { - let plan = ctx.sql(sql).await?; - let (state, plan) = plan.into_parts(); - - if debug { - println!("=== Logical plan ===\n{plan:?}\n"); - } - - let plan = state.optimize(&plan)?; - if debug { - println!("=== Optimized logical plan ===\n{plan:?}\n"); - } - let physical_plan = state.create_physical_plan(&plan).await?; - if debug { - println!( - "=== Physical plan ===\n{}\n", - displayable(physical_plan.as_ref()).indent() - ); - } - let result = collect(physical_plan.clone(), state.task_ctx()).await?; - if debug { - println!( - "=== Physical plan with metrics ===\n{}\n", - DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent() - ); - if !result.is_empty() { - // do not call print_batches if there are no batches as the result is confusing - // and makes it look like there is a batch with no columns - pretty::print_batches(&result)?; - } - } - Ok(result) -} - -async fn get_table( - ctx: &SessionContext, - path: &str, - table: &str, - table_format: &str, - target_partitions: usize, -) -> Result> { - // Obtain a snapshot of the SessionState - let state = ctx.state(); - let (format, path, extension): (Arc, String, &'static str) = - match table_format { - // dbgen creates .tbl ('|' delimited) files without header - "tbl" => { - let path = format!("{path}/{table}.tbl"); - - let format = CsvFormat::default() - .with_delimiter(b'|') - .with_has_header(false); - - (Arc::new(format), path, ".tbl") - } - "csv" => { - let path = format!("{path}/{table}"); - let format = CsvFormat::default() - .with_delimiter(b',') - .with_has_header(true); - - (Arc::new(format), path, DEFAULT_CSV_EXTENSION) - } - "parquet" => { - let path = format!("{path}/{table}"); - let format = ParquetFormat::default().with_enable_pruning(Some(true)); - - (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) - } - other => { - unimplemented!("Invalid file format '{}'", other); - } - }; - - let options = ListingOptions::new(format) - .with_file_extension(extension) - .with_target_partitions(target_partitions) - .with_collect_stat(state.config().collect_statistics()); - - let table_path = ListingTableUrl::parse(path)?; - let config = ListingTableConfig::new(table_path).with_listing_options(options); - - let config = match table_format { - "parquet" => config.infer_schema(&state).await?, - "tbl" => config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))), - "csv" => config.with_schema(Arc::new(get_tpch_table_schema(table))), - _ => unreachable!(), - }; - - Ok(Arc::new(ListingTable::try_new(config)?)) -} - -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - -#[cfg(test)] -#[cfg(feature = "ci")] -/// CI checks -mod tests { - use std::path::Path; - - use super::*; - use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; - - async fn serde_round_trip(query: usize) -> Result<()> { - let ctx = SessionContext::default(); - let path = get_tpch_data_path()?; - let opt = DataFusionBenchmarkOpt { - query: Some(query), - debug: false, - iterations: 1, - partitions: 2, - batch_size: 8192, - path: PathBuf::from(path.to_string()), - file_format: "tbl".to_string(), - mem_table: false, - output_path: None, - disable_statistics: false, - }; - register_tables(&opt, &ctx).await?; - let queries = get_query_sql(query)?; - for query in queries { - let plan = ctx.sql(&query).await?; - let plan = plan.into_optimized_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; - let plan_formatted = format!("{}", plan.display_indent()); - let plan2_formatted = format!("{}", plan2.display_indent()); - assert_eq!(plan_formatted, plan2_formatted); - } - Ok(()) - } - - #[tokio::test] - async fn serde_q1() -> Result<()> { - serde_round_trip(1).await - } - - #[tokio::test] - async fn serde_q2() -> Result<()> { - serde_round_trip(2).await - } - - #[tokio::test] - async fn serde_q3() -> Result<()> { - serde_round_trip(3).await - } - - #[tokio::test] - async fn serde_q4() -> Result<()> { - serde_round_trip(4).await - } - - #[tokio::test] - async fn serde_q5() -> Result<()> { - serde_round_trip(5).await - } - - #[tokio::test] - async fn serde_q6() -> Result<()> { - serde_round_trip(6).await - } - - #[tokio::test] - async fn serde_q7() -> Result<()> { - serde_round_trip(7).await - } - - #[tokio::test] - async fn serde_q8() -> Result<()> { - serde_round_trip(8).await - } - - #[tokio::test] - async fn serde_q9() -> Result<()> { - serde_round_trip(9).await - } - - #[tokio::test] - async fn serde_q10() -> Result<()> { - serde_round_trip(10).await - } - - #[tokio::test] - async fn serde_q11() -> Result<()> { - serde_round_trip(11).await - } - - #[tokio::test] - async fn serde_q12() -> Result<()> { - serde_round_trip(12).await - } - - #[tokio::test] - async fn serde_q13() -> Result<()> { - serde_round_trip(13).await - } - - #[tokio::test] - async fn serde_q14() -> Result<()> { - serde_round_trip(14).await - } - - #[tokio::test] - async fn serde_q15() -> Result<()> { - serde_round_trip(15).await - } - - #[tokio::test] - async fn serde_q16() -> Result<()> { - serde_round_trip(16).await - } - - #[tokio::test] - async fn serde_q17() -> Result<()> { - serde_round_trip(17).await - } - - #[tokio::test] - async fn serde_q18() -> Result<()> { - serde_round_trip(18).await - } - - #[tokio::test] - async fn serde_q19() -> Result<()> { - serde_round_trip(19).await - } - - #[tokio::test] - async fn serde_q20() -> Result<()> { - serde_round_trip(20).await - } - - #[tokio::test] - async fn serde_q21() -> Result<()> { - serde_round_trip(21).await - } - - #[tokio::test] - async fn serde_q22() -> Result<()> { - serde_round_trip(22).await - } - - fn get_tpch_data_path() -> Result { - let path = - std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); - if !Path::new(&path).exists() { - return Err(DataFusionError::Execution(format!( - "Benchmark data not found (set TPCH_DATA env var to override): {}", - path - ))); + TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { + opt.run().await } - Ok(path) + TpchOpt::Convert(opt) => opt.run().await, } } diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs new file mode 100644 index 0000000000000..a6d32eb39f311 --- /dev/null +++ b/benchmarks/src/clickbench.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{path::PathBuf, time::Instant}; + +use datafusion::{ + common::exec_err, + error::{DataFusionError, Result}, + prelude::SessionContext, +}; +use structopt::StructOpt; + +use crate::{BenchmarkRun, CommonOpt}; + +/// Run the clickbench benchmark +/// +/// The ClickBench[1] benchmarks are widely cited in the industry and +/// focus on grouping / aggregation / filtering. This runner uses the +/// scripts and queries from [2]. +/// +/// [1]: https://github.com/ClickHouse/ClickBench +/// [2]: https://github.com/ClickHouse/ClickBench/tree/main/datafusion +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 0 and 42). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to hits.parquet (single file) or `hits_partitioned` + /// (partitioned, 100 files) + #[structopt( + parse(from_os_str), + short = "p", + long = "path", + default_value = "benchmarks/data/hits.parquet" + )] + path: PathBuf, + + /// Path to queries.sql (single file) + #[structopt( + parse(from_os_str), + short = "r", + long = "queries-path", + default_value = "benchmarks/queries/clickbench/queries.sql" + )] + queries_path: PathBuf, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +const CLICKBENCH_QUERY_START_ID: usize = 0; +const CLICKBENCH_QUERY_END_ID: usize = 42; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => CLICKBENCH_QUERY_START_ID..=CLICKBENCH_QUERY_END_ID, + }; + + let config = self.common.config(); + let ctx = SessionContext::new_with_config(config); + self.register_hits(&ctx).await?; + + let iterations = self.common.iterations; + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let sql = self.get_query(query_id)?; + println!("Q{query_id}: {sql}"); + + for i in 0..iterations { + let start = Instant::now(); + let results = ctx.sql(&sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + benchmark_run.write_iter(elapsed, row_count); + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + /// Registrs the `hits.parquet` as a table named `hits` + async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { + let options = Default::default(); + let path = self.path.as_os_str().to_str().unwrap(); + ctx.register_parquet("hits", path, options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'hits' as {path}"), + Box::new(e), + ) + }) + } + + /// Returns the text of query `query_id` + fn get_query(&self, query_id: usize) -> Result { + if query_id > CLICKBENCH_QUERY_END_ID { + return exec_err!( + "Invalid query id {query_id}. Must be between {CLICKBENCH_QUERY_START_ID} and {CLICKBENCH_QUERY_END_ID}" + ); + } + + let path = self.queries_path.as_path(); + + // ClickBench has all queries in a single file identified by line number + let all_queries = std::fs::read_to_string(path).map_err(|e| { + DataFusionError::Execution(format!("Could not open {path:?}: {e}")) + })?; + let all_queries: Vec<_> = all_queries.lines().collect(); + + Ok(all_queries.get(query_id).map(|s| s.to_string()).unwrap()) + } +} diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index c2f4e876ce700..f81220aa2c944 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -15,143 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{error::Result, DATAFUSION_VERSION}; -use serde::{Serialize, Serializer}; -use serde_json::Value; -use std::{ - collections::HashMap, - path::Path, - time::{Duration, SystemTime}, -}; - +//! DataFusion benchmark runner +pub mod clickbench; +pub mod parquet_filter; +pub mod sort; pub mod tpch; - -fn serialize_start_time(start_time: &SystemTime, ser: S) -> Result -where - S: Serializer, -{ - ser.serialize_u64( - start_time - .duration_since(SystemTime::UNIX_EPOCH) - .expect("current time is later than UNIX_EPOCH") - .as_secs(), - ) -} -fn serialize_elapsed(elapsed: &Duration, ser: S) -> Result -where - S: Serializer, -{ - let ms = elapsed.as_secs_f64() * 1000.0; - ser.serialize_f64(ms) -} -#[derive(Debug, Serialize)] -pub struct RunContext { - /// Benchmark crate version - pub benchmark_version: String, - /// DataFusion crate version - pub datafusion_version: String, - /// Number of CPU cores - pub num_cpus: usize, - /// Start time - #[serde(serialize_with = "serialize_start_time")] - pub start_time: SystemTime, - /// CLI arguments - pub arguments: Vec, -} - -impl Default for RunContext { - fn default() -> Self { - Self::new() - } -} - -impl RunContext { - pub fn new() -> Self { - Self { - benchmark_version: env!("CARGO_PKG_VERSION").to_owned(), - datafusion_version: DATAFUSION_VERSION.to_owned(), - num_cpus: num_cpus::get(), - start_time: SystemTime::now(), - arguments: std::env::args().skip(1).collect::>(), - } - } -} - -/// A single iteration of a benchmark query -#[derive(Debug, Serialize)] -struct QueryIter { - #[serde(serialize_with = "serialize_elapsed")] - elapsed: Duration, - row_count: usize, -} -/// A single benchmark case -#[derive(Debug, Serialize)] -pub struct BenchQuery { - query: String, - iterations: Vec, - #[serde(serialize_with = "serialize_start_time")] - start_time: SystemTime, -} - -/// collects benchmark run data and then serializes it at the end -pub struct BenchmarkRun { - context: RunContext, - queries: Vec, - current_case: Option, -} - -impl Default for BenchmarkRun { - fn default() -> Self { - Self::new() - } -} - -impl BenchmarkRun { - // create new - pub fn new() -> Self { - Self { - context: RunContext::new(), - queries: vec![], - current_case: None, - } - } - /// begin a new case. iterations added after this will be included in the new case - pub fn start_new_case(&mut self, id: &str) { - self.queries.push(BenchQuery { - query: id.to_owned(), - iterations: vec![], - start_time: SystemTime::now(), - }); - if let Some(c) = self.current_case.as_mut() { - *c += 1; - } else { - self.current_case = Some(0); - } - } - /// Write a new iteration to the current case - pub fn write_iter(&mut self, elapsed: Duration, row_count: usize) { - if let Some(idx) = self.current_case { - self.queries[idx] - .iterations - .push(QueryIter { elapsed, row_count }) - } else { - panic!("no cases existed yet"); - } - } - - /// Stringify data into formatted json - pub fn to_json(&self) -> String { - let mut output = HashMap::<&str, Value>::new(); - output.insert("context", serde_json::to_value(&self.context).unwrap()); - output.insert("queries", serde_json::to_value(&self.queries).unwrap()); - serde_json::to_string_pretty(&output).unwrap() - } - - /// Write data as json into output path if it exists. - pub fn maybe_write_json(&self, maybe_path: Option>) -> Result<()> { - if let Some(path) = maybe_path { - std::fs::write(path, self.to_json())?; - }; - Ok(()) - } -} +mod util; +pub use util::*; diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs new file mode 100644 index 0000000000000..1d816908e2b04 --- /dev/null +++ b/benchmarks/src/parquet_filter.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::AccessLogOpt; +use crate::{BenchmarkRun, CommonOpt}; +use arrow::util::pretty; +use datafusion::common::Result; +use datafusion::logical_expr::utils::disjunction; +use datafusion::logical_expr::{lit, or, Expr}; +use datafusion::physical_plan::collect; +use datafusion::prelude::{col, SessionContext}; +use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; +use std::path::PathBuf; +use std::time::Instant; +use structopt::StructOpt; + +/// Test performance of parquet filter pushdown +/// +/// The queries are executed on a synthetic dataset generated during +/// the benchmark execution and designed to simulate web server access +/// logs. +/// +/// Example +/// +/// dfbench parquet-filter --path ./data --scale-factor 1.0 +/// +/// generates the synthetic dataset at `./data/logs.parquet`. The size +/// of the dataset can be controlled through the `size_factor` +/// (with the default value of `1.0` generating a ~1GB parquet file). +/// +/// For each filter we will run the query using different +/// `ParquetScanOption` settings. +/// +/// Example output: +/// +/// Running benchmarks with the following options: Opt { debug: false, iterations: 3, partitions: 2, path: "./data", batch_size: 8192, scale_factor: 1.0 } +/// Generated test dataset with 10699521 rows +/// Executing with filter 'request_method = Utf8("GET")' +/// Using scan options ParquetScanOptions { pushdown_filters: false, reorder_predicates: false, enable_page_index: false } +/// Iteration 0 returned 10699521 rows in 1303 ms +/// Iteration 1 returned 10699521 rows in 1288 ms +/// Iteration 2 returned 10699521 rows in 1266 ms +/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: true, enable_page_index: true } +/// Iteration 0 returned 1781686 rows in 1970 ms +/// Iteration 1 returned 1781686 rows in 2002 ms +/// Iteration 2 returned 1781686 rows in 1988 ms +/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: false, enable_page_index: true } +/// Iteration 0 returned 1781686 rows in 1940 ms +/// Iteration 1 returned 1781686 rows in 1986 ms +/// Iteration 2 returned 1781686 rows in 1947 ms +/// ... +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Create data files + #[structopt(flatten)] + access_log: AccessLogOpt, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + let test_file = self.access_log.build()?; + + let mut rundata = BenchmarkRun::new(); + let scan_options_matrix = vec![ + ParquetScanOptions { + pushdown_filters: false, + reorder_filters: false, + enable_page_index: false, + }, + ParquetScanOptions { + pushdown_filters: true, + reorder_filters: true, + enable_page_index: true, + }, + ParquetScanOptions { + pushdown_filters: true, + reorder_filters: true, + enable_page_index: false, + }, + ]; + + let filter_matrix = vec![ + ("Selective-ish filter", col("request_method").eq(lit("GET"))), + ( + "Non-selective filter", + col("request_method").not_eq(lit("GET")), + ), + ( + "Basic conjunction", + col("request_method") + .eq(lit("POST")) + .and(col("response_status").eq(lit(503_u16))), + ), + ( + "Nested filters", + col("request_method").eq(lit("POST")).and(or( + col("response_status").eq(lit(503_u16)), + col("response_status").eq(lit(403_u16)), + )), + ), + ( + "Many filters", + disjunction([ + col("request_method").not_eq(lit("GET")), + col("response_status").eq(lit(400_u16)), + col("service").eq(lit("backend")), + ]) + .unwrap(), + ), + ("Filter everything", col("response_status").eq(lit(429_u16))), + ("Filter nothing", col("response_status").gt(lit(0_u16))), + ]; + + for (name, filter_expr) in &filter_matrix { + println!("Executing '{name}' (filter: {filter_expr})"); + for scan_options in &scan_options_matrix { + println!("Using scan options {scan_options:?}"); + rundata.start_new_case(&format!( + "{name}: {}", + parquet_scan_disp(scan_options) + )); + for i in 0..self.common.iterations { + let config = self.common.update_config(scan_options.config()); + let ctx = SessionContext::new_with_config(config); + + let (rows, elapsed) = exec_scan( + &ctx, + &test_file, + filter_expr.clone(), + self.common.debug, + ) + .await?; + let ms = elapsed.as_secs_f64() * 1000.0; + println!("Iteration {i} returned {rows} rows in {ms} ms"); + rundata.write_iter(elapsed, rows); + } + } + println!("\n"); + } + rundata.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } +} + +fn parquet_scan_disp(opts: &ParquetScanOptions) -> String { + format!( + "pushdown_filters={}, reorder_filters={}, page_index={}", + opts.pushdown_filters, opts.reorder_filters, opts.enable_page_index + ) +} + +async fn exec_scan( + ctx: &SessionContext, + test_file: &TestParquetFile, + filter: Expr, + debug: bool, +) -> Result<(usize, std::time::Duration)> { + let start = Instant::now(); + let exec = test_file.create_scan(Some(filter)).await?; + + let task_ctx = ctx.task_ctx(); + let result = collect(exec, task_ctx).await?; + let elapsed = start.elapsed(); + if debug { + pretty::print_batches(&result)?; + } + let rows = result.iter().map(|b| b.num_rows()).sum(); + Ok((rows, elapsed)) +} diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs new file mode 100644 index 0000000000000..224f2b19c72e5 --- /dev/null +++ b/benchmarks/src/sort.rs @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::AccessLogOpt; +use crate::BenchmarkRun; +use crate::CommonOpt; +use arrow::util::pretty; +use datafusion::common::Result; +use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::test_util::parquet::TestParquetFile; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; +use structopt::StructOpt; + +/// Test performance of sorting large datasets +/// +/// This test sorts a a synthetic dataset generated during the +/// benchmark execution, designed to simulate sorting web server +/// access logs. Such sorting is often done during data transformation +/// steps. +/// +/// The tests sort the entire dataset using several different sort +/// orders. +/// +/// Example: +/// +/// dfbench sort --path ./data --scale-factor 1.0 +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Create data files + #[structopt(flatten)] + access_log: AccessLogOpt, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + let test_file = self.access_log.build()?; + + use datafusion::physical_expr::expressions::col; + let mut rundata = BenchmarkRun::new(); + let schema = test_file.schema(); + let sort_cases = vec![ + ( + "sort utf8", + vec![PhysicalSortExpr { + expr: col("request_method", &schema)?, + options: Default::default(), + }], + ), + ( + "sort int", + vec![PhysicalSortExpr { + expr: col("request_bytes", &schema)?, + options: Default::default(), + }], + ), + ( + "sort decimal", + vec![ + // sort decimal + PhysicalSortExpr { + expr: col("decimal_price", &schema)?, + options: Default::default(), + }, + ], + ), + ( + "sort integer tuple", + vec![ + PhysicalSortExpr { + expr: col("request_bytes", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("response_bytes", &schema)?, + options: Default::default(), + }, + ], + ), + ( + "sort utf8 tuple", + vec![ + // sort utf8 tuple + PhysicalSortExpr { + expr: col("service", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("host", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("pod", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("image", &schema)?, + options: Default::default(), + }, + ], + ), + ( + "sort mixed tuple", + vec![ + PhysicalSortExpr { + expr: col("service", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("request_bytes", &schema)?, + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("decimal_price", &schema)?, + options: Default::default(), + }, + ], + ), + ]; + for (title, expr) in sort_cases { + println!("Executing '{title}' (sorting by: {expr:?})"); + rundata.start_new_case(title); + for i in 0..self.common.iterations { + let config = SessionConfig::new().with_target_partitions( + self.common.partitions.unwrap_or(num_cpus::get()), + ); + let ctx = SessionContext::new_with_config(config); + let (rows, elapsed) = + exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; + let ms = elapsed.as_secs_f64() * 1000.0; + println!("Iteration {i} finished in {ms} ms"); + rundata.write_iter(elapsed, rows); + } + println!("\n"); + } + if let Some(path) = &self.output_path { + std::fs::write(path, rundata.to_json())?; + } + Ok(()) + } +} + +async fn exec_sort( + ctx: &SessionContext, + expr: &[PhysicalSortExpr], + test_file: &TestParquetFile, + debug: bool, +) -> Result<(usize, std::time::Duration)> { + let start = Instant::now(); + let scan = test_file.create_scan(None).await?; + let exec = Arc::new(SortExec::new(expr.to_owned(), scan)); + let task_ctx = ctx.task_ctx(); + let result = collect(exec, task_ctx).await?; + let elapsed = start.elapsed(); + if debug { + pretty::print_batches(&result)?; + } + let rows = result.iter().map(|b| b.num_rows()).sum(); + Ok((rows, elapsed)) +} diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs new file mode 100644 index 0000000000000..2fc74ce38888f --- /dev/null +++ b/benchmarks/src/tpch/convert.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::Instant; + +use datafusion::common::not_impl_err; +use datafusion::error::DataFusionError; +use datafusion::error::Result; +use datafusion::prelude::*; +use parquet::basic::Compression; +use parquet::file::properties::WriterProperties; +use structopt::StructOpt; + +use super::get_tbl_tpch_table_schema; +use super::TPCH_TABLES; + +/// Convert tpch .slt files to .parquet or .csv files +#[derive(Debug, StructOpt)] +pub struct ConvertOpt { + /// Path to csv files + #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] + input_path: PathBuf, + + /// Output path + #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] + output_path: PathBuf, + + /// Output file format: `csv` or `parquet` + #[structopt(short = "f", long = "format")] + file_format: String, + + /// Compression to use when writing Parquet files + #[structopt(short = "c", long = "compression", default_value = "zstd")] + compression: String, + + /// Number of partitions to produce + #[structopt(short = "n", long = "partitions", default_value = "1")] + partitions: usize, + + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, +} + +impl ConvertOpt { + pub async fn run(self) -> Result<()> { + let compression = self.compression()?; + + let input_path = self.input_path.to_str().unwrap(); + let output_path = self.output_path.to_str().unwrap(); + + let output_root_path = Path::new(output_path); + for table in TPCH_TABLES { + let start = Instant::now(); + let schema = get_tbl_tpch_table_schema(table); + + let input_path = format!("{input_path}/{table}.tbl"); + let options = CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .delimiter(b'|') + .file_extension(".tbl"); + + let config = SessionConfig::new().with_batch_size(self.batch_size); + let ctx = SessionContext::new_with_config(config); + + // build plan to read the TBL file + let mut csv = ctx.read_csv(&input_path, options).await?; + + // Select all apart from the padding column + let selection = csv + .schema() + .fields() + .iter() + .take(schema.fields.len() - 1) + .map(|d| Expr::Column(d.qualified_column())) + .collect(); + + csv = csv.select(selection)?; + // optionally, repartition the file + let partitions = self.partitions; + if partitions > 1 { + csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? + } + + // create the physical plan + let csv = csv.create_physical_plan().await?; + + let output_path = output_root_path.join(table); + let output_path = output_path.to_str().unwrap().to_owned(); + fs::create_dir_all(&output_path)?; + println!( + "Converting '{}' to {} files in directory '{}'", + &input_path, self.file_format, &output_path + ); + match self.file_format.as_str() { + "csv" => ctx.write_csv(csv, output_path).await?, + "parquet" => { + let props = WriterProperties::builder() + .set_compression(compression) + .build(); + ctx.write_parquet(csv, output_path, Some(props)).await? + } + other => { + return not_impl_err!("Invalid output format: {other}"); + } + } + println!("Conversion completed in {} ms", start.elapsed().as_millis()); + } + + Ok(()) + } + + /// return the compression method to use when writing parquet + fn compression(&self) -> Result { + Ok(match self.compression.as_str() { + "none" => Compression::UNCOMPRESSED, + "snappy" => Compression::SNAPPY, + "brotli" => Compression::BROTLI(Default::default()), + "gzip" => Compression::GZIP(Default::default()), + "lz4" => Compression::LZ4, + "lz0" => Compression::LZO, + "zstd" => Compression::ZSTD(Default::default()), + other => { + return not_impl_err!("Invalid compression format: {other}"); + } + }) + } +} diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch/mod.rs similarity index 72% rename from benchmarks/src/tpch.rs rename to benchmarks/src/tpch/mod.rs index 58b9c3637c4e9..8965ebea7ff63 100644 --- a/benchmarks/src/tpch.rs +++ b/benchmarks/src/tpch/mod.rs @@ -15,18 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::SchemaBuilder; -use std::fs; -use std::path::Path; -use std::time::Instant; +//! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. -use datafusion::prelude::*; +use arrow::datatypes::SchemaBuilder; use datafusion::{ arrow::datatypes::{DataType, Field, Schema}, + common::plan_err, error::{DataFusionError, Result}, }; -use parquet::basic::Compression; -use parquet::file::properties::WriterProperties; +use std::fs; +mod run; +pub use run::RunOpt; + +mod convert; +pub use convert::ConvertOpt; pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", @@ -156,88 +158,12 @@ pub fn get_query_sql(query: usize) -> Result> { Err(e) => errors.push(format!("{filename}: {e}")), }; } - Err(DataFusionError::Plan(format!( - "invalid query. Could not find query: {errors:?}" - ))) + plan_err!("invalid query. Could not find query: {:?}", errors) } else { - Err(DataFusionError::Plan( - "invalid query. Expected value between 1 and 22".to_owned(), - )) + plan_err!("invalid query. Expected value between 1 and 22") } } -/// Conver tbl (csv) file to parquet -pub async fn convert_tbl( - input_path: &str, - output_path: &str, - file_format: &str, - partitions: usize, - batch_size: usize, - compression: Compression, -) -> Result<()> { - let output_root_path = Path::new(output_path); - for table in TPCH_TABLES { - let start = Instant::now(); - let schema = get_tbl_tpch_table_schema(table); - - let input_path = format!("{input_path}/{table}.tbl"); - let options = CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .delimiter(b'|') - .file_extension(".tbl"); - - let config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::with_config(config); - - // build plan to read the TBL file - let mut csv = ctx.read_csv(&input_path, options).await?; - - // Select all apart from the padding column - let selection = csv - .schema() - .fields() - .iter() - .take(schema.fields.len() - 1) - .map(|d| Expr::Column(d.qualified_column())) - .collect(); - - csv = csv.select(selection)?; - // optionally, repartition the file - if partitions > 1 { - csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? - } - - // create the physical plan - let csv = csv.create_physical_plan().await?; - - let output_path = output_root_path.join(table); - let output_path = output_path.to_str().unwrap().to_owned(); - - println!( - "Converting '{}' to {} files in directory '{}'", - &input_path, &file_format, &output_path - ); - match file_format { - "csv" => ctx.write_csv(csv, output_path).await?, - "parquet" => { - let props = WriterProperties::builder() - .set_compression(compression) - .build(); - ctx.write_parquet(csv, output_path, Some(props)).await? - } - other => { - return Err(DataFusionError::NotImplemented(format!( - "Invalid output format: {other}" - ))); - } - } - println!("Conversion completed in {} ms", start.elapsed().as_millis()); - } - - Ok(()) -} - pub const QUERY_LIMIT: [Option; 22] = [ None, Some(100), diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs new file mode 100644 index 0000000000000..5193d578fb486 --- /dev/null +++ b/benchmarks/src/tpch/run.rs @@ -0,0 +1,453 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::get_query_sql; +use crate::{BenchmarkRun, CommonOpt}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::csv::CsvFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use log::info; + +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; + +use datafusion::error::Result; +use datafusion::prelude::*; +use structopt::StructOpt; + +use super::{get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES}; + +/// Run the tpch benchmark. +/// +/// This benchmarks is derived from the [TPC-H][1] version +/// [2.17.1]. The data and answers are generated using `tpch-gen` from +/// [2]. +/// +/// [1]: http://www.tpc.org/tpch/ +/// [2]: https://github.com/databricks/tpch-dbgen.git, +/// [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// File format: `csv` or `parquet` + #[structopt(short = "f", long = "format", default_value = "csv")] + file_format: String, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[structopt(short = "S", long = "disable-statistics")] + disable_statistics: bool, +} + +const TPCH_QUERY_START_ID: usize = 1; +const TPCH_QUERY_END_ID: usize = 22; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => TPCH_QUERY_START_ID..=TPCH_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id).await?; + for iter in query_run { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query(&self, query_id: usize) -> Result> { + let config = self + .common + .config() + .with_collect_statistics(!self.disable_statistics); + let ctx = SessionContext::new_with_config(config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let sql = &get_query_sql(query_id)?; + + // query 15 is special, with 3 statements. the second statement is the one from which we + // want to capture the results + let mut result = vec![]; + if query_id == 15 { + for (n, query) in sql.iter().enumerate() { + if n == 1 { + result = self.execute_query(&ctx, query).await?; + } else { + self.execute_query(&ctx, query).await?; + } + } + } else { + for query in sql { + result = self.execute_query(&ctx, query).await?; + } + } + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in TPCH_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan:?}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan:?}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let table_format = self.file_format.as_str(); + let target_partitions = self.partitions(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let (format, path, extension): (Arc, String, &'static str) = + match table_format { + // dbgen creates .tbl ('|' delimited) files without header + "tbl" => { + let path = format!("{path}/{table}.tbl"); + + let format = CsvFormat::default() + .with_delimiter(b'|') + .with_has_header(false); + + (Arc::new(format), path, ".tbl") + } + "csv" => { + let path = format!("{path}/{table}"); + let format = CsvFormat::default() + .with_delimiter(b',') + .with_has_header(true); + + (Arc::new(format), path, DEFAULT_CSV_EXTENSION) + } + "parquet" => { + let path = format!("{path}/{table}"); + let format = ParquetFormat::default().with_enable_pruning(Some(true)); + + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) + } + other => { + unimplemented!("Invalid file format '{}'", other); + } + }; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_target_partitions(target_partitions) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + + let config = match table_format { + "parquet" => config.infer_schema(&state).await?, + "tbl" => config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))), + "csv" => config.with_schema(Arc::new(get_tpch_table_schema(table))), + _ => unreachable!(), + }; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use super::*; + use datafusion::common::exec_err; + use datafusion::error::{DataFusionError, Result}; + use std::path::Path; + + use datafusion_proto::bytes::{ + logical_plan_from_bytes, logical_plan_to_bytes, physical_plan_from_bytes, + physical_plan_to_bytes, + }; + + fn get_tpch_data_path() -> Result { + let path = + std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set TPCH_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + async fn round_trip_logical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_tpch_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(query)?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", plan.display_indent()); + let plan2_formatted = format!("{}", plan2.display_indent()); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + async fn round_trip_physical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_tpch_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(query)?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.create_physical_plan().await?; + let bytes = physical_plan_to_bytes(plan.clone())?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); + let plan2_formatted = + format!("{}", displayable(plan2.as_ref()).indent(false)); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + macro_rules! test_round_trip_logical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_logical_plan($query).await + } + }; + } + + macro_rules! test_round_trip_physical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_physical_plan($query).await + } + }; + } + + // logical plan tests + test_round_trip_logical!(round_trip_logical_plan_q1, 1); + test_round_trip_logical!(round_trip_logical_plan_q2, 2); + test_round_trip_logical!(round_trip_logical_plan_q3, 3); + test_round_trip_logical!(round_trip_logical_plan_q4, 4); + test_round_trip_logical!(round_trip_logical_plan_q5, 5); + test_round_trip_logical!(round_trip_logical_plan_q6, 6); + test_round_trip_logical!(round_trip_logical_plan_q7, 7); + test_round_trip_logical!(round_trip_logical_plan_q8, 8); + test_round_trip_logical!(round_trip_logical_plan_q9, 9); + test_round_trip_logical!(round_trip_logical_plan_q10, 10); + test_round_trip_logical!(round_trip_logical_plan_q11, 11); + test_round_trip_logical!(round_trip_logical_plan_q12, 12); + test_round_trip_logical!(round_trip_logical_plan_q13, 13); + test_round_trip_logical!(round_trip_logical_plan_q14, 14); + test_round_trip_logical!(round_trip_logical_plan_q15, 15); + test_round_trip_logical!(round_trip_logical_plan_q16, 16); + test_round_trip_logical!(round_trip_logical_plan_q17, 17); + test_round_trip_logical!(round_trip_logical_plan_q18, 18); + test_round_trip_logical!(round_trip_logical_plan_q19, 19); + test_round_trip_logical!(round_trip_logical_plan_q20, 20); + test_round_trip_logical!(round_trip_logical_plan_q21, 21); + test_round_trip_logical!(round_trip_logical_plan_q22, 22); + + // physical plan tests + test_round_trip_physical!(round_trip_physical_plan_q1, 1); + test_round_trip_physical!(round_trip_physical_plan_q2, 2); + test_round_trip_physical!(round_trip_physical_plan_q3, 3); + test_round_trip_physical!(round_trip_physical_plan_q4, 4); + test_round_trip_physical!(round_trip_physical_plan_q5, 5); + test_round_trip_physical!(round_trip_physical_plan_q6, 6); + test_round_trip_physical!(round_trip_physical_plan_q7, 7); + test_round_trip_physical!(round_trip_physical_plan_q8, 8); + test_round_trip_physical!(round_trip_physical_plan_q9, 9); + test_round_trip_physical!(round_trip_physical_plan_q10, 10); + test_round_trip_physical!(round_trip_physical_plan_q11, 11); + test_round_trip_physical!(round_trip_physical_plan_q12, 12); + test_round_trip_physical!(round_trip_physical_plan_q13, 13); + test_round_trip_physical!(round_trip_physical_plan_q14, 14); + test_round_trip_physical!(round_trip_physical_plan_q15, 15); + test_round_trip_physical!(round_trip_physical_plan_q16, 16); + test_round_trip_physical!(round_trip_physical_plan_q17, 17); + test_round_trip_physical!(round_trip_physical_plan_q18, 18); + test_round_trip_physical!(round_trip_physical_plan_q19, 19); + test_round_trip_physical!(round_trip_physical_plan_q20, 20); + test_round_trip_physical!(round_trip_physical_plan_q21, 21); + test_round_trip_physical!(round_trip_physical_plan_q22, 22); +} diff --git a/benchmarks/src/util/access_log.rs b/benchmarks/src/util/access_log.rs new file mode 100644 index 0000000000000..2b29465ee20e3 --- /dev/null +++ b/benchmarks/src/util/access_log.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark data generation + +use datafusion::common::Result; +use datafusion::test_util::parquet::TestParquetFile; +use parquet::file::properties::WriterProperties; +use std::path::PathBuf; +use structopt::StructOpt; +use test_utils::AccessLogGenerator; + +// Options and builder for making an access log test file +// Note don't use docstring or else it ends up in help +#[derive(Debug, StructOpt, Clone)] +pub struct AccessLogOpt { + /// Path to folder where access log file will be generated + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Data page size of the generated parquet file + #[structopt(long = "page-size")] + page_size: Option, + + /// Data page size of the generated parquet file + #[structopt(long = "row-group-size")] + row_group_size: Option, + + /// Total size of generated dataset. The default scale factor of 1.0 will generate a roughly 1GB parquet file + #[structopt(long = "scale-factor", default_value = "1.0")] + scale_factor: f32, +} + +impl AccessLogOpt { + /// Create the access log and return the file. + /// + /// See [`TestParquetFile`] for more details + pub fn build(self) -> Result { + let path = self.path.join("logs.parquet"); + + let mut props_builder = WriterProperties::builder(); + + if let Some(s) = self.page_size { + props_builder = props_builder + .set_data_page_size_limit(s) + .set_write_batch_size(s); + } + + if let Some(s) = self.row_group_size { + props_builder = props_builder.set_max_row_group_size(s); + } + let props = props_builder.build(); + + let generator = AccessLogGenerator::new(); + + let num_batches = 100_f32 * self.scale_factor; + + TestParquetFile::try_new(path, props, generator.take(num_batches as usize)) + } +} diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs new file mode 100644 index 0000000000000..95c6e5f53d0f0 --- /dev/null +++ b/benchmarks/src/util/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared benchmark utilities +mod access_log; +mod options; +mod run; + +pub use access_log::AccessLogOpt; +pub use options::CommonOpt; +pub use run::{BenchQuery, BenchmarkRun}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs new file mode 100644 index 0000000000000..b9398e5b522f2 --- /dev/null +++ b/benchmarks/src/util/options.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::SessionConfig; +use structopt::StructOpt; + +// Common benchmark options (don't use doc comments otherwise this doc +// shows up in help files) +#[derive(Debug, StructOpt, Clone)] +pub struct CommonOpt { + /// Number of iterations of each test run + #[structopt(short = "i", long = "iterations", default_value = "3")] + pub iterations: usize, + + /// Number of partitions to process in parallel. Defaults to number of available cores. + #[structopt(short = "n", long = "partitions")] + pub partitions: Option, + + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + pub batch_size: usize, + + /// Activate debug mode to see more details + #[structopt(short, long)] + pub debug: bool, +} + +impl CommonOpt { + /// Return an appropriately configured `SessionConfig` + pub fn config(&self) -> SessionConfig { + self.update_config(SessionConfig::new()) + } + + /// Modify the existing config appropriately + pub fn update_config(&self, config: SessionConfig) -> SessionConfig { + config + .with_target_partitions(self.partitions.unwrap_or(num_cpus::get())) + .with_batch_size(self.batch_size) + } +} diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs new file mode 100644 index 0000000000000..5ee6691576b44 --- /dev/null +++ b/benchmarks/src/util/run.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{error::Result, DATAFUSION_VERSION}; +use serde::{Serialize, Serializer}; +use serde_json::Value; +use std::{ + collections::HashMap, + path::Path, + time::{Duration, SystemTime}, +}; + +fn serialize_start_time(start_time: &SystemTime, ser: S) -> Result +where + S: Serializer, +{ + ser.serialize_u64( + start_time + .duration_since(SystemTime::UNIX_EPOCH) + .expect("current time is later than UNIX_EPOCH") + .as_secs(), + ) +} +fn serialize_elapsed(elapsed: &Duration, ser: S) -> Result +where + S: Serializer, +{ + let ms = elapsed.as_secs_f64() * 1000.0; + ser.serialize_f64(ms) +} +#[derive(Debug, Serialize)] +pub struct RunContext { + /// Benchmark crate version + pub benchmark_version: String, + /// DataFusion crate version + pub datafusion_version: String, + /// Number of CPU cores + pub num_cpus: usize, + /// Start time + #[serde(serialize_with = "serialize_start_time")] + pub start_time: SystemTime, + /// CLI arguments + pub arguments: Vec, +} + +impl Default for RunContext { + fn default() -> Self { + Self::new() + } +} + +impl RunContext { + pub fn new() -> Self { + Self { + benchmark_version: env!("CARGO_PKG_VERSION").to_owned(), + datafusion_version: DATAFUSION_VERSION.to_owned(), + num_cpus: num_cpus::get(), + start_time: SystemTime::now(), + arguments: std::env::args().skip(1).collect::>(), + } + } +} + +/// A single iteration of a benchmark query +#[derive(Debug, Serialize)] +struct QueryIter { + #[serde(serialize_with = "serialize_elapsed")] + elapsed: Duration, + row_count: usize, +} +/// A single benchmark case +#[derive(Debug, Serialize)] +pub struct BenchQuery { + query: String, + iterations: Vec, + #[serde(serialize_with = "serialize_start_time")] + start_time: SystemTime, +} + +/// collects benchmark run data and then serializes it at the end +pub struct BenchmarkRun { + context: RunContext, + queries: Vec, + current_case: Option, +} + +impl Default for BenchmarkRun { + fn default() -> Self { + Self::new() + } +} + +impl BenchmarkRun { + // create new + pub fn new() -> Self { + Self { + context: RunContext::new(), + queries: vec![], + current_case: None, + } + } + /// begin a new case. iterations added after this will be included in the new case + pub fn start_new_case(&mut self, id: &str) { + self.queries.push(BenchQuery { + query: id.to_owned(), + iterations: vec![], + start_time: SystemTime::now(), + }); + if let Some(c) = self.current_case.as_mut() { + *c += 1; + } else { + self.current_case = Some(0); + } + } + /// Write a new iteration to the current case + pub fn write_iter(&mut self, elapsed: Duration, row_count: usize) { + if let Some(idx) = self.current_case { + self.queries[idx] + .iterations + .push(QueryIter { elapsed, row_count }) + } else { + panic!("no cases existed yet"); + } + } + + /// Stringify data into formatted json + pub fn to_json(&self) -> String { + let mut output = HashMap::<&str, Value>::new(); + output.insert("context", serde_json::to_value(&self.context).unwrap()); + output.insert("queries", serde_json::to_value(&self.queries).unwrap()); + serde_json::to_string_pretty(&output).unwrap() + } + + /// Write data as json into output path if it exists. + pub fn maybe_write_json(&self, maybe_path: Option>) -> Result<()> { + if let Some(path) = maybe_path { + std::fs::write(path, self.to_json())?; + }; + Ok(()) + } +} diff --git a/ci/scripts/rust_clippy.sh b/ci/scripts/rust_clippy.sh index dfd2916981dd1..f5c8b61e1c06f 100755 --- a/ci/scripts/rust_clippy.sh +++ b/ci/scripts/rust_clippy.sh @@ -19,3 +19,5 @@ set -ex cargo clippy --all-targets --workspace --features avro,pyarrow -- -D warnings +cd datafusion-cli +cargo clippy --all-targets --all-features -- -D warnings diff --git a/ci/scripts/rust_docs.sh b/ci/scripts/rust_docs.sh index 033d6e890ffc2..cf83b80b5132e 100755 --- a/ci/scripts/rust_docs.sh +++ b/ci/scripts/rust_docs.sh @@ -20,3 +20,5 @@ set -ex export RUSTDOCFLAGS="-D warnings -A rustdoc::private-intra-doc-links" cargo doc --document-private-items --no-deps --workspace +cd datafusion-cli +cargo doc --document-private-items --no-deps diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh new file mode 100755 index 0000000000000..fe3696f208652 --- /dev/null +++ b/ci/scripts/rust_example.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex +cd datafusion-examples/examples/ +cargo fmt --all -- --check + +files=$(ls .) +for filename in $files +do + example_name=`basename $filename ".rs"` + # Skip tests that rely on external storage and flight + # todo: Currently, catalog.rs is placed in the external-dependence directory because there is a problem parsing + # the parquet file of the external parquet-test that it currently relies on. + # We will wait for this issue[https://github.com/apache/arrow-datafusion/issues/8041] to be resolved. + if [ ! -d $filename ]; then + cargo run --example $example_name + fi +done diff --git a/ci/scripts/rust_fmt.sh b/ci/scripts/rust_fmt.sh index 9d8325877aad5..cb9bb5e877e77 100755 --- a/ci/scripts/rust_fmt.sh +++ b/ci/scripts/rust_fmt.sh @@ -19,3 +19,5 @@ set -ex cargo fmt --all -- --check +cd datafusion-cli +cargo fmt --all -- --check diff --git a/ci/scripts/rust_toml_fmt.sh b/ci/scripts/rust_toml_fmt.sh index e297ef0015941..0a8cc346a37dc 100755 --- a/ci/scripts/rust_toml_fmt.sh +++ b/ci/scripts/rust_toml_fmt.sh @@ -17,5 +17,11 @@ # specific language governing permissions and limitations # under the License. +# Run cargo-tomlfmt with flag `-d` in dry run to check formatting +# without overwritng the file. If any error occur, you may want to +# rerun 'cargo tomlfmt -p path/to/Cargo.toml' without '-d' to fix +# the formatting automatically. set -ex -find . -mindepth 2 -name 'Cargo.toml' -exec cargo tomlfmt -p {} \; +for toml in $(find . -mindepth 2 -name 'Cargo.toml'); do + cargo tomlfmt -d -p $toml +done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 71b18f71a5fb9..76be04d5ef670 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -2,30 +2,46 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" dependencies = [ "cfg-if", "const-random", "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -47,9 +63,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4f263788a35611fba42eb41ff811c5d0360c58b97402570312a350736e2542e" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "android-tzdata" @@ -66,6 +82,40 @@ dependencies = [ "libc", ] +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + +[[package]] +name = "apache-avro" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +dependencies = [ + "bzip2", + "crc32fast", + "digest", + "lazy_static", + "libflate", + "log", + "num-bigint", + "quad-rand", + "rand", + "regex-lite", + "serde", + "serde_json", + "snap", + "strum", + "strum_macros", + "thiserror", + "typed-builder", + "uuid", + "xz2", + "zstd 0.12.4", +] + [[package]] name = "arrayref" version = "0.3.7" @@ -74,15 +124,15 @@ checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" [[package]] name = "arrayvec" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a46441ae78c0c5915f62aa32cad9910647c19241456dd24039646dd96d494a5" +checksum = "5bc25126d18a012146a888a0298f2c22e1150327bd2765fc76d710a556b2d614" dependencies = [ "ahash", "arrow-arith", @@ -102,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350c5067470aeeb38dcfcc1f7e9c397098116409c9087e43ca99c231020635d9" +checksum = "34ccd45e217ffa6e53bbb0080990e77113bdd4e91ddb84e97b77649810bcf1a7" dependencies = [ "arrow-array", "arrow-buffer", @@ -117,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6049e031521c4e7789b7530ea5991112c0a375430094191f3b74bdf37517c9a9" +checksum = "6bda9acea48b25123c08340f3a8ac361aa0f74469bb36f5ee9acf923fce23e9d" dependencies = [ "ahash", "arrow-buffer", @@ -128,42 +178,45 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.13.2", + "hashbrown 0.14.3", "num", ] [[package]] name = "arrow-buffer" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a83450b94b9fe018b65ba268415aaab78757636f68b7f37b6bc1f2a3888af0a0" +checksum = "01a0fc21915b00fc6c2667b069c1b64bdd920982f426079bc4a7cab86822886c" dependencies = [ + "bytes", "half", "num", ] [[package]] name = "arrow-cast" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "249198411254530414805f77e88e1587b0914735ea180f906506905721f7a44a" +checksum = "5dc0368ed618d509636c1e3cc20db1281148190a78f43519487b2daf07b63b4a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "base64", "chrono", "comfy-table", + "half", "lexical-core", "num", ] [[package]] name = "arrow-csv" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec9ee134298aa895ef9d791dc9cc557cecd839108843830bd35824fcd8d7f721" +checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" dependencies = [ "arrow-array", "arrow-buffer", @@ -180,9 +233,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d48dcbed83d741d4af712af17f6d952972b8f6491b24ee2415243a7e37c6438" +checksum = "907fafe280a3874474678c1858b9ca4cb7fd83fb8034ff5b6d6376205a08c634" dependencies = [ "arrow-buffer", "arrow-schema", @@ -192,9 +245,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8d7b138c5414aeef5dd08abacf362f87ed9b1168ea38d60a6f67590c3f7d99" +checksum = "79a43d6808411886b8c7d4f6f7dd477029c1e77ffffffb7923555cc6579639cd" dependencies = [ "arrow-array", "arrow-buffer", @@ -206,9 +259,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3a597fdca885a81f2e7ab0bacaa0bd2dfefb4cd6a2e5a3d1677396a68673101" +checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" dependencies = [ "arrow-array", "arrow-buffer", @@ -217,7 +270,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap", + "indexmap 2.1.0", "lexical-core", "num", "serde", @@ -226,9 +279,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29be2d5fadaab29e4fa6a7e527ceaa1c2cddc57dc6d86c062f7a05adcd8df71e" +checksum = "9b23b0e53c0db57c6749997fd343d4c0354c994be7eca67152dd2bdb9a3e1bb4" dependencies = [ "arrow-array", "arrow-buffer", @@ -241,9 +294,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6e0bd6ad24d56679b3317b499b0de61bca16d3142896908cce1aa943e56e981" +checksum = "361249898d2d6d4a6eeb7484be6ac74977e48da12a4dd81a708d620cc558117a" dependencies = [ "ahash", "arrow-array", @@ -251,21 +304,22 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.13.2", + "hashbrown 0.14.3", ] [[package]] name = "arrow-schema" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b71d8d68d0bc2e648e4e395896dc518be8b90c5f0f763c59083187c3d46184b" +checksum = "09e28a5e781bf1b0f981333684ad13f5901f4cd2f20589eab7cf1797da8fc167" [[package]] name = "arrow-select" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470cb8610bdfda56554a436febd4e457e506f3c42e01e545a1ea7ecf2a4c8823" +checksum = "4f6208466590960efc1d2a7172bc4ff18a67d6e25c529381d7f96ddaf0dc4036" dependencies = [ + "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -275,24 +329,40 @@ dependencies = [ [[package]] name = "arrow-string" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70f8a2e4ff9dbbd51adbabf92098b71e3eb2ef0cfcb75236ca7c3ce087cce038" +checksum = "a4a48149c63c11c9ff571e50ab8f017d2a7cb71037a882b42f6354ed2da9acc7" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "num", "regex", "regex-syntax", ] +[[package]] +name = "assert_cmd" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88903cb14723e4d4003335bb7f8a14f27691649105346a0f0957466c096adfe6" +dependencies = [ + "anstyle", + "bstr", + "doc-comment", + "predicates", + "predicates-core", + "predicates-tree", + "wait-timeout", +] + [[package]] name = "async-compression" -version = "0.4.0" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0122885821398cc923ece939e24d1056a2384ee719432397fa9db87230ff11" +checksum = "bc2d0cfb2a7388d34f590e76686704c494ed7aaceed62ee1ba35cbf363abc2a5" dependencies = [ "bzip2", "flate2", @@ -302,19 +372,19 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd", - "zstd-safe", + "zstd 0.13.0", + "zstd-safe 7.0.0", ] [[package]] name = "async-trait" -version = "0.1.68" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] @@ -352,11 +422,11 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand", + "fastrand 1.9.0", "hex", "http", "hyper", - "ring", + "ring 0.16.20", "time", "tokio", "tower", @@ -372,7 +442,7 @@ checksum = "1fcdb2f7acbc076ff5ad05e7864bdb191ca70a6fd07668dc3a1a8bcd051de5ae" dependencies = [ "aws-smithy-async", "aws-smithy-types", - "fastrand", + "fastrand 1.9.0", "tokio", "tracing", "zeroize", @@ -518,14 +588,14 @@ dependencies = [ "aws-smithy-http-tower", "aws-smithy-types", "bytes", - "fastrand", + "fastrand 1.9.0", "http", "http-body", "hyper", "hyper-rustls 0.23.2", "lazy_static", "pin-project-lite", - "rustls 0.20.8", + "rustls 0.20.9", "tokio", "tower", "tracing", @@ -624,11 +694,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64-simd" @@ -646,6 +731,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + [[package]] name = "blake2" version = "0.10.6" @@ -657,16 +748,15 @@ dependencies = [ [[package]] name = "blake3" -version = "1.3.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ae2468a89544a466886840aa467a25b766499f4f04bf7d9fcd10ecee9fccef" +checksum = "0231f06152bf547e9c2b5194f247cd97aacf6dcd8b15d8e5ec0663f64580da87" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", - "digest", ] [[package]] @@ -680,9 +770,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.3.4" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" +checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -691,37 +781,48 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.3.4" +version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6561fd3f895a11e8f72af2cb7d22e08366bebc2b6b57f7744c4bda27034744" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", ] +[[package]] +name = "bstr" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e47d3a8076e283f3acd27400535992edb3ba4b5bb72f8891ad8fbe7932a7d4b9" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ "bytes", "either", @@ -750,11 +851,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", + "libc", ] [[package]] @@ -765,22 +867,22 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "winapi", + "windows-targets 0.48.5", ] [[package]] name = "chrono-tz" -version = "0.8.2" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9cc2b23599e6d7479755f3594285efb3f74a1bdca7a7374948bc831e23a552" +checksum = "e23185c0e21df6ed832a12e2bda87c7d1def6842881fb634a8511ced741b0d76" dependencies = [ "chrono", "chrono-tz-build", @@ -789,9 +891,9 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9998fb9f7e9b2111641485bf8beb32f92945f97f92a3d061f744cfef335f751" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" dependencies = [ "parse-zoneinfo", "phf", @@ -805,10 +907,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" dependencies = [ "atty", - "bitflags", + "bitflags 1.3.2", "clap_derive", "clap_lex", - "indexmap", + "indexmap 1.9.3", "once_cell", "strsim", "termcolor", @@ -850,9 +952,9 @@ dependencies = [ [[package]] name = "comfy-table" -version = "6.2.0" +version = "7.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e959d788268e3bf9d35ace83e81b124190378e4c91c9067524675e33394b8ba" +checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" dependencies = [ "strum", "strum_macros", @@ -861,37 +963,35 @@ dependencies = [ [[package]] name = "const-random" -version = "0.1.15" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368a7a772ead6ce7e1de82bfb04c485f3db8ec744f72925af5735e29a22cc18e" +checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" dependencies = [ "const-random-macro", - "proc-macro-hack", ] [[package]] name = "const-random-macro" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d7d6ab3c3a2282db210df5f02c4dab6e0a7057af0fb7ebd4070f30fe05c0ddb" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ "getrandom", "once_cell", - "proc-macro-hack", "tiny-keccak", ] [[package]] name = "constant_time_eq" -version = "0.2.5" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13418e745008f7349ec7e449155f419a61b92b58a99cc3616942b926825ec76b" +checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -899,15 +999,24 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "core2" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] @@ -939,9 +1048,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.2.2" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" dependencies = [ "csv-core", "itoa", @@ -951,21 +1060,37 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" dependencies = [ "memchr", ] +[[package]] +name = "ctor" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" +dependencies = [ + "quote", + "syn 2.0.39", +] + +[[package]] +name = "dary_heap" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" + [[package]] name = "dashmap" -version = "5.4.0" +version = "5.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.12.3", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -973,9 +1098,10 @@ dependencies = [ [[package]] name = "datafusion" -version = "26.0.0" +version = "33.0.0" dependencies = [ "ahash", + "apache-avro", "arrow", "arrow-array", "arrow-schema", @@ -990,24 +1116,23 @@ dependencies = [ "datafusion-expr", "datafusion-optimizer", "datafusion-physical-expr", - "datafusion-row", + "datafusion-physical-plan", "datafusion-sql", "flate2", "futures", "glob", - "hashbrown 0.14.0", - "indexmap", - "itertools", - "lazy_static", + "half", + "hashbrown 0.14.3", + "indexmap 2.1.0", + "itertools 0.12.0", "log", + "num-traits", "num_cpus", "object_store", "parking_lot", "parquet", - "percent-encoding", "pin-project-lite", "rand", - "smallvec", "sqlparser", "tempfile", "tokio", @@ -1015,24 +1140,31 @@ dependencies = [ "url", "uuid", "xz2", - "zstd", + "zstd 0.13.0", ] [[package]] name = "datafusion-cli" -version = "26.0.0" +version = "33.0.0" dependencies = [ "arrow", + "assert_cmd", "async-trait", "aws-config", "aws-credential-types", "clap", + "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "mimalloc", "object_store", "parking_lot", + "parquet", + "predicates", + "regex", + "rstest", "rustyline", "tokio", "url", @@ -1040,11 +1172,17 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "26.0.0" +version = "33.0.0" dependencies = [ + "ahash", + "apache-avro", "arrow", "arrow-array", + "arrow-buffer", + "arrow-schema", "chrono", + "half", + "libc", "num_cpus", "object_store", "parquet", @@ -1053,12 +1191,15 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "26.0.0" +version = "33.0.0" dependencies = [ + "arrow", + "chrono", "dashmap", "datafusion-common", "datafusion-expr", - "hashbrown 0.14.0", + "futures", + "hashbrown 0.14.3", "log", "object_store", "parking_lot", @@ -1069,12 +1210,13 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "26.0.0" +version = "33.0.0" dependencies = [ "ahash", "arrow", + "arrow-array", "datafusion-common", - "lazy_static", + "paste", "sqlparser", "strum", "strum_macros", @@ -1082,7 +1224,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "26.0.0" +version = "33.0.0" dependencies = [ "arrow", "async-trait", @@ -1090,33 +1232,34 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.0", - "itertools", + "hashbrown 0.14.3", + "itertools 0.12.0", "log", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "26.0.0" +version = "33.0.0" dependencies = [ "ahash", "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", + "base64", "blake2", "blake3", "chrono", "datafusion-common", "datafusion-expr", - "datafusion-row", "half", - "hashbrown 0.14.0", - "indexmap", - "itertools", - "lazy_static", - "libc", + "hashbrown 0.14.3", + "hex", + "indexmap 2.1.0", + "itertools 0.12.0", + "log", "md-5", "paste", "petgraph", @@ -1128,18 +1271,37 @@ dependencies = [ ] [[package]] -name = "datafusion-row" -version = "26.0.0" +name = "datafusion-physical-plan" +version = "33.0.0" dependencies = [ + "ahash", "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema", + "async-trait", + "chrono", "datafusion-common", - "paste", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.3", + "indexmap 2.1.0", + "itertools 0.12.0", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", "rand", + "tokio", + "uuid", ] [[package]] name = "datafusion-sql" -version = "26.0.0" +version = "33.0.0" dependencies = [ "arrow", "arrow-schema", @@ -1149,6 +1311,21 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "deranged" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + [[package]] name = "digest" version = "0.10.7" @@ -1209,15 +1386,15 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "encoding_rs" -version = "0.8.32" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" dependencies = [ "cfg-if", ] @@ -1242,24 +1419,19 @@ dependencies = [ ] [[package]] -name = "errno" -version = "0.3.1" +name = "equivalent" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" -dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys 0.48.0", -] +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "errno" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "cc", "libc", + "windows-sys 0.52.0", ] [[package]] @@ -1281,11 +1453,17 @@ dependencies = [ "instant", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + [[package]] name = "fd-lock" -version = "3.0.12" +version = "3.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" +checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", @@ -1304,20 +1482,29 @@ version = "23.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" dependencies = [ - "bitflags", + "bitflags 1.3.2", "rustc_version", ] [[package]] name = "flate2" -version = "1.0.26" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", ] +[[package]] +name = "float-cmp" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +dependencies = [ + "num-traits", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1326,18 +1513,18 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1350,9 +1537,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1360,15 +1547,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1377,38 +1564,44 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" + +[[package]] +name = "futures-timer" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -1434,15 +1627,21 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + [[package]] name = "glob" version = "0.3.1" @@ -1451,9 +1650,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.19" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d357c7ae988e7d2182f7d7871d0b963962420b0678b0997ce7de72001aeab782" +checksum = "4d6250322ef6e60f93f9a2162799302cd6f68f79f6e5d85c8c16f14d1d958178" dependencies = [ "bytes", "fnv", @@ -1461,7 +1660,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 2.1.0", "slab", "tokio", "tokio-util", @@ -1470,10 +1669,11 @@ dependencies = [ [[package]] name = "half" -version = "2.2.1" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" dependencies = [ + "cfg-if", "crunchy", "num-traits", ] @@ -1489,12 +1689,15 @@ name = "hashbrown" version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ "ahash", "allocator-api2", @@ -1517,18 +1720,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "hex" @@ -1547,9 +1741,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -1575,9 +1769,9 @@ checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" @@ -1587,9 +1781,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.26" +version = "0.14.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" dependencies = [ "bytes", "futures-channel", @@ -1602,7 +1796,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -1618,7 +1812,7 @@ dependencies = [ "http", "hyper", "log", - "rustls 0.20.8", + "rustls 0.20.9", "rustls-native-certs", "tokio", "tokio-rustls 0.23.4", @@ -1626,29 +1820,30 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.24.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ + "futures-util", "http", "hyper", - "rustls 0.21.1", + "rustls 0.21.9", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", ] [[package]] name = "iana-time-zone" -version = "0.1.56" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -1662,9 +1857,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1680,6 +1875,16 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "indexmap" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" +dependencies = [ + "equivalent", + "hashbrown 0.14.3", +] + [[package]] name = "instant" version = "0.1.12" @@ -1696,51 +1901,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] -name = "io-lifetimes" -version = "1.0.11" +name = "ipnet" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.1", - "libc", - "windows-sys 0.48.0", -] +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] -name = "ipnet" -version = "2.7.2" +name = "itertools" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] [[package]] name = "itertools" -version = "0.10.5" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "jobserver" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" dependencies = [ "libc", ] [[package]] name = "js-sys" -version = "0.3.63" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f37a4a5928311ac501dee68b3c7613a1037d0edb30c8e5427bd832d55d1b790" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -1817,37 +2020,72 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.146" +version = "0.2.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" + +[[package]] +name = "libflate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +dependencies = [ + "core2", + "hashbrown 0.13.2", + "rle-decode-fast", +] [[package]] name = "libm" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.33" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4ac0e912c8ef1b735e92369695618dc5b1819f5a7bf3f167301a3ba1cea515e" +checksum = "3979b5c37ece694f1f5e51e7ecc871fdb0f517ed04ee45f88d15d6d553cb9664" dependencies = [ "cc", "libc", ] +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.1", + "libc", + "redox_syscall", +] + [[package]] name = "linux-raw-sys" -version = "0.3.8" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -1855,61 +2093,51 @@ dependencies = [ [[package]] name = "log" -version = "0.4.18" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "518ef76f2f87365916b142844c16d8fefd85039bc5699050210a7778ee1cd1de" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] -name = "lz4" -version = "1.24.0" +name = "lz4_flex" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8" dependencies = [ - "libc", - "lz4-sys", + "twox-hash", ] [[package]] -name = "lz4-sys" -version = "1.9.4" +name = "lzma-sys" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" dependencies = [ "cc", "libc", -] - -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", + "pkg-config", ] [[package]] name = "md-5" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ + "cfg-if", "digest", ] [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "mimalloc" -version = "0.1.37" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2894987a3459f3ffb755608bd82188f8ed00d0ae077f1edea29c068d639d98" +checksum = "fa01922b5ea280a911e323e4d2fd24b7fe5cc4042e0d2cda3c40775cdc4bdc9c" dependencies = [ "libmimalloc-sys", ] @@ -1931,9 +2159,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", @@ -1951,21 +2179,26 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "libc", - "static_assertions", ] +[[package]] +name = "normalize-line-endings" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" + [[package]] name = "num" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" dependencies = [ "num-bigint", "num-complex", @@ -1977,9 +2210,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", @@ -1988,9 +2221,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" dependencies = [ "num-traits", ] @@ -2030,9 +2263,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", @@ -2040,19 +2273,28 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi 0.3.3", "libc", ] +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + [[package]] name = "object_store" -version = "0.6.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c776db4f332b571958444982ff641d2531417a326ca368995073b639205d58" +checksum = "2524735495ea1268be33d200e1ee97455096a0846295a21548cd2f3541de7050" dependencies = [ "async-trait", "base64", @@ -2061,13 +2303,13 @@ dependencies = [ "futures", "humantime", "hyper", - "itertools", + "itertools 0.11.0", "parking_lot", "percent-encoding", "quick-xml", "rand", "reqwest", - "ring", + "ring 0.17.7", "rustls-pemfile", "serde", "serde_json", @@ -2092,18 +2334,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "ordered-float" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" dependencies = [ "num-traits", ] [[package]] name = "os_str_bytes" -version = "6.5.0" +version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" +checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" [[package]] name = "outref" @@ -2123,22 +2365,22 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.8" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.3.5", + "redox_syscall", "smallvec", - "windows-targets 0.48.0", + "windows-targets 0.48.5", ] [[package]] name = "parquet" -version = "41.0.0" +version = "49.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6880c32d81884ac4441d9f4b027df8561be23b54f3ac1e62086fa42753dd3faa" +checksum = "af88740a842787da39b3d69ce5fbf6fce97d20211d3b299fee0a0da6430c74d4" dependencies = [ "ahash", "arrow-array", @@ -2154,8 +2396,8 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.13.2", - "lz4", + "hashbrown 0.14.3", + "lz4_flex", "num", "num-bigint", "object_store", @@ -2165,7 +2407,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd", + "zstd 0.13.0", ] [[package]] @@ -2179,40 +2421,40 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.12" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 2.1.0", ] [[package]] name = "phf" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56ac890c5e3ca598bbdeaa99964edb5b0258a583a9eb6ef4e89fc85d9224770" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" dependencies = [ "phf_generator", "phf_shared", @@ -2220,9 +2462,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" dependencies = [ "phf_shared", "rand", @@ -2230,38 +2472,38 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" dependencies = [ "siphasher", ] [[package]] name = "pin-project" -version = "1.1.0" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.0" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -2275,12 +2517,49 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dfc28575c2e3f19cb3c73b93af36460ae898d426eba6fc15b9bd2a5220758a0" +dependencies = [ + "anstyle", + "difflib", + "float-cmp", + "itertools 0.11.0", + "normalize-line-endings", + "predicates-core", + "regex", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2305,26 +2584,26 @@ dependencies = [ "version_check", ] -[[package]] -name = "proc-macro-hack" -version = "0.5.20+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" - [[package]] name = "proc-macro2" -version = "1.0.59" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] +[[package]] +name = "quad-rand" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" + [[package]] name = "quick-xml" -version = "0.28.2" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce5e73202a820a31f8a0ee32ada5e21029c81fd9e3ebf668a40832e4219d9d1" +checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" dependencies = [ "memchr", "serde", @@ -2332,9 +2611,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.28" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -2381,55 +2660,64 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] -name = "redox_syscall" -version = "0.3.5" +name = "redox_users" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" dependencies = [ - "bitflags", + "getrandom", + "libredox", + "thiserror", ] [[package]] -name = "redox_users" -version = "0.4.3" +name = "regex" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "regex" -version = "1.8.4" +name = "regex-automata" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" + [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ "base64", "bytes", @@ -2440,7 +2728,7 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls 0.24.0", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -2448,13 +2736,14 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.1", + "rustls 0.21.9", "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", "tokio-util", "tower-service", "url", @@ -2475,12 +2764,64 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", - "untrusted", + "spin 0.5.2", + "untrusted 0.7.1", "web-sys", "winapi", ] +[[package]] +name = "ring" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +dependencies = [ + "cc", + "getrandom", + "libc", + "spin 0.9.8", + "untrusted 0.9.0", + "windows-sys 0.48.0", +] + +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + +[[package]] +name = "rstest" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de1bb486a691878cd320c2f0d319ba91eeaa2e894066d8b5f8f117c000e9d962" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290ca1a1c8ca7edb7c3283bd44dc35dd54fdec6253a3912e201ba1072018fca8" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", + "unicode-ident", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc_version" version = "0.4.0" @@ -2492,47 +2833,46 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.19" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ - "bitflags", + "bitflags 2.4.1", "errno", - "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.20.8" +version = "0.20.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" +checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" dependencies = [ "log", - "ring", + "ring 0.16.20", "sct", "webpki", ] [[package]] name = "rustls" -version = "0.21.1" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e" +checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" dependencies = [ "log", - "ring", + "ring 0.17.7", "rustls-webpki", "sct", ] [[package]] name = "rustls-native-certs" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" dependencies = [ "openssl-probe", "rustls-pemfile", @@ -2542,28 +2882,28 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ "base64", ] [[package]] name = "rustls-webpki" -version = "0.100.1" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] name = "rustversion" -version = "1.0.12" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "rustyline" @@ -2571,7 +2911,7 @@ version = "11.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfc8644681285d1fb67a467fb3021bfea306b99b4146b166a1fe3ada965eece" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "clipboard-win", "dirs-next", @@ -2590,9 +2930,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "same-file" @@ -2605,36 +2945,36 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] name = "security-framework" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -2643,9 +2983,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", @@ -2653,41 +2993,41 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.17" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] name = "seq-macro" -version = "0.3.3" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.163" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "itoa", "ryu", @@ -2708,9 +3048,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -2719,30 +3059,30 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" [[package]] name = "slab" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] [[package]] name = "smallvec" -version = "1.10.0" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "snafu" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0656e7e3ffb70f6c39b3c2a86332bb74aa3c679da781642590f3c1118c5045" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" dependencies = [ "doc-comment", "snafu-derive", @@ -2750,9 +3090,9 @@ dependencies = [ [[package]] name = "snafu-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "475b3bbe5245c26f2d8a6f62d67c1f30eb9fffeccee721c45d162c3ebbdf81b2" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" dependencies = [ "heck", "proc-macro2", @@ -2762,31 +3102,47 @@ dependencies = [ [[package]] name = "snap" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e9f0ab6ef7eb7353d9119c170a436d1bf248eea575ac42d19d12f4e34130831" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" dependencies = [ "libc", "winapi", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "sqlparser" -version = "0.34.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3706eefb17039056234df6b566b0014f303f867f2656108334a55b8096f59" +checksum = "7c80afe31cdb649e56c0d9bb5503be9166600d68a852c38dd445636d126858e5" dependencies = [ "log", "sqlparser_derive", @@ -2794,9 +3150,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +checksum = "3e9c2e1dde0efa87003e7923d94a90f46e3274ad1649f51de96812be561f041f" dependencies = [ "proc-macro2", "quote", @@ -2823,24 +3179,24 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "strum" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.24.3" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 1.0.109", + "syn 2.0.39", ] [[package]] @@ -2862,37 +3218,64 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.18" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" -version = "3.5.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" dependencies = [ "cfg-if", - "fastrand", - "redox_syscall 0.3.5", + "fastrand 2.0.1", + "redox_syscall", "rustix", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] name = "termcolor" -version = "1.2.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +checksum = "ff1bc3d3f05aff0403e8ac0d92ced918ec05b666a43f83297ccef5bea8a3d449" dependencies = [ "winapi-util", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "textwrap" version = "0.16.0" @@ -2901,22 +3284,22 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] @@ -2932,10 +3315,12 @@ dependencies = [ [[package]] name = "time" -version = "0.3.21" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc" +checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" dependencies = [ + "deranged", + "powerfmt", "serde", "time-core", "time-macros", @@ -2943,15 +3328,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" dependencies = [ "time-core", ] @@ -2982,31 +3367,31 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", "mio", "num_cpus", "parking_lot", "pin-project-lite", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] @@ -3015,18 +3400,18 @@ version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ - "rustls 0.20.8", + "rustls 0.20.9", "tokio", "webpki", ] [[package]] name = "tokio-rustls" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.1", + "rustls 0.21.9", "tokio", ] @@ -3043,9 +3428,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", @@ -3085,11 +3470,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-attributes", @@ -3098,20 +3482,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.24" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] @@ -3132,11 +3516,31 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "typed-builder" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" @@ -3146,9 +3550,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -3167,9 +3571,9 @@ checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "untrusted" @@ -3177,11 +3581,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", @@ -3190,9 +3600,9 @@ dependencies = [ [[package]] name = "urlencoding" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "utf8parse" @@ -3202,11 +3612,12 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.3.3" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", + "serde", ] [[package]] @@ -3221,11 +3632,20 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "walkdir" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" dependencies = [ "same-file", "winapi-util", @@ -3233,11 +3653,10 @@ dependencies = [ [[package]] name = "want" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" dependencies = [ - "log", "try-lock", ] @@ -3249,9 +3668,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.86" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bba0e8cb82ba49ff4e229459ff22a191bbe9a1cb3a341610c9c33efc27ddf73" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3259,24 +3678,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.86" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b04bc93f9d6bdee709f6bd2118f57dd6679cf1176a1af464fca3ab0d66d8fb" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.36" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d1985d03709c53167ce907ff394f5316aa22cb4e12761295c5dc57dacb6297e" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3286,9 +3705,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.86" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14d6b024f1a526bb0234f52840389927257beb670610081360e5a03c5df9c258" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3296,28 +3715,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.86" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.39", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.86" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "wasm-streams" -version = "0.2.3" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" dependencies = [ "futures-util", "js-sys", @@ -3328,9 +3747,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.63" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bdd9ef4e984da1187bf8110c5cf5b845fbc87a23602cdf912386a76fcd3a7c2" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3338,22 +3757,19 @@ dependencies = [ [[package]] name = "webpki" -version = "0.22.0" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring", - "untrusted", + "ring 0.17.7", + "untrusted 0.9.0", ] [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "winapi" @@ -3373,9 +3789,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] @@ -3387,175 +3803,161 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" -dependencies = [ - "windows-targets 0.48.0", -] - -[[package]] -name = "windows-sys" -version = "0.42.0" +name = "windows-core" +version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows-targets 0.48.5", ] [[package]] name = "windows-sys" -version = "0.45.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.42.2", + "windows-targets 0.48.5", ] [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-targets" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] name = "xmlparser" -version = "0.13.5" +version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d25c75bf9ea12c4040a97f829154768bbbce366287e2dc044af160cd79a13fd" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "xz2" @@ -3566,38 +3968,75 @@ dependencies = [ "lzma-sys", ] +[[package]] +name = "zerocopy" +version = "0.7.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d075cf85bbb114e933343e087b92f2146bac0d55b534cbb8188becf0039948e" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86cd5ca076997b97ef09d3ad65efe811fa68c9e874cb636ccb211223a813b0c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + [[package]] name = "zeroize" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe 6.0.6", +] [[package]] name = "zstd" -version = "0.12.3+zstd.1.5.2" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76eea132fb024e0e13fd9c2f5d5d595d8a967aa72382ac2f9d39fcc95afd0806" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" dependencies = [ - "zstd-safe", + "zstd-safe 7.0.0", ] [[package]] name = "zstd-safe" -version = "6.0.5+zstd.1.5.4" +version = "6.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" dependencies = [ "libc", "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" -version = "2.0.8+zstd.1.5.5" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" dependencies = [ "cc", - "libc", "pkg-config", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index a604e017b7447..5ce318aea3ac3 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,28 +18,37 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "26.0.0" +version = "33.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] license = "Apache-2.0" homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.62" +rust-version = "1.70" readme = "README.md" [dependencies] -arrow = "41.0.0" +arrow = "49.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "26.0.0" } +datafusion = { path = "../datafusion/core", version = "33.0.0", features = ["avro", "crypto_expressions", "encoding_expressions", "parquet", "regex_expressions", "unicode_expressions", "compression"] } dirs = "4.0.0" env_logger = "0.9" mimalloc = { version = "0.1", default-features = false } -object_store = { version = "0.6.1", features = ["aws", "gcp"] } +object_store = { version = "0.8.0", features = ["aws", "gcp"] } parking_lot = { version = "0.12" } +parquet = { version = "49.0.0", default-features = false } +regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } url = "2.2" + +[dev-dependencies] +assert_cmd = "2.0" +ctor = "0.2.0" +datafusion-common = { path = "../datafusion/common" } +predicates = "3.0" +rstest = "0.17" diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index bc0f821bbdd1d..07f2b888158c9 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.62 as builder +FROM rust:1.70 as builder -COPY ./datafusion /usr/src/datafusion +COPY . /usr/src/arrow-datafusion +COPY ./datafusion /usr/src/arrow-datafusion/datafusion -COPY ./datafusion-cli /usr/src/datafusion-cli +COPY ./datafusion-cli /usr/src/arrow-datafusion/datafusion-cli -WORKDIR /usr/src/datafusion-cli +WORKDIR /usr/src/arrow-datafusion/datafusion-cli RUN rustup component add rustfmt @@ -29,7 +30,7 @@ RUN cargo build --release FROM debian:bullseye-slim -COPY --from=builder /usr/src/datafusion-cli/target/release/datafusion-cli /usr/local/bin +COPY --from=builder /usr/src/arrow-datafusion/datafusion-cli/target/release/datafusion-cli /usr/local/bin ENTRYPOINT ["datafusion-cli"] diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index dbd6751a4f769..d790e3118a116 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -16,8 +16,8 @@ // under the License. use async_trait::async_trait; -use datafusion::catalog::catalog::{CatalogList, CatalogProvider}; use datafusion::catalog::schema::SchemaProvider; +use datafusion::catalog::{CatalogList, CatalogProvider}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableUrl, }; diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 5563c31bcc53d..f7f36b6f9d512 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -25,6 +25,7 @@ use clap::ArgEnum; use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::exec_err; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; use std::fs::File; @@ -81,9 +82,7 @@ impl Command { exec_from_lines(ctx, &mut BufReader::new(file), print_options).await; Ok(()) } else { - Err(DataFusionError::Execution( - "Required filename argument is missing".into(), - )) + exec_err!("Required filename argument is missing") } } Self::QuietMode(quiet) => { @@ -101,9 +100,7 @@ impl Command { } Ok(()) } - Self::Quit => Err(DataFusionError::Execution( - "Unexpected quit, this should be handled outside".into(), - )), + Self::Quit => exec_err!("Unexpected quit, this should be handled outside"), Self::ListFunctions => display_all_functions(), Self::SearchFunctions(function) => { if let Ok(func) = function.parse::() { @@ -111,13 +108,12 @@ impl Command { println!("{}", details); Ok(()) } else { - let msg = format!("{} is not a supported function", function); - Err(DataFusionError::Execution(msg)) + exec_err!("{function} is not a supported function") } } - Self::OutputFormat(_) => Err(DataFusionError::Execution( - "Unexpected change output format, this should be handled outside".into(), - )), + Self::OutputFormat(_) => exec_err!( + "Unexpected change output format, this should be handled outside" + ), } } @@ -230,11 +226,11 @@ impl OutputFormat { println!("Output format is {:?}.", print_options.format); Ok(()) } else { - Err(DataFusionError::Execution(format!( + exec_err!( "{:?} is not a valid format type [possible values: {:?}]", format, PrintFormat::value_variants() - ))) + ) } } } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index cec0fe03739a5..8af534cd13754 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -24,8 +24,10 @@ use crate::{ get_gcs_object_store_builder, get_oss_object_store_builder, get_s3_object_store_builder, }, - print_options::PrintOptions, + print_options::{MaxRows, PrintOptions}, }; +use datafusion::common::plan_datafusion_err; +use datafusion::sql::{parser::DFParser, sqlparser::dialect::dialect_from_str}; use datafusion::{ datasource::listing::ListingTableUrl, error::{DataFusionError, Result}, @@ -41,6 +43,20 @@ use std::time::Instant; use std::{fs::File, sync::Arc}; use url::Url; +/// run and execute SQL statements and commands, against a context with the given print options +pub async fn exec_from_commands( + ctx: &mut SessionContext, + print_options: &PrintOptions, + commands: Vec, +) { + for sql in commands { + match exec_and_print(ctx, print_options, sql).await { + Ok(_) => {} + Err(err) => println!("{err}"), + } + } +} + /// run and execute SQL statements and commands from a file, against a context with the given print options pub async fn exec_from_lines( ctx: &mut SessionContext, @@ -58,11 +74,8 @@ pub async fn exec_from_lines( let line = line.trim_end(); query.push_str(line); if line.ends_with(';') { - match unescape_input(line) { - Ok(sql) => match exec_and_print(ctx, print_options, sql).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, + match exec_and_print(ctx, print_options, query).await { + Ok(_) => {} Err(err) => eprintln!("{err}"), } query = "".to_owned(); @@ -77,7 +90,8 @@ pub async fn exec_from_lines( } // run the left over query if the last statement doesn't contain ‘;’ - if !query.is_empty() { + // ignore if it only consists of '\n' + if query.contains(|c| c != '\n') { match exec_and_print(ctx, print_options, query).await { Ok(_) => {} Err(err) => println!("{err}"), @@ -106,7 +120,9 @@ pub async fn exec_from_repl( print_options: &mut PrintOptions, ) -> rustyline::Result<()> { let mut rl = Editor::new()?; - rl.set_helper(Some(CliHelper::default())); + rl.set_helper(Some(CliHelper::new( + &ctx.task_ctx().session_config().options().sql_parser.dialect, + ))); rl.load_history(".history").ok(); let mut print_options = print_options.clone(); @@ -149,13 +165,14 @@ pub async fn exec_from_repl( } Ok(line) => { rl.add_history_entry(line.trim_end())?; - match unescape_input(&line) { - Ok(sql) => match exec_and_print(ctx, &print_options, sql).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, + match exec_and_print(ctx, &print_options, line).await { + Ok(_) => {} Err(err) => eprintln!("{err}"), } + // dialect might have changed + rl.helper_mut().unwrap().set_dialect( + &ctx.task_ctx().session_config().options().sql_parser.dialect, + ); } Err(ReadlineError::Interrupted) => { println!("^C"); @@ -182,24 +199,54 @@ async fn exec_and_print( ) -> Result<()> { let now = Instant::now(); - let plan = ctx.state().create_logical_plan(&sql).await?; - let df = match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { + let sql = unescape_input(&sql)?; + let task_ctx = ctx.task_ctx(); + let dialect = &task_ctx.session_config().options().sql_parser.dialect; + let dialect = dialect_from_str(dialect).ok_or_else(|| { + plan_datafusion_err!( + "Unsupported SQL dialect: {dialect}. Available dialects: \ + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." + ) + })?; + let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; + for statement in statements { + let mut plan = ctx.state().statement_to_plan(statement).await?; + + // For plans like `Explain` ignore `MaxRows` option and always display all rows + let should_ignore_maxrows = matches!( + plan, + LogicalPlan::Explain(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Analyze(_) + ); + + // Note that cmd is a mutable reference so that create_external_table function can remove all + // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion + // will raise Configuration errors. + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { create_external_table(ctx, cmd).await?; - ctx.execute_logical_plan(plan).await? } - _ => ctx.execute_logical_plan(plan).await?, - }; + let df = ctx.execute_logical_plan(plan).await?; + let results = df.collect().await?; - let results = df.collect().await?; - print_options.print_batches(&results, now)?; + let print_options = if should_ignore_maxrows { + PrintOptions { + maxrows: MaxRows::Unlimited, + ..print_options.clone() + } + } else { + print_options.clone() + }; + print_options.print_batches(&results, now)?; + } Ok(()) } async fn create_external_table( ctx: &SessionContext, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result<()> { let table_path = ListingTableUrl::parse(&cmd.location)?; let scheme = table_path.scheme(); @@ -240,18 +287,35 @@ async fn create_external_table( #[cfg(test)] mod tests { + use std::str::FromStr; + use super::*; + use datafusion::common::plan_err; + use datafusion_common::{file_options::StatementOptions, FileTypeWriterOptions}; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; - - match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { - create_external_table(&ctx, cmd).await?; - } - _ => assert!(false), - }; + let mut plan = ctx.state().create_logical_plan(sql).await?; + + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { + create_external_table(&ctx, cmd).await?; + let options: Vec<_> = cmd + .options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let statement_options = StatementOptions::new(options); + let file_type = + datafusion_common::FileType::from_str(cmd.file_type.as_str())?; + + let _file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + ctx.state().config_options(), + &statement_options, + )?; + } else { + return plan_err!("LogicalPlan is not a CreateExternalTable"); + } ctx.runtime_env() .object_store(ListingTableUrl::parse(location)?)?; @@ -302,7 +366,7 @@ mod tests { async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -312,14 +376,15 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("No such file or directory")); + assert!(err.to_string().contains("os error 2")); // for service_account_key let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_key' '{service_account_key}') LOCATION '{location}'"); let err = create_external_table_test(location, &sql) .await - .unwrap_err(); - assert!(err.to_string().contains("No RSA key found in pem file")); + .unwrap_err() + .to_string(); + assert!(err.contains("No RSA key found in pem file"), "{err}"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -327,22 +392,19 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("No such file or directory")); + assert!(err.to_string().contains("os error 2")); Ok(()) } #[tokio::test] async fn create_external_table_local_file() -> Result<()> { - let location = "/path/to/file.parquet"; + let location = "path/to/file.parquet"; // Ensure that local files are also registered let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); - let err = create_external_table_test(location, &sql) - .await - .unwrap_err(); - assert!(err.to_string().contains("No such file or directory")); + create_external_table_test(location, &sql).await.unwrap(); Ok(()) } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index eeebe713d716e..24f3399ee2be5 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -16,12 +16,26 @@ // under the License. //! Functions that are query-able and searchable via the `\h` command -use arrow::array::StringArray; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::{Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use datafusion::common::DataFusionError; +use datafusion::common::{plan_err, Column}; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::scalar::ScalarValue; +use parquet::file::reader::FileReader; +use parquet::file::serialized_reader::SerializedFileReader; +use parquet::file::statistics::Statistics; use std::fmt; +use std::fs::File; use std::str::FromStr; use std::sync::Arc; @@ -196,3 +210,208 @@ pub fn display_all_functions() -> Result<()> { println!("{}", pretty_format_batches(&[batch]).unwrap()); Ok(()) } + +/// PARQUET_META table function +struct ParquetMetadataTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ParquetMetadataTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(MemoryExec::try_new( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +pub struct ParquetMetadataFunc {} + +impl TableFunctionImpl for ParquetMetadataFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let filename = match exprs.get(0) { + Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") + _ => { + return plan_err!( + "parquet_metadata requires string argument as its input" + ); + } + }; + + let file = File::open(filename.clone())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("filename", DataType::Utf8, true), + Field::new("row_group_id", DataType::Int64, true), + Field::new("row_group_num_rows", DataType::Int64, true), + Field::new("row_group_num_columns", DataType::Int64, true), + Field::new("row_group_bytes", DataType::Int64, true), + Field::new("column_id", DataType::Int64, true), + Field::new("file_offset", DataType::Int64, true), + Field::new("num_values", DataType::Int64, true), + Field::new("path_in_schema", DataType::Utf8, true), + Field::new("type", DataType::Utf8, true), + Field::new("stats_min", DataType::Utf8, true), + Field::new("stats_max", DataType::Utf8, true), + Field::new("stats_null_count", DataType::Int64, true), + Field::new("stats_distinct_count", DataType::Int64, true), + Field::new("stats_min_value", DataType::Utf8, true), + Field::new("stats_max_value", DataType::Utf8, true), + Field::new("compression", DataType::Utf8, true), + Field::new("encodings", DataType::Utf8, true), + Field::new("index_page_offset", DataType::Int64, true), + Field::new("dictionary_page_offset", DataType::Int64, true), + Field::new("data_page_offset", DataType::Int64, true), + Field::new("total_compressed_size", DataType::Int64, true), + Field::new("total_uncompressed_size", DataType::Int64, true), + ])); + + // construct recordbatch from metadata + let mut filename_arr = vec![]; + let mut row_group_id_arr = vec![]; + let mut row_group_num_rows_arr = vec![]; + let mut row_group_num_columns_arr = vec![]; + let mut row_group_bytes_arr = vec![]; + let mut column_id_arr = vec![]; + let mut file_offset_arr = vec![]; + let mut num_values_arr = vec![]; + let mut path_in_schema_arr = vec![]; + let mut type_arr = vec![]; + let mut stats_min_arr = vec![]; + let mut stats_max_arr = vec![]; + let mut stats_null_count_arr = vec![]; + let mut stats_distinct_count_arr = vec![]; + let mut stats_min_value_arr = vec![]; + let mut stats_max_value_arr = vec![]; + let mut compression_arr = vec![]; + let mut encodings_arr = vec![]; + let mut index_page_offset_arr = vec![]; + let mut dictionary_page_offset_arr = vec![]; + let mut data_page_offset_arr = vec![]; + let mut total_compressed_size_arr = vec![]; + let mut total_uncompressed_size_arr = vec![]; + for (rg_idx, row_group) in metadata.row_groups().iter().enumerate() { + for (col_idx, column) in row_group.columns().iter().enumerate() { + filename_arr.push(filename.clone()); + row_group_id_arr.push(rg_idx as i64); + row_group_num_rows_arr.push(row_group.num_rows()); + row_group_num_columns_arr.push(row_group.num_columns() as i64); + row_group_bytes_arr.push(row_group.total_byte_size()); + column_id_arr.push(col_idx as i64); + file_offset_arr.push(column.file_offset()); + num_values_arr.push(column.num_values()); + path_in_schema_arr.push(column.column_path().to_string()); + type_arr.push(column.column_type().to_string()); + if let Some(s) = column.statistics() { + let (min_val, max_val) = if s.has_min_max_set() { + let (min_val, max_val) = match s { + Statistics::Boolean(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int32(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int64(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Int96(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Float(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::Double(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::ByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + Statistics::FixedLenByteArray(val) => { + (val.min().to_string(), val.max().to_string()) + } + }; + (Some(min_val), Some(max_val)) + } else { + (None, None) + }; + stats_min_arr.push(min_val.clone()); + stats_max_arr.push(max_val.clone()); + stats_null_count_arr.push(Some(s.null_count() as i64)); + stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_min_value_arr.push(min_val); + stats_max_value_arr.push(max_val); + } else { + stats_min_arr.push(None); + stats_max_arr.push(None); + stats_null_count_arr.push(None); + stats_distinct_count_arr.push(None); + stats_min_value_arr.push(None); + stats_max_value_arr.push(None); + }; + compression_arr.push(format!("{:?}", column.compression())); + encodings_arr.push(format!("{:?}", column.encodings())); + index_page_offset_arr.push(column.index_page_offset()); + dictionary_page_offset_arr.push(column.dictionary_page_offset()); + data_page_offset_arr.push(column.data_page_offset()); + total_compressed_size_arr.push(column.compressed_size()); + total_uncompressed_size_arr.push(column.uncompressed_size()); + } + } + + let rb = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(filename_arr)), + Arc::new(Int64Array::from(row_group_id_arr)), + Arc::new(Int64Array::from(row_group_num_rows_arr)), + Arc::new(Int64Array::from(row_group_num_columns_arr)), + Arc::new(Int64Array::from(row_group_bytes_arr)), + Arc::new(Int64Array::from(column_id_arr)), + Arc::new(Int64Array::from(file_offset_arr)), + Arc::new(Int64Array::from(num_values_arr)), + Arc::new(StringArray::from(path_in_schema_arr)), + Arc::new(StringArray::from(type_arr)), + Arc::new(StringArray::from(stats_min_arr)), + Arc::new(StringArray::from(stats_max_arr)), + Arc::new(Int64Array::from(stats_null_count_arr)), + Arc::new(Int64Array::from(stats_distinct_count_arr)), + Arc::new(StringArray::from(stats_min_value_arr)), + Arc::new(StringArray::from(stats_max_value_arr)), + Arc::new(StringArray::from(compression_arr)), + Arc::new(StringArray::from(encodings_arr)), + Arc::new(Int64Array::from(index_page_offset_arr)), + Arc::new(Int64Array::from(dictionary_page_offset_arr)), + Arc::new(Int64Array::from(data_page_offset_arr)), + Arc::new(Int64Array::from(total_compressed_size_arr)), + Arc::new(Int64Array::from(total_uncompressed_size_arr)), + ], + )?; + + let parquet_metadata = ParquetMetadataTable { schema, batch: rb }; + Ok(Arc::new(parquet_metadata)) + } +} diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 15464eec13a05..69d412db5afae 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -18,8 +18,10 @@ //! Helper that helps with interactive editing, including multi-line parsing and validation, //! and auto-completion for file name during creating external table. +use datafusion::common::sql_err; use datafusion::error::DataFusionError; use datafusion::sql::parser::{DFParser, Statement}; +use datafusion::sql::sqlparser::dialect::dialect_from_str; use datafusion::sql::sqlparser::parser::ParserError; use rustyline::completion::Completer; use rustyline::completion::FilenameCompleter; @@ -34,12 +36,25 @@ use rustyline::Context; use rustyline::Helper; use rustyline::Result; -#[derive(Default)] pub struct CliHelper { completer: FilenameCompleter, + dialect: String, } impl CliHelper { + pub fn new(dialect: &str) -> Self { + Self { + completer: FilenameCompleter::new(), + dialect: dialect.into(), + } + } + + pub fn set_dialect(&mut self, dialect: &str) { + if dialect != self.dialect { + self.dialect = dialect.to_string(); + } + } + fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { let sql = match unescape_input(sql) { @@ -50,13 +65,21 @@ impl CliHelper { )))) } }; - match DFParser::parse_sql(&sql) { + + let dialect = match dialect_from_str(&self.dialect) { + Some(dialect) => dialect, + None => { + return Ok(ValidationResult::Invalid(Some(format!( + " 🤔 Invalid dialect: {}", + self.dialect + )))) + } + }; + + match DFParser::parse_sql_with_dialect(&sql, dialect.as_ref()) { Ok(statements) if statements.is_empty() => Ok(ValidationResult::Invalid( Some(" 🤔 You entered an empty statement".to_string()), )), - Ok(statements) if statements.len() > 1 => Ok(ValidationResult::Invalid( - Some(" 🤔 You entered more than one statement".to_string()), - )), Ok(_statements) => Ok(ValidationResult::Valid(None)), Err(err) => Ok(ValidationResult::Invalid(Some(format!( " 🤔 Invalid statement: {err}", @@ -71,6 +94,12 @@ impl CliHelper { } } +impl Default for CliHelper { + fn default() -> Self { + Self::new("generic") + } +} + impl Highlighter for CliHelper {} impl Hinter for CliHelper { @@ -134,9 +163,10 @@ pub fn unescape_input(input: &str) -> datafusion::error::Result { 't' => '\t', '\\' => '\\', _ => { - return Err(DataFusionError::SQL(ParserError::TokenizerError( - format!("unsupported escape char: '\\{}'", next_char), - ))) + return sql_err!(ParserError::TokenizerError(format!( + "unsupported escape char: '\\{}'", + next_char + ),)) } }); } @@ -223,4 +253,24 @@ mod tests { Ok(()) } + + #[test] + fn sql_dialect() -> Result<()> { + let mut validator = CliHelper::default(); + + // shoule be invalid in generic dialect + let result = + readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; + assert!( + matches!(result, ValidationResult::Invalid(Some(e)) if e.contains("Invalid statement")) + ); + + // valid in postgresql dialect + validator.set_dialect("postgresql"); + let result = + readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; + assert!(matches!(result, ValidationResult::Valid(None))); + + Ok(()) + } } diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 4c1dd2f94e059..8b1a9816afc09 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -18,20 +18,45 @@ use clap::Parser; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; +use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ - exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, + exec, + print_format::PrintFormat, + print_options::{MaxRows, PrintOptions}, + DATAFUSION_CLI_VERSION, }; use mimalloc::MiMalloc; +use std::collections::HashMap; use std::env; use std::path::Path; -use std::sync::Arc; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; +#[derive(PartialEq, Debug)] +enum PoolType { + Greedy, + Fair, +} + +impl FromStr for PoolType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Greedy" | "greedy" => Ok(PoolType::Greedy), + "Fair" | "fair" => Ok(PoolType::Fair), + _ => Err(format!("Invalid memory pool type '{}'", s)), + } + } +} + #[derive(Debug, Parser, PartialEq)] #[clap(author, version, about, long_about= None)] struct Args { @@ -44,13 +69,29 @@ struct Args { data_path: Option, #[clap( - short = 'c', + short = 'b', long, help = "The batch size of each query, or use DataFusion default", validator(is_valid_batch_size) )] batch_size: Option, + #[clap( + short = 'c', + long, + multiple_values = true, + help = "Execute the given command string(s), then exit" + )] + command: Vec, + + #[clap( + short = 'm', + long, + help = "The memory pool limitation (e.g. '10g'), default to None (no limit)", + validator(is_valid_memory_pool_size) + )] + memory_limit: Option, + #[clap( short, long, @@ -79,6 +120,19 @@ struct Args { help = "Reduce printing other than the results and work quietly" )] quiet: bool, + + #[clap( + long, + help = "Specify the memory pool type 'greedy' or 'fair', default to 'greedy'" + )] + mem_pool_type: Option, + + #[clap( + long, + help = "The max number of rows to display for 'Table' format\n[default: 40] [possible values: numbers(0/10/...), inf(no limit)]", + default_value = "40" + )] + maxrows: MaxRows, } #[tokio::main] @@ -101,21 +155,47 @@ pub async fn main() -> Result<()> { session_config = session_config.with_batch_size(batch_size); }; - let runtime_env = create_runtime_env()?; + let rn_config = RuntimeConfig::new(); + let rn_config = + // set memory pool size + if let Some(memory_limit) = args.memory_limit { + let memory_limit = extract_memory_pool_size(&memory_limit).unwrap(); + // set memory pool type + if let Some(mem_pool_type) = args.mem_pool_type { + match mem_pool_type { + PoolType::Greedy => rn_config + .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))), + PoolType::Fair => rn_config + .with_memory_pool(Arc::new(FairSpillPool::new(memory_limit))), + } + } else { + rn_config + .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))) + } + } else { + rn_config + }; + + let runtime_env = create_runtime_env(rn_config.clone())?; + let mut ctx = - SessionContext::with_config_rt(session_config.clone(), Arc::new(runtime_env)); + SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)); ctx.refresh_catalogs().await?; // install dynamic catalog provider that knows how to open files ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( ctx.state().catalog_list(), ctx.state_weak_ref(), ))); + // register `parquet_metadata` table function to get metadata from parquet files + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, + maxrows: args.maxrows, }; + let commands = args.command; let files = args.file; let rc = match args.rc { Some(file) => file, @@ -132,22 +212,28 @@ pub async fn main() -> Result<()> { } }; - if !files.is_empty() { - exec::exec_from_files(files, &mut ctx, &print_options).await; - Ok(()) - } else { + if commands.is_empty() && files.is_empty() { if !rc.is_empty() { exec::exec_from_files(rc, &mut ctx, &print_options).await } // TODO maybe we can have thiserror for cli but for now let's keep it simple - exec::exec_from_repl(&mut ctx, &mut print_options) + return exec::exec_from_repl(&mut ctx, &mut print_options) .await - .map_err(|e| DataFusionError::External(Box::new(e))) + .map_err(|e| DataFusionError::External(Box::new(e))); + } + + if !files.is_empty() { + exec::exec_from_files(files, &mut ctx, &print_options).await; + } + + if !commands.is_empty() { + exec::exec_from_commands(&mut ctx, &print_options, commands).await; } + + Ok(()) } -fn create_runtime_env() -> Result { - let rn_config = RuntimeConfig::new(); +fn create_runtime_env(rn_config: RuntimeConfig) -> Result { RuntimeEnv::new(rn_config) } @@ -173,3 +259,165 @@ fn is_valid_batch_size(size: &str) -> Result<(), String> { _ => Err(format!("Invalid batch size '{}'", size)), } } + +fn is_valid_memory_pool_size(size: &str) -> Result<(), String> { + match extract_memory_pool_size(size) { + Ok(_) => Ok(()), + Err(e) => Err(e), + } +} + +#[derive(Debug, Clone, Copy)] +enum ByteUnit { + Byte, + KiB, + MiB, + GiB, + TiB, +} + +impl ByteUnit { + fn multiplier(&self) -> usize { + match self { + ByteUnit::Byte => 1, + ByteUnit::KiB => 1 << 10, + ByteUnit::MiB => 1 << 20, + ByteUnit::GiB => 1 << 30, + ByteUnit::TiB => 1 << 40, + } + } +} + +fn extract_memory_pool_size(size: &str) -> Result { + fn byte_suffixes() -> &'static HashMap<&'static str, ByteUnit> { + static BYTE_SUFFIXES: OnceLock> = OnceLock::new(); + BYTE_SUFFIXES.get_or_init(|| { + let mut m = HashMap::new(); + m.insert("b", ByteUnit::Byte); + m.insert("k", ByteUnit::KiB); + m.insert("kb", ByteUnit::KiB); + m.insert("m", ByteUnit::MiB); + m.insert("mb", ByteUnit::MiB); + m.insert("g", ByteUnit::GiB); + m.insert("gb", ByteUnit::GiB); + m.insert("t", ByteUnit::TiB); + m.insert("tb", ByteUnit::TiB); + m + }) + } + + fn suffix_re() -> &'static regex::Regex { + static SUFFIX_REGEX: OnceLock = OnceLock::new(); + SUFFIX_REGEX.get_or_init(|| regex::Regex::new(r"^(-?[0-9]+)([a-z]+)?$").unwrap()) + } + + let lower = size.to_lowercase(); + if let Some(caps) = suffix_re().captures(&lower) { + let num_str = caps.get(1).unwrap().as_str(); + let num = num_str.parse::().map_err(|_| { + format!("Invalid numeric value in memory pool size '{}'", size) + })?; + + let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); + let unit = byte_suffixes() + .get(suffix) + .ok_or_else(|| format!("Invalid memory pool size '{}'", size))?; + + Ok(num * unit.multiplier()) + } else { + Err(format!("Invalid memory pool size '{}'", size)) + } +} + +#[cfg(test)] +mod tests { + use datafusion::assert_batches_eq; + + use super::*; + + fn assert_conversion(input: &str, expected: Result) { + let result = extract_memory_pool_size(input); + match expected { + Ok(v) => assert_eq!(result.unwrap(), v), + Err(e) => assert_eq!(result.unwrap_err(), e), + } + } + + #[test] + fn memory_pool_size() -> Result<(), String> { + // Test basic sizes without suffix, assumed to be bytes + assert_conversion("5", Ok(5)); + assert_conversion("100", Ok(100)); + + // Test various units + assert_conversion("5b", Ok(5)); + assert_conversion("4k", Ok(4 * 1024)); + assert_conversion("4kb", Ok(4 * 1024)); + assert_conversion("20m", Ok(20 * 1024 * 1024)); + assert_conversion("20mb", Ok(20 * 1024 * 1024)); + assert_conversion("2g", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2gb", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("3t", Ok(3 * 1024 * 1024 * 1024 * 1024)); + assert_conversion("4tb", Ok(4 * 1024 * 1024 * 1024 * 1024)); + + // Test case insensitivity + assert_conversion("4K", Ok(4 * 1024)); + assert_conversion("4KB", Ok(4 * 1024)); + assert_conversion("20M", Ok(20 * 1024 * 1024)); + assert_conversion("20MB", Ok(20 * 1024 * 1024)); + assert_conversion("2G", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2GB", Ok(2 * 1024 * 1024 * 1024)); + assert_conversion("2T", Ok(2 * 1024 * 1024 * 1024 * 1024)); + + // Test invalid input + assert_conversion( + "invalid", + Err("Invalid memory pool size 'invalid'".to_string()), + ); + assert_conversion("4kbx", Err("Invalid memory pool size '4kbx'".to_string())); + assert_conversion( + "-20mb", + Err("Invalid numeric value in memory pool size '-20mb'".to_string()), + ); + assert_conversion( + "-100", + Err("Invalid numeric value in memory pool size '-100'".to_string()), + ); + assert_conversion( + "12k12k", + Err("Invalid memory pool size '12k12k'".to_string()), + ); + + Ok(()) + } + + #[tokio::test] + async fn test_parquet_metadata_works() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + + // input with single quote + let sql = + "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + let excepted = [ + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + "| ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | \"f0.list.item\" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 |", + "+-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+", + ]; + assert_batches_eq!(excepted, &rbs); + + // input with double quote + let sql = + "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_batches_eq!(excepted, &rbs); + + Ok(()) + } +} diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 46b03a0a36a25..9d79c7e0ec78e 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -30,20 +30,23 @@ use url::Url; pub async fn get_s3_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + // These options are datafusion-cli specific and must be removed before passing through to datafusion. + // Otherwise, a Configuration error will be raised. + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { + println!("removing secret access key!"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); - if let Some(session_token) = cmd.options.get("session_token") { + if let Some(session_token) = cmd.options.remove("session_token") { builder = builder.with_token(session_token); } } else { @@ -57,8 +60,7 @@ pub async fn get_s3_object_store_builder( .ok_or_else(|| { DataFusionError::ObjectStore(object_store::Error::Generic { store: "S3", - source: format!("Failed to get S3 credentials from environment") - .into(), + source: "Failed to get S3 credentials from environment".into(), }) })? .clone(); @@ -67,7 +69,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_credentials(credentials); } - if let Some(region) = cmd.options.get("region") { + if let Some(region) = cmd.options.remove("region") { builder = builder.with_region(region); } @@ -100,7 +102,7 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() @@ -110,15 +112,15 @@ pub fn get_oss_object_store_builder( .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = ( - cmd.options.get("access_key_id"), - cmd.options.get("secret_access_key"), + cmd.options.remove("access_key_id"), + cmd.options.remove("secret_access_key"), ) { builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = cmd.options.get("endpoint") { + if let Some(endpoint) = cmd.options.remove("endpoint") { builder = builder.with_endpoint(endpoint); } @@ -127,21 +129,21 @@ pub fn get_oss_object_store_builder( pub fn get_gcs_object_store_builder( url: &Url, - cmd: &CreateExternalTable, + cmd: &mut CreateExternalTable, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = GoogleCloudStorageBuilder::from_env().with_bucket_name(bucket_name); - if let Some(service_account_path) = cmd.options.get("service_account_path") { + if let Some(service_account_path) = cmd.options.remove("service_account_path") { builder = builder.with_service_account_path(service_account_path); } - if let Some(service_account_key) = cmd.options.get("service_account_key") { + if let Some(service_account_key) = cmd.options.remove("service_account_key") { builder = builder.with_service_account_key(service_account_key); } if let Some(application_credentials_path) = - cmd.options.get("application_credentials_path") + cmd.options.remove("application_credentials_path") { builder = builder.with_application_credentials(application_credentials_path); } @@ -160,6 +162,8 @@ fn get_bucket_name(url: &Url) -> Result<&str> { #[cfg(test)] mod tests { + use super::*; + use datafusion::common::plan_err; use datafusion::{ datasource::listing::ListingTableUrl, logical_expr::{DdlStatement, LogicalPlan}, @@ -167,8 +171,6 @@ mod tests { }; use object_store::{aws::AmazonS3ConfigKey, gcp::GoogleConfigKey}; - use super::*; - #[tokio::test] async fn s3_object_store_builder() -> Result<()> { let access_key_id = "fake_access_key_id"; @@ -181,9 +183,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'region' '{region}', 'session_token' {session_token}) LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_s3_object_store_builder(table_url.as_ref(), cmd).await?; // get the actual configuration information, then assert_eq! let config = [ @@ -196,9 +198,7 @@ mod tests { assert_eq!(value, builder.get_config_value(&key).unwrap()); } } else { - return Err(DataFusionError::Plan( - "LogicalPlan is not a CreateExternalTable".to_string(), - )); + return plan_err!("LogicalPlan is not a CreateExternalTable"); } Ok(()) @@ -215,9 +215,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('access_key_id' '{access_key_id}', 'secret_access_key' '{secret_access_key}', 'endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_oss_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -229,9 +229,7 @@ mod tests { assert_eq!(value, builder.get_config_value(&key).unwrap()); } } else { - return Err(DataFusionError::Plan( - "LogicalPlan is not a CreateExternalTable".to_string(), - )); + return plan_err!("LogicalPlan is not a CreateExternalTable"); } Ok(()) @@ -249,9 +247,9 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('service_account_path' '{service_account_path}', 'service_account_key' '{service_account_key}', 'application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let mut plan = ctx.state().create_logical_plan(&sql).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { let builder = get_gcs_object_store_builder(table_url.as_ref(), cmd)?; // get the actual configuration information, then assert_eq! let config = [ @@ -266,9 +264,7 @@ mod tests { assert_eq!(value, builder.get_config_value(&key).unwrap()); } } else { - return Err(DataFusionError::Plan( - "LogicalPlan is not a CreateExternalTable".to_string(), - )); + return plan_err!("LogicalPlan is not a CreateExternalTable"); } Ok(()) diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index de9e140f5c0fd..0738bf6f9b47c 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -16,10 +16,12 @@ // under the License. //! Print format variants +use crate::print_options::MaxRows; use arrow::csv::writer::WriterBuilder; use arrow::json::{ArrayWriter, LineDelimitedWriter}; +use arrow::util::pretty::pretty_format_batches_with_options; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::arrow::util::pretty; +use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; use datafusion::error::{DataFusionError, Result}; use std::str::FromStr; @@ -57,7 +59,7 @@ fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result Result String { + let lines: Vec = s.lines().map(String::from).collect(); + + assert!(lines.len() >= maxrows + 4); // 4 lines for top and bottom border + + let last_line = &lines[lines.len() - 1]; // bottom border line + + let spaces = last_line.len().saturating_sub(4); + let dotted_line = format!("| .{: Result { + match maxrows { + MaxRows::Limited(maxrows) => { + // Only format enough batches for maxrows + let mut filtered_batches = Vec::new(); + let mut batches = batches; + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); + if row_count > maxrows { + let mut accumulated_rows = 0; + + for batch in batches { + filtered_batches.push(batch.clone()); + if accumulated_rows + batch.num_rows() > maxrows { + break; + } + accumulated_rows += batch.num_rows(); + } + + batches = &filtered_batches; + } + + let mut formatted = format!( + "{}", + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, + ); + + if row_count > maxrows { + formatted = keep_only_maxrows(&formatted, maxrows); + } + + Ok(formatted) + } + MaxRows::Unlimited => { + // maxrows not specified, print all rows + Ok(format!( + "{}", + pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS)?, + )) + } + } +} + impl PrintFormat { /// print the batches to stdout using the specified format - pub fn print_batches(&self, batches: &[RecordBatch]) -> Result<()> { + /// `maxrows` option is only used for `Table` format: + /// If `maxrows` is Some(n), then at most n rows will be displayed + /// If `maxrows` is None, then every row will be displayed + pub fn print_batches(&self, batches: &[RecordBatch], maxrows: MaxRows) -> Result<()> { + if batches.is_empty() { + return Ok(()); + } + match self { Self::Csv => println!("{}", print_batches_with_sep(batches, b',')?), Self::Tsv => println!("{}", print_batches_with_sep(batches, b'\t')?), - Self::Table => pretty::print_batches(batches)?, + Self::Table => { + if maxrows == MaxRows::Limited(0) { + return Ok(()); + } + println!("{}", format_batches_with_maxrows(batches, maxrows)?,) + } Self::Json => println!("{}", batches_to_json!(ArrayWriter, batches)), Self::NdJson => { println!("{}", batches_to_json!(LineDelimitedWriter, batches)) @@ -90,7 +166,6 @@ mod tests { use super::*; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::from_slice::FromSlice; use std::sync::Arc; #[test] @@ -107,9 +182,9 @@ mod tests { let batch = RecordBatch::try_new( schema, vec![ - Arc::new(Int32Array::from_slice([1, 2, 3])), - Arc::new(Int32Array::from_slice([4, 5, 6])), - Arc::new(Int32Array::from_slice([7, 8, 9])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![7, 8, 9])), ], ) .unwrap(); @@ -137,9 +212,9 @@ mod tests { let batch = RecordBatch::try_new( schema, vec![ - Arc::new(Int32Array::from_slice([1, 2, 3])), - Arc::new(Int32Array::from_slice([4, 5, 6])), - Arc::new(Int32Array::from_slice([7, 8, 9])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![7, 8, 9])), ], ) .unwrap(); @@ -152,4 +227,72 @@ mod tests { assert_eq!("{\"a\":1,\"b\":4,\"c\":7}\n{\"a\":2,\"b\":5,\"c\":8}\n{\"a\":3,\"b\":6,\"c\":9}\n", r); Ok(()) } + + #[test] + fn test_format_batches_with_maxrows() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]) + .unwrap(); + + #[rustfmt::skip] + let all_rows_expected = [ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---+", + ].join("\n"); + + #[rustfmt::skip] + let one_row_expected = [ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| . |", + "| . |", + "| . |", + "+---+", + ].join("\n"); + + #[rustfmt::skip] + let multi_batches_expected = [ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| 2 |", + "| 3 |", + "| 1 |", + "| 2 |", + "| . |", + "| . |", + "| . |", + "+---+", + ].join("\n"); + + let no_limit = format_batches_with_maxrows(&[batch.clone()], MaxRows::Unlimited)?; + assert_eq!(all_rows_expected, no_limit); + + let maxrows_less_than_actual = + format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(1))?; + assert_eq!(one_row_expected, maxrows_less_than_actual); + let maxrows_more_than_actual = + format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(5))?; + assert_eq!(all_rows_expected, maxrows_more_than_actual); + let maxrows_equals_actual = + format_batches_with_maxrows(&[batch.clone()], MaxRows::Limited(3))?; + assert_eq!(all_rows_expected, maxrows_equals_actual); + let multi_batches = format_batches_with_maxrows( + &[batch.clone(), batch.clone(), batch.clone()], + MaxRows::Limited(5), + )?; + assert_eq!(multi_batches_expected, multi_batches); + + Ok(()) + } } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 5e3792634a4e9..0a6c8d4c36fce 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -18,37 +18,89 @@ use crate::print_format::PrintFormat; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::Result; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; use std::time::Instant; +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum MaxRows { + /// show all rows in the output + Unlimited, + /// Only show n rows + Limited(usize), +} + +impl FromStr for MaxRows { + type Err = String; + + fn from_str(maxrows: &str) -> Result { + if maxrows.to_lowercase() == "inf" + || maxrows.to_lowercase() == "infinite" + || maxrows.to_lowercase() == "none" + { + Ok(Self::Unlimited) + } else { + match maxrows.parse::() { + Ok(nrows) => Ok(Self::Limited(nrows)), + _ => Err(format!("Invalid maxrows {}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.", maxrows)), + } + } + } +} + +impl Display for MaxRows { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Unlimited => write!(f, "unlimited"), + Self::Limited(max_rows) => write!(f, "at most {max_rows}"), + } + } +} + #[derive(Debug, Clone)] pub struct PrintOptions { pub format: PrintFormat, pub quiet: bool, + pub maxrows: MaxRows, } -fn print_timing_info(row_count: usize, now: Instant) { - println!( - "{} {} in set. Query took {:.3} seconds.", +fn get_timing_info_str( + row_count: usize, + maxrows: MaxRows, + query_start_time: Instant, +) -> String { + let row_word = if row_count == 1 { "row" } else { "rows" }; + let nrows_shown_msg = match maxrows { + MaxRows::Limited(nrows) if nrows < row_count => format!(" ({} shown)", nrows), + _ => String::new(), + }; + + format!( + "{} {} in set{}. Query took {:.3} seconds.\n", row_count, - if row_count == 1 { "row" } else { "rows" }, - now.elapsed().as_secs_f64() - ); + row_word, + nrows_shown_msg, + query_start_time.elapsed().as_secs_f64() + ) } impl PrintOptions { /// print the batches to stdout using the specified format - pub fn print_batches(&self, batches: &[RecordBatch], now: Instant) -> Result<()> { - if batches.is_empty() { - if !self.quiet { - print_timing_info(0, now); - } - } else { - self.format.print_batches(batches)?; - if !self.quiet { - let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - print_timing_info(row_count, now); - } + pub fn print_batches( + &self, + batches: &[RecordBatch], + query_start_time: Instant, + ) -> Result<()> { + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); + // Elapsed time should not count time for printing batches + let timing_info = get_timing_info_str(row_count, self.maxrows, query_start_time); + + self.format.print_batches(batches, self.maxrows)?; + + if !self.quiet { + println!("{timing_info}"); } + Ok(()) } } diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs new file mode 100644 index 0000000000000..119a0aa39d3c0 --- /dev/null +++ b/datafusion-cli/tests/cli_integration.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::process::Command; + +use assert_cmd::prelude::{CommandCargoExt, OutputAssertExt}; +use predicates::prelude::predicate; +use rstest::rstest; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for tests + let _ = env_logger::try_init(); +} + +#[rstest] +#[case::exec_from_commands( + ["--command", "select 1", "--format", "json", "-q"], + "[{\"Int64(1)\":1}]\n" +)] +#[case::exec_multiple_statements( + ["--command", "select 1; select 2;", "--format", "json", "-q"], + "[{\"Int64(1)\":1}]\n[{\"Int64(2)\":2}]\n" +)] +#[case::exec_from_files( + ["--file", "tests/data/sql.txt", "--format", "json", "-q"], + "[{\"Int64(1)\":1}]\n" +)] +#[case::set_batch_size( + ["--command", "show datafusion.execution.batch_size", "--format", "json", "-q", "-b", "1"], + "[{\"name\":\"datafusion.execution.batch_size\",\"value\":\"1\"}]\n" +)] +#[test] +fn cli_quick_test<'a>( + #[case] args: impl IntoIterator, + #[case] expected: &str, +) { + let mut cmd = Command::cargo_bin("datafusion-cli").unwrap(); + cmd.args(args); + cmd.assert().stdout(predicate::eq(expected)); +} diff --git a/datafusion-cli/tests/data/sql.txt b/datafusion-cli/tests/data/sql.txt new file mode 100644 index 0000000000000..9e13a3eff4a73 --- /dev/null +++ b/datafusion-cli/tests/data/sql.txt @@ -0,0 +1 @@ +select 1; \ No newline at end of file diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 31595c980a30d..676b4aaa78c09 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -20,43 +20,39 @@ name = "datafusion-examples" description = "DataFusion usage examples" keywords = ["arrow", "query", "sql"] publish = false +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } -[[example]] -name = "avro_sql" -path = "examples/avro_sql.rs" -required-features = ["datafusion/avro"] - [dev-dependencies] arrow = { workspace = true } arrow-flight = { workspace = true } arrow-schema = { workspace = true } -async-trait = "0.1.41" -bytes = "1.4" -dashmap = "5.4" -datafusion = { path = "../datafusion/core" } +async-trait = { workspace = true } +bytes = { workspace = true } +dashmap = { workspace = true } +datafusion = { path = "../datafusion/core", features = ["avro"] } datafusion-common = { path = "../datafusion/common" } datafusion-expr = { path = "../datafusion/expr" } datafusion-optimizer = { path = "../datafusion/optimizer" } datafusion-sql = { path = "../datafusion/sql" } -env_logger = "0.10" -futures = "0.3" -log = "0.4" +env_logger = { workspace = true } +futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", default-features = false } -num_cpus = "1.13.0" -object_store = { version = "0.6.1", features = ["aws"] } -prost = { version = "0.11", default-features = false } +num_cpus = { workspace = true } +object_store = { workspace = true, features = ["aws", "http"] } +prost = { version = "0.12", default-features = false } prost-derive = { version = "0.11", default-features = false } serde = { version = "1.0.136", features = ["derive"] } -serde_json = "1.0.82" +serde_json = { workspace = true } +tempfile = { workspace = true } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } -tonic = "0.9" -url = "2.2" +tonic = "0.10" +url = { workspace = true } uuid = "1.2" diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index df6ad5a467b60..305422ccd0be0 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -44,20 +44,24 @@ cargo run --example csv_sql - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file +- [`catalog.rs`](examples/external_dependency/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file +- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde -- [`expr_api.rs`](examples/expr_api.rs): Use the `Expr` construction and simplification API -- [`flight_sql_server.rs`](examples/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients +- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and anaylze `Expr`s +- [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files -- [`query-aws-s3.rs`](examples/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 +- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 +- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF) +- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) ## Distributed -- [`flight_client.rs`](examples/flight_client.rs) and [`flight_server.rs`](examples/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. +- [`flight_client.rs`](examples/flight/flight_client.rs) and [`flight_server.rs`](examples/flight/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index f2982522a7cdd..15fb07ded4811 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -17,10 +17,11 @@ use std::{sync::Arc, vec}; +use datafusion::common::Statistics; use datafusion::{ assert_batches_eq, datasource::{ - file_format::file_type::FileCompressionType, + file_format::file_compression_type::FileCompressionType, listing::PartitionedFile, object_store::ObjectStoreUrl, physical_plan::{CsvConfig, CsvOpener, FileScanConfig, FileStream}, @@ -29,6 +30,7 @@ use datafusion::{ physical_plan::metrics::ExecutionPlanMetricsSet, test_util::aggr_test_schema, }; + use futures::StreamExt; use object_store::local::LocalFileSystem; @@ -45,6 +47,7 @@ async fn main() -> Result<()> { Some(vec![12, 0]), true, b',', + b'"', object_store, ); @@ -59,7 +62,7 @@ async fn main() -> Result<()> { object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new(path.display().to_string(), 10)]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: Some(vec![12, 0]), limit: Some(5), table_partition_cols: vec![], diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs index c883a2076d134..851fdcb626d2f 100644 --- a/datafusion-examples/examples/csv_sql.rs +++ b/datafusion-examples/examples/csv_sql.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::datasource::file_format::file_type::FileCompressionType; +use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::error::Result; use datafusion::prelude::*; diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index c426d9611c608..69f9c9530e871 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -15,27 +15,29 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::any::Any; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::dataframe::DataFrame; -use datafusion::datasource::provider_as_source; -use datafusion::datasource::{TableProvider, TableType}; +use datafusion::datasource::{provider_as_source, TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionState, TaskContext}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ - project_schema, ExecutionPlan, SendableRecordBatchStream, Statistics, + project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion_expr::{Expr, LogicalPlanBuilder}; -use std::any::Any; -use std::collections::{BTreeMap, HashMap}; -use std::fmt::{Debug, Formatter}; -use std::sync::{Arc, Mutex}; -use std::time::Duration; + +use async_trait::async_trait; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource @@ -78,7 +80,7 @@ async fn search_accounts( timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); @@ -204,6 +206,12 @@ impl CustomExec { } } +impl DisplayAs for CustomExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + write!(f, "CustomExec") + } +} + impl ExecutionPlan for CustomExec { fn as_any(&self) -> &dyn Any { self @@ -262,8 +270,4 @@ impl ExecutionPlan for CustomExec { None, )?)) } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 26fddcd226a98..ea01c53b1c624 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -18,7 +18,9 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; use datafusion::prelude::*; -use std::fs; +use std::fs::File; +use std::io::Write; +use tempfile::tempdir; /// This example demonstrates executing a simple query against an Arrow data source (Parquet) and /// fetching results, using the DataFrame trait @@ -41,12 +43,19 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // create a csv file waiting to be written + let dir = tempdir()?; + let file_path = dir.path().join("example.csv"); + let file = File::create(&file_path)?; + write_csv_file(file); + // Reading CSV file with inferred schema example - let csv_df = example_read_csv_file_with_inferred_schema().await; + let csv_df = + example_read_csv_file_with_inferred_schema(file_path.to_str().unwrap()).await; csv_df.show().await?; // Reading CSV file with defined schema - let csv_df = example_read_csv_file_with_schema().await; + let csv_df = example_read_csv_file_with_schema(file_path.to_str().unwrap()).await; csv_df.show().await?; // Reading PARQUET file and print describe @@ -59,31 +68,28 @@ async fn main() -> Result<()> { } // Function to create an test CSV file -fn create_csv_file(path: String) { +fn write_csv_file(mut file: File) { // Create the data to put into the csv file with headers let content = r#"id,time,vote,unixtime,rating a1,"10 6, 2013",3,1381017600,5.0 a2,"08 9, 2013",2,1376006400,4.5"#; // write the data - fs::write(path, content).expect("Problem with writing file!"); + file.write_all(content.as_ref()) + .expect("Problem with writing file!"); } // Example to read data from a csv file with inferred schema -async fn example_read_csv_file_with_inferred_schema() -> DataFrame { - let path = "example.csv"; - // Create a csv file using the predefined function - create_csv_file(path.to_string()); +async fn example_read_csv_file_with_inferred_schema(file_path: &str) -> DataFrame { // Create a session context let ctx = SessionContext::new(); // Register a lazy DataFrame using the context - ctx.read_csv(path, CsvReadOptions::default()).await.unwrap() + ctx.read_csv(file_path, CsvReadOptions::default()) + .await + .unwrap() } // Example to read csv file with a defined schema for the csv file -async fn example_read_csv_file_with_schema() -> DataFrame { - let path = "example.csv"; - // Create a csv file using the predefined function - create_csv_file(path.to_string()); +async fn example_read_csv_file_with_schema(file_path: &str) -> DataFrame { // Create a session context let ctx = SessionContext::new(); // Define the schema @@ -101,5 +107,5 @@ async fn example_read_csv_file_with_schema() -> DataFrame { ..Default::default() }; // Register a lazy DataFrame by using the context and option provider - ctx.read_csv(path, csv_read_option).await.unwrap() + ctx.read_csv(file_path, csv_read_option).await.unwrap() } diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 94049e59b3ab8..9fb61008b9f69 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; @@ -38,7 +39,7 @@ async fn main() -> Result<()> { Ok(()) } -//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 10; +//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -46,7 +47,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(col("t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? .aggregate(vec![], vec![avg(col("t2.c2"))])? .select(vec![avg(col("t2.c2"))])? .into_unoptimized_plan(), @@ -60,7 +61,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { Ok(()) } -//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 10 +//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -82,14 +83,14 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { Ok(()) } -//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 10 +//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(col("t1.c1").eq(col("t2.c1")))? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? .select(vec![col("t2.c2")])? .into_unoptimized_plan(), )))? diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 97abf4d552a9d..715e1ff2dce60 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -15,28 +15,43 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{BooleanArray, Int32Array}; +use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::physical_expr::{ + analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, +}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use std::sync::Arc; /// This example demonstrates the DataFusion [`Expr`] API. /// /// DataFusion comes with a powerful and extensive system for /// representing and manipulating expressions such as `A + 5` and `X -/// IN ('foo', 'bar', 'baz')` and many other constructs. +/// IN ('foo', 'bar', 'baz')`. +/// +/// In addition to building and manipulating [`Expr`]s, DataFusion +/// also comes with APIs for evaluation, simplification, and analysis. +/// +/// The code in this example shows how to: +/// 1. Create [`Exprs`] using different APIs: [`main`]` +/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] +/// 3. Simplify expressions: [`simplify_demo`] +/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the - // "fluent"-style API, like this: + // "fluent"-style API: let expr = col("a") + lit(5); - // this creates the same expression as the following though with - // much less code, + // The same same expression can be created directly, with much more code: let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, @@ -44,15 +59,51 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to evaluate expressions + evaluate_demo()?; + + // See how to simplify expressions simplify_demo()?; + // See how to analyze ranges in expressions + range_analysis_demo()?; + + Ok(()) +} + +/// DataFusion can also evaluate arbitrary expressions on Arrow arrays. +fn evaluate_demo() -> Result<()> { + // For example, let's say you have some integers in an array + let batch = RecordBatch::try_from_iter([( + "a", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 8, 7, 4])) as _, + )])?; + + // If you want to find all rows where the expression `a < 5 OR a = 8` is true + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + + // First, you make a "physical expression" from the logical `Expr` + let physical_expr = physical_expr(&batch.schema(), expr)?; + + // Now, you can evaluate the expression against the RecordBatch + let result = physical_expr.evaluate(&batch)?; + + // The result contain an array that is true only for where `a < 5 OR a = 8` + let expected_result = Arc::new(BooleanArray::from(vec![ + true, false, false, false, true, false, true, + ])) as _; + assert!( + matches!(&result, ColumnarValue::Array(r) if r == &expected_result), + "result: {:?}", + result + ); + Ok(()) } -/// In addition to easy construction, DataFusion exposes APIs for -/// working with and simplifying such expressions that call into the -/// same powerful and extensive implementation used for the query -/// engine. +/// In addition to easy construction, DataFusion exposes APIs for simplifying +/// such expression so they are more efficient to evaluate. This code is also +/// used by the query engine to optimize queries. fn simplify_demo() -> Result<()> { // For example, lets say you have has created an expression such // ts = to_timestamp("2020-09-08T12:00:00+00:00") @@ -94,7 +145,7 @@ fn simplify_demo() -> Result<()> { make_field("b", DataType::Boolean), ]) .to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -120,6 +171,64 @@ fn simplify_demo() -> Result<()> { col("i").lt(lit(10)) ); + // String --> Date simplification + // `cast('2020-09-01' as date)` --> 18500 + assert_eq!( + simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, + lit(ScalarValue::Date32(Some(18506))) + ); + + Ok(()) +} + +/// DataFusion also has APIs for analyzing predicates (boolean expressions) to +/// determine any ranges restrictions on the inputs required for the predicate +/// evaluate to true. +fn range_analysis_demo() -> Result<()> { + // For example, let's say you are interested in finding data for all days + // in the month of September, 2020 + let september_1 = ScalarValue::Date32(Some(18506)); // 2020-09-01 + let october_1 = ScalarValue::Date32(Some(18536)); // 2020-10-01 + + // The predicate to find all such days could be + // `date > '2020-09-01' AND date < '2020-10-01'` + let expr = col("date") + .gt(lit(september_1.clone())) + .and(col("date").lt(lit(october_1.clone()))); + + // Using the analysis API, DataFusion can determine that the value of `date` + // must be in the range `['2020-09-01', '2020-10-01']`. If your data is + // organized in files according to day, this information permits skipping + // entire files without reading them. + // + // While this simple example could be handled with a special case, the + // DataFusion API handles arbitrary expressions (so for example, you don't + // have to handle the case where the predicate clauses are reversed such as + // `date < '2020-10-01' AND date > '2020-09-01'` + + // As always, we need to tell DataFusion the type of column "date" + let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + + // You can provide DataFusion any known boundaries on the values of `date` + // (for example, maybe you know you only have data up to `2020-09-15`), but + // in this case, let's say we don't know any boundaries beforehand so we use + // `try_new_unknown` + let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; + + // Now, we invoke the analysis code to perform the range analysis + let physical_expr = physical_expr(&schema, expr)?; + let analysis_result = + analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + + // The results of the analysis is an range, encoded as an `Interval`, for + // each column in the schema, that must be true in order for the predicate + // to be true. + // + // In this case, we can see that, as expected, `analyze` has figured out + // that in this case, `date` must be in the range `['2020-09-01', '2020-10-01']` + let expected_range = Interval::try_new(september_1, october_1)?; + assert_eq!(analysis_result.boundaries[0].interval, expected_range); + Ok(()) } @@ -132,3 +241,18 @@ fn make_ts_field(name: &str) -> Field { let tz = None; make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } + +/// Build a physical expression from a logical one, after applying simplification and type coercion +pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { + let df_schema = schema.clone().to_dfschema_ref()?; + + // Simplify + let props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); + + // apply type coercion here to ensure types match + let expr = simplifier.coerce(expr, df_schema.clone())?; + + create_physical_expr(&expr, df_schema.as_ref(), schema, &props) +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/external_dependency/catalog.rs similarity index 99% rename from datafusion-examples/examples/catalog.rs rename to datafusion-examples/examples/external_dependency/catalog.rs index 30cc2c8bd6180..aa9fd103a50c0 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/external_dependency/catalog.rs @@ -23,8 +23,8 @@ use async_trait::async_trait; use datafusion::{ arrow::util::pretty, catalog::{ - catalog::{CatalogList, CatalogProvider}, schema::SchemaProvider, + {CatalogList, CatalogProvider}, }, datasource::{ file_format::{csv::CsvFormat, parquet::ParquetFormat, FileFormat}, @@ -58,7 +58,7 @@ async fn main() -> Result<()> { // context will by default have MemoryCatalogList ctx.register_catalog_list(catlist.clone()); - // intitialize our catalog and schemas + // initialize our catalog and schemas let catalog = DirCatalog::new(); let parquet_schema = DirSchema::create( &state, diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs new file mode 100644 index 0000000000000..883da7d0d13d0 --- /dev/null +++ b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{FileType, GetExt}; + +use object_store::aws::AmazonS3Builder; +use std::env; +use std::sync::Arc; +use url::Url; + +/// This example demonstrates querying data from AmazonS3 and writing +/// the result of a query back to AmazonS3 +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + //enter region and bucket to which your credentials have GET and PUT access + let region = ""; + let bucket_name = ""; + + let s3 = AmazonS3Builder::new() + .with_bucket_name(bucket_name) + .with_region(region) + .with_access_key_id(env::var("AWS_ACCESS_KEY_ID").unwrap()) + .with_secret_access_key(env::var("AWS_SECRET_ACCESS_KEY").unwrap()) + .build()?; + + let path = format!("s3://{bucket_name}"); + let s3_url = Url::parse(&path).unwrap(); + let arc_s3 = Arc::new(s3); + ctx.runtime_env() + .register_object_store(&s3_url, arc_s3.clone()); + + let path = format!("s3://{bucket_name}/test_data/"); + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + ctx.register_listing_table("test", &path, listing_options, None, None) + .await?; + + // execute the query + let df = ctx.sql("SELECT * from test").await?; + + let out_path = format!("s3://{bucket_name}/test_write/"); + df.clone() + .write_parquet(&out_path, DataFrameWriteOptions::new(), None) + .await?; + + //write as JSON to s3 + let json_out = format!("s3://{bucket_name}/json_out"); + df.clone() + .write_json(&json_out, DataFrameWriteOptions::new()) + .await?; + + //write as csv to s3 + let csv_out = format!("s3://{bucket_name}/csv_out"); + df.write_csv(&csv_out, DataFrameWriteOptions::new(), None) + .await?; + + let file_format = ParquetFormat::default().with_enable_pruning(Some(true)); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::PARQUET.get_ext()); + ctx.register_listing_table("test2", &out_path, listing_options, None, None) + .await?; + + let df = ctx + .sql( + "SELECT * \ + FROM test2 \ + ", + ) + .await?; + + df.show_limit(20).await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query-aws-s3.rs similarity index 100% rename from datafusion-examples/examples/query-aws-s3.rs rename to datafusion-examples/examples/external_dependency/query-aws-s3.rs diff --git a/datafusion-examples/examples/flight_client.rs b/datafusion-examples/examples/flight/flight_client.rs similarity index 100% rename from datafusion-examples/examples/flight_client.rs rename to datafusion-examples/examples/flight/flight_client.rs diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs similarity index 100% rename from datafusion-examples/examples/flight_server.rs rename to datafusion-examples/examples/flight/flight_server.rs diff --git a/datafusion-examples/examples/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs similarity index 98% rename from datafusion-examples/examples/flight_sql_server.rs rename to datafusion-examples/examples/flight/flight_sql_server.rs index 1cf288b7d24d5..ed5b86d0b66ca 100644 --- a/datafusion-examples/examples/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -20,7 +20,7 @@ use arrow::record_batch::RecordBatch; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; -use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; use arrow_flight::sql::{ ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, ActionBeginTransactionResult, @@ -36,7 +36,7 @@ use arrow_flight::sql::{ TicketStatementQuery, }; use arrow_flight::{ - Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; use arrow_schema::Schema; @@ -105,7 +105,7 @@ impl FlightSqlServiceImpl { let session_config = SessionConfig::from_env() .map_err(|e| Status::internal(format!("Error building plan: {e}")))? .with_information_schema(true); - let ctx = Arc::new(SessionContext::with_config(session_config)); + let ctx = Arc::new(SessionContext::new_with_config(session_config)); let testdata = datafusion::test_util::parquet_test_data(); @@ -547,7 +547,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { info!("do_put_statement_update"); Err(Status::unimplemented("Implement do_put_statement_update")) @@ -556,7 +556,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Request>, + _request: Request, ) -> Result::DoPutStream>, Status> { info!("do_put_prepared_statement_query"); Err(Status::unimplemented( @@ -567,7 +567,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_prepared_statement_update( &self, _handle: CommandPreparedStatementUpdate, - _request: Request>, + _request: Request, ) -> Result { info!("do_put_prepared_statement_update"); // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function @@ -578,7 +578,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_substrait_plan( &self, _query: CommandStatementSubstraitPlan, - _request: Request>, + _request: Request, ) -> Result { info!("do_put_prepared_statement_update"); Err(Status::unimplemented( diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index 39013455da358..1a3dbe57be75e 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -21,7 +21,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::{ assert_batches_eq, datasource::{ - file_format::file_type::FileCompressionType, + file_format::file_compression_type::FileCompressionType, listing::PartitionedFile, object_store::ObjectStoreUrl, physical_plan::{FileScanConfig, FileStream, JsonOpener}, @@ -29,6 +29,8 @@ use datafusion::{ error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, }; +use datafusion_common::Statistics; + use futures::StreamExt; use object_store::ObjectStore; @@ -63,7 +65,7 @@ async fn main() -> Result<()> { object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new(path.to_string(), 10)]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: Some(vec![1, 0]), limit: Some(5), table_partition_cols: vec![], diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index bef8f3e5bb8f5..5cce578039e74 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); - let record_batch = result.get(0).unwrap(); + let record_batch = result.first().unwrap(); assert_eq!(1, record_batch.column(0).len()); dbg!(record_batch.columns()); diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 7bd35a7844fcf..451de96f2e914 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use datafusion::datasource::file_format::file_type::{FileType, GetExt}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_common::{FileType, GetExt}; use std::sync::Arc; /// This example demonstrates executing a simple query against an Arrow data source (a directory diff --git a/datafusion-examples/examples/query-http-csv.rs b/datafusion-examples/examples/query-http-csv.rs new file mode 100644 index 0000000000000..928d702711591 --- /dev/null +++ b/datafusion-examples/examples/query-http-csv.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; +use datafusion::prelude::*; +use object_store::http::HttpBuilder; +use std::sync::Arc; +use url::Url; + +/// This example demonstrates executing a simple query against an Arrow data source (CSV) and +/// fetching results +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // setup http object store + let base_url = Url::parse("https://github.com").unwrap(); + let http_store = HttpBuilder::new() + .with_url(base_url.clone()) + .build() + .unwrap(); + ctx.runtime_env() + .register_object_store(&base_url, Arc::new(http_store)); + + // register csv file with the execution context + ctx.register_csv( + "aggregate_test_100", + "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv", + CsvReadOptions::new(), + ) + .await?; + + // execute the query + let df = ctx + .sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5") + .await?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 2777781eb98db..5e95562033e60 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -18,9 +18,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, + AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, }; use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; use datafusion_optimizer::optimizer::Optimizer; @@ -191,7 +191,7 @@ struct MyContextProvider { } impl ContextProvider for MyContextProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { if name.table() == "person" { Ok(Arc::new(MyTableSource { schema: Arc::new(Schema::new(vec![ @@ -200,7 +200,7 @@ impl ContextProvider for MyContextProvider { ])), })) } else { - Err(DataFusionError::Plan("table not found".to_string())) + plan_err!("table not found") } } @@ -216,6 +216,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + fn options(&self) -> &ConfigOptions { &self.options } diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index e3b290cc54d78..2c797f221b2cc 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -145,7 +145,7 @@ async fn main() -> Result<()> { // the name; used to represent it in plan descriptions and in the registry, to use in SQL. "geo_mean", // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, + vec![DataType::Float64], // the return type; DataFusion expects this to match the type returned by `evaluate`. Arc::new(DataType::Float64), Volatility::Immutable, @@ -154,6 +154,10 @@ async fn main() -> Result<()> { // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; // get a DataFrame from the context // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs new file mode 100644 index 0000000000000..e120c5e7bf8e9 --- /dev/null +++ b/datafusion-examples/examples/simple_udtf.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{plan_err, DataFusionError, ScalarValue}; +use datafusion_expr::{Expr, TableType}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +// To define your own table function, you only need to do the following 3 things: +// 1. Implement your own [`TableProvider`] +// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] +// 3. Register the function using [`SessionContext::register_udtf`] + +/// This example demonstrates how to register a TableFunction +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let ctx = SessionContext::new(); + + // register the table function that will be called in SQL statements by `read_csv` + ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); + + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + + // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .await?; + df.show().await?; + + // just run, return all rows + let df = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await?; + df.show().await?; + + Ok(()) +} + +/// Table Function that mimics the [`read_csv`] function in DuckDB. +/// +/// Usage: `read_csv(filename, [limit])` +/// +/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +struct LocalCsvTable { + schema: SchemaRef, + limit: Option, + batches: Vec, +} + +#[async_trait] +impl TableProvider for LocalCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if let Some(max_return_lines) = self.limit { + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines { + let batch_lines = max_return_lines - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} +struct LocalCsvTableFunc {} + +impl TableFunctionImpl for LocalCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + return plan_err!("read_csv requires at least one string argument"); + }; + + let limit = exprs + .get(1) + .map(|expr| { + // try to simpify the expression, so 1+2 becomes 3, for example + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props); + let expr = ExprSimplifier::new(info).simplify(expr.clone())?; + + if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + Ok(limit as usize) + } else { + plan_err!("Limit must be an integer") + } + }) + .transpose()?; + + let (schema, batches) = read_csv_batches(path)?; + + let table = LocalCsvTable { + schema, + limit, + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default().infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs new file mode 100644 index 0000000000000..0d04c093e1478 --- /dev/null +++ b/datafusion-examples/examples/simple_udwf.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{ArrayRef, AsArray, Float64Array}, + datatypes::Float64Type, +}; +use arrow_schema::DataType; +use datafusion::datasource::file_format::options::CsvReadOptions; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::ScalarValue; +use datafusion_expr::{PartitionEvaluator, Volatility, WindowFrame}; + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + + // here is where we define the UDWF. We also declare its signature: + let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + + // register the window function with DataFusion so we can call it + ctx.register_udwf(smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + // Use SQL to run the new window function: + // + // `PARTITION BY car`:each distinct value of car (red, and green) + // should be treated as a separate partition (and will result in + // creating a new `PartitionEvaluator`) + // + // `ORDER BY time`: within each partition ('green' or 'red') the + // rows will be be ordered by the value in the `time` column + // + // `evaluate_inside_range` is invoked with a window defined by the + // SQL. In this case: + // + // The first invocation will be passed row 0, the first row in the + // partition. + // + // The second invocation will be passed rows 0 and 1, the first + // two rows in the partition. + // + // etc. + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + // this time, call the new widow function with an explicit + // window so evaluate will be invoked with each window. + // + // `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation + // sees at most 3 rows: the row before, the current row, and the 1 + // row afterward. + let df = ctx.sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ).await?; + // print the results + df.show().await?; + + // Now, run the function using the DataFrame API: + let window_expr = smooth_it.call( + vec![col("speed")], // smooth_it(speed) + vec![col("car")], // PARTITION BY car + vec![col("time").sort(true, true)], // ORDER BY time ASC + WindowFrame::new(false), + ); + let df = ctx.table("cars").await?.window(vec![window_expr])?; + + // print the results + df.show().await?; + + Ok(()) +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} + +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` (each car type in our example) +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 8235ab0cf32b8..e224b93876551 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,13 @@ # Changelog +- [33.0.0](../dev/changelog/33.0.0.md) +- [32.0.0](../dev/changelog/32.0.0.md) +- [31.0.0](../dev/changelog/31.0.0.md) +- [30.0.0](../dev/changelog/30.0.0.md) +- [29.0.0](../dev/changelog/29.0.0.md) +- [28.0.0](../dev/changelog/28.0.0.md) +- [27.0.0](../dev/changelog/27.0.0.md) - [26.0.0](../dev/changelog/26.0.0.md) - [25.0.0](../dev/changelog/25.0.0.md) - [24.0.0](../dev/changelog/24.0.0.md) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 6943fc7263e25..b69e1f7f3d108 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-common" description = "Common functionality for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -34,19 +34,31 @@ path = "src/lib.rs" [features] avro = ["apache-avro"] -default = [] -pyarrow = ["pyo3", "arrow/pyarrow"] +backtrace = [] +pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] [dependencies] -apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +apache-avro = { version = "0.16", default-features = false, features = [ + "bzip", + "snappy", + "xz", + "zstandard", +], optional = true } arrow = { workspace = true } arrow-array = { workspace = true } -chrono = { version = "0.4", default-features = false } -num_cpus = "1.13.0" -object_store = { version = "0.6.1", default-features = false, optional = true } -parquet = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", optional = true } -sqlparser = "0.34" +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +half = { version = "2.1", default-features = false } +libc = "0.2.140" +num_cpus = { workspace = true } +object_store = { workspace = true, optional = true } +parquet = { workspace = true, optional = true, default-features = true } +pyo3 = { version = "0.20.0", optional = true } +sqlparser = { workspace = true } [dev-dependencies] rand = "0.8.4" diff --git a/datafusion/common/README.md b/datafusion/common/README.md index 9bccf3f18b7f4..524ab4420d2a8 100644 --- a/datafusion/common/README.md +++ b/datafusion/common/README.md @@ -19,7 +19,7 @@ # DataFusion Common -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides common data types and utilities. diff --git a/datafusion/optimizer/src/alias.rs b/datafusion/common/src/alias.rs similarity index 98% rename from datafusion/optimizer/src/alias.rs rename to datafusion/common/src/alias.rs index 6420cc685e25e..2ee2cb4dc7add 100644 --- a/datafusion/optimizer/src/alias.rs +++ b/datafusion/common/src/alias.rs @@ -18,6 +18,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; /// A utility struct that can be used to generate unique aliases when optimizing queries +#[derive(Debug)] pub struct AliasGenerator { next_id: AtomicUsize, } diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 04ae32ec35aad..088f03e002ed3 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -34,6 +34,7 @@ use arrow::{ }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; +use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -65,6 +66,11 @@ pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) } +// Downcast ArrayRef to Decimal256Array +pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { + Ok(downcast_value!(array, Decimal256Array)) +} + // Downcast ArrayRef to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) @@ -175,23 +181,17 @@ pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondAr } // Downcast ArrayRef to IntervalYearMonthArray -pub fn as_interval_ym_array( - array: &dyn Array, -) -> Result<&IntervalYearMonthArray, DataFusionError> { +pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } // Downcast ArrayRef to IntervalDayTimeArray -pub fn as_interval_dt_array( - array: &dyn Array, -) -> Result<&IntervalDayTimeArray, DataFusionError> { +pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } // Downcast ArrayRef to IntervalMonthDayNanoArray -pub fn as_interval_mdn_array( - array: &dyn Array, -) -> Result<&IntervalMonthDayNanoArray, DataFusionError> { +pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index d138bb06cd840..2e729c128e73a 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -67,11 +67,7 @@ impl Column { } } - /// Deserialize a fully qualified name string into a column - pub fn from_qualified_name(flat_name: impl Into) -> Self { - let flat_name = flat_name.into(); - let mut idents = parse_identifiers_normalized(&flat_name); - + fn from_idents(idents: &mut Vec) -> Option { let (relation, name) = match idents.len() { 1 => (None, idents.remove(0)), 2 => ( @@ -97,9 +93,33 @@ impl Column { ), // any expression that failed to parse or has more than 4 period delimited // identifiers will be treated as an unqualified column name - _ => (None, flat_name), + _ => return None, }; - Self { relation, name } + Some(Self { relation, name }) + } + + /// Deserialize a fully qualified name string into a column + /// + /// Treats the name as a SQL identifier. For example + /// `foo.BAR` would be parsed to a reference to relation `foo`, column name `bar` (lower case) + /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` + pub fn from_qualified_name(flat_name: impl Into) -> Self { + let flat_name: &str = &flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(flat_name, false)) + .unwrap_or_else(|| Self { + relation: None, + name: flat_name.to_owned(), + }) + } + + /// Deserialize a fully qualified name string into a column preserving column text case + pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { + let flat_name: &str = &flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(flat_name, true)) + .unwrap_or_else(|| Self { + relation: None, + name: flat_name.to_owned(), + }) } /// Serialize column into a flat name string @@ -408,7 +428,7 @@ mod tests { ) .expect_err("should've failed to find field"); let expected = r#"Schema error: No field named z. Valid fields are t1.a, t1.b, t2.c, t2.d, t3.a, t3.b, t3.c, t3.d, t3.e."#; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); // ambiguous column reference let col = Column::from_name("a"); @@ -419,7 +439,7 @@ mod tests { ) .expect_err("should've found ambiguous field"); let expected = "Schema error: Ambiguous reference to unqualified field a"; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); Ok(()) } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c5ce3540fce50..03fb5ea320a04 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion Configuration Options - +//! Runtime configuration, via [`ConfigOptions`] +use crate::error::_internal_err; use crate::{DataFusionError, Result}; use std::any::Any; use std::collections::{BTreeMap, HashMap}; @@ -65,10 +65,10 @@ use std::fmt::Display; /// "field1" => self.field1.set(rem, value), /// "field2" => self.field2.set(rem, value), /// "field3" => self.field3.set(rem, value), -/// _ => Err(DataFusionError::Internal(format!( +/// _ => _internal_err!( /// "Config value \"{}\" not found on MyConfig", /// key -/// ))), +/// ), /// } /// } /// @@ -126,9 +126,9 @@ macro_rules! config_namespace { $( stringify!($field_name) => self.$field_name.set(rem, value), )* - _ => Err(DataFusionError::Internal( - format!(concat!("Config value \"{}\" not found on ", stringify!($struct_name)), key) - )) + _ => _internal_err!( + "Config value \"{}\" not found on {}", key, stringify!($struct_name) + ) } } @@ -235,11 +235,49 @@ config_namespace! { /// /// Defaults to the number of CPU cores on the system pub planning_concurrency: usize, default = num_cpus::get() + + /// Specifies the reserved memory for each spillable sort operation to + /// facilitate an in-memory merge. + /// + /// When a sort operation spills to disk, the in-memory data must be + /// sorted and merged before being written to a file. This setting reserves + /// a specific amount of memory for that in-memory sort/merge process. + /// + /// Note: This setting is irrelevant if the sort operation cannot spill + /// (i.e., if there's no `DiskManager` configured). + pub sort_spill_reservation_bytes: usize, default = 10 * 1024 * 1024 + + /// When sorting, below what size should data be concatenated + /// and sorted in a single RecordBatch rather than sorted in + /// batches and merged. + pub sort_in_place_threshold_bytes: usize, default = 1024 * 1024 + + /// Number of files to read in parallel when inferring schema and statistics + pub meta_fetch_concurrency: usize, default = 32 + + /// Guarantees a minimum level of output files running in parallel. + /// RecordBatches will be distributed in round robin fashion to each + /// parallel writer. Each writer is closed and a new file opened once + /// soft_max_rows_per_output_file is reached. + pub minimum_parallel_output_files: usize, default = 4 + + /// Target number of rows in output files when writing multiple. + /// This is a soft max, so it can be exceeded slightly. There also + /// will be one file smaller than the limit if the total + /// number of rows written is not roughly divisible by the soft max + pub soft_max_rows_per_output_file: usize, default = 50000000 + + /// This is the maximum number of RecordBatches buffered + /// for each output file being worked. Higher values can potentially + /// give faster write performance at the cost of higher peak + /// memory consumption + pub max_buffered_batches_per_output_file: usize, default = 2 + } } config_namespace! { - /// Options related to reading of parquet files + /// Options related to parquet files pub struct ParquetOptions { /// If true, reads the Parquet data page level metadata (the /// Page Index), if present, to reduce the I/O and number of @@ -270,6 +308,102 @@ config_namespace! { /// will be reordered heuristically to minimize the cost of evaluation. If false, /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false + + // The following map to parquet::file::properties::WriterProperties + + /// Sets best effort maximum size of data page in bytes + pub data_pagesize_limit: usize, default = 1024 * 1024 + + /// Sets write_batch_size in bytes + pub write_batch_size: usize, default = 1024 + + /// Sets parquet writer version + /// valid values are "1.0" and "2.0" + pub writer_version: String, default = "1.0".into() + + /// Sets default parquet compression codec + /// Valid values are: uncompressed, snappy, gzip(level), + /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// These values are not case sensitive. If NULL, uses + /// default parquet writer setting + pub compression: Option, default = Some("zstd(3)".into()) + + /// Sets if dictionary encoding is enabled. If NULL, uses + /// default parquet writer setting + pub dictionary_enabled: Option, default = None + + /// Sets best effort maximum dictionary page size, in bytes + pub dictionary_page_size_limit: usize, default = 1024 * 1024 + + /// Sets if statistics are enabled for any column + /// Valid values are: "none", "chunk", and "page" + /// These values are not case sensitive. If NULL, uses + /// default parquet writer setting + pub statistics_enabled: Option, default = None + + /// Sets max statistics size for any column. If NULL, uses + /// default parquet writer setting + pub max_statistics_size: Option, default = None + + /// Sets maximum number of rows in a row group + pub max_row_group_size: usize, default = 1024 * 1024 + + /// Sets "created by" property + pub created_by: String, default = concat!("datafusion version ", env!("CARGO_PKG_VERSION")).into() + + /// Sets column index truncate length + pub column_index_truncate_length: Option, default = None + + /// Sets best effort maximum number of rows in data page + pub data_page_row_count_limit: usize, default = usize::MAX + + /// Sets default encoding for any column + /// Valid values are: plain, plain_dictionary, rle, + /// bit_packed, delta_binary_packed, delta_length_byte_array, + /// delta_byte_array, rle_dictionary, and byte_stream_split. + /// These values are not case sensitive. If NULL, uses + /// default parquet writer setting + pub encoding: Option, default = None + + /// Sets if bloom filter is enabled for any column + pub bloom_filter_enabled: bool, default = false + + /// Sets bloom filter false positive probability. If NULL, uses + /// default parquet writer setting + pub bloom_filter_fpp: Option, default = None + + /// Sets bloom filter number of distinct values. If NULL, uses + /// default parquet writer setting + pub bloom_filter_ndv: Option, default = None + + /// Controls whether DataFusion will attempt to speed up writing + /// parquet files by serializing them in parallel. Each column + /// in each row group in each output file are serialized in parallel + /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. + pub allow_single_file_parallelism: bool, default = true + + /// By default parallel parquet writer is tuned for minimum + /// memory usage in a streaming execution plan. You may see + /// a performance benefit when writing large parquet files + /// by increasing maximum_parallel_row_group_writers and + /// maximum_buffered_record_batches_per_stream if your system + /// has idle cores and can tolerate additional memory usage. + /// Boosting these values is likely worthwhile when + /// writing out already in-memory data, such as from a cached + /// data frame. + pub maximum_parallel_row_group_writers: usize, default = 1 + + /// By default parallel parquet writer is tuned for minimum + /// memory usage in a streaming execution plan. You may see + /// a performance benefit when writing large parquet files + /// by increasing maximum_parallel_row_group_writers and + /// maximum_buffered_record_batches_per_stream if your system + /// has idle cores and can tolerate additional memory usage. + /// Boosting these values is likely worthwhile when + /// writing out already in-memory data, such as from a cached + /// data frame. + pub maximum_buffered_record_batches_per_stream: usize, default = 2 + } } @@ -293,10 +427,19 @@ config_namespace! { config_namespace! { /// Options related to query optimization pub struct OptimizerOptions { + /// When set to true, the optimizer will push a limit operation into + /// grouped aggregations which have no aggregate expressions, as a soft limit, + /// emitting groups once the limit is reached, before all rows in the group are read. + pub enable_distinct_aggregation_soft_limit: bool, default = true + /// When set to true, the physical plan optimizer will try to add round robin /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true + /// When set to true, the optimizer will attempt to perform limit operations + /// during aggregations, if possible + pub enable_topk_aggregation: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support @@ -323,10 +466,13 @@ config_namespace! { /// long runner execution, all types of joins may encounter out-of-memory errors. pub allow_symmetric_joins_without_pruning: bool, default = true - /// When set to true, file groups will be repartitioned to achieve maximum parallelism. - /// Currently supported only for Parquet format in which case - /// multiple row groups from the same file may be read concurrently. If false then each - /// row group is read serially, though different files may be read in parallel. + /// When set to `true`, file groups will be repartitioned to achieve maximum parallelism. + /// Currently Parquet and CSV formats are supported. + /// + /// If set to `true`, all files will be repartitioned evenly (i.e., a single large file + /// might be partitioned into smaller chunks) for parallel scanning. + /// If set to `false`, different files will be read in parallel, but repartitioning won't + /// happen within a single file. pub repartition_file_scans: bool, default = true /// Should DataFusion repartition data using the partitions keys to execute window @@ -351,6 +497,14 @@ config_namespace! { /// ``` pub repartition_sorts: bool, default = true + /// When true, DataFusion will opportunistically remove sorts when the data is already sorted, + /// (i.e. setting `preserve_order` to true on `RepartitionExec` and + /// using `SortPreservingMergeExec`) + /// + /// When false, DataFusion will maximize plan parallelism using + /// `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. + pub prefer_existing_sort: bool, default = false + /// When set to true, the logical plan optimizer will produce warning /// messages if any optimization rules produce errors and then proceed to the next /// rule. When set to false, any rules that produce errors will cause the query to fail @@ -370,6 +524,11 @@ config_namespace! { /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 + + /// The default filter selectivity used by Filter Statistics + /// when an exact selectivity cannot be determined. Valid values are + /// between 0 (no selectivity) and 100 (all rows are selected). + pub default_filter_selectivity: u8, default = 20 } } @@ -381,6 +540,10 @@ config_namespace! { /// When set to true, the explain statement will only print physical plans pub physical_plan_only: bool, default = false + + /// When set to true, the explain statement will print operator statistics + /// for physical plans + pub show_statistics: bool, default = false } } @@ -425,9 +588,7 @@ impl ConfigField for ConfigOptions { "optimizer" => self.optimizer.set(rem, value), "explain" => self.explain.set(rem, value), "sql_parser" => self.sql_parser.set(rem, value), - _ => Err(DataFusionError::Internal(format!( - "Config value \"{key}\" not found on ConfigOptions" - ))), + _ => _internal_err!("Config value \"{key}\" not found on ConfigOptions"), } } @@ -720,6 +881,9 @@ macro_rules! config_field { config_field!(String); config_field!(bool); config_field!(usize); +config_field!(f64); +config_field!(u8); +config_field!(u64); /// An implementation trait used to recursively walk configuration trait Visit { diff --git a/datafusion/common/src/delta.rs b/datafusion/common/src/delta.rs deleted file mode 100644 index bb71e3eb935ea..0000000000000 --- a/datafusion/common/src/delta.rs +++ /dev/null @@ -1,336 +0,0 @@ -// MIT License -// -// Copyright (c) 2020-2022 Oliver Margetts -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// Copied from chronoutil crate - -//! Contains utility functions for shifting Date objects. -use chrono::Datelike; - -/// Returns true if the year is a leap-year, as naively defined in the Gregorian calendar. -#[inline] -fn is_leap_year(year: i32) -> bool { - year % 4 == 0 && (year % 100 != 0 || year % 400 == 0) -} - -// If the day lies within the month, this function has no effect. Otherwise, it shifts -// day backwards to the final day of the month. -// XXX: No attempt is made to handle days outside the 1-31 range. -#[inline] -fn normalise_day(year: i32, month: u32, day: u32) -> u32 { - if day <= 28 { - day - } else if month == 2 { - 28 + is_leap_year(year) as u32 - } else if day == 31 && (month == 4 || month == 6 || month == 9 || month == 11) { - 30 - } else { - day - } -} - -/// Shift a date by the given number of months. -/// Ambiguous month-ends are shifted backwards as necessary. -pub fn shift_months(date: D, months: i32, sign: i32) -> D { - let months = months * sign; - let mut year = date.year() + (date.month() as i32 + months) / 12; - let mut month = (date.month() as i32 + months) % 12; - let mut day = date.day(); - - if month < 1 { - year -= 1; - month += 12; - } - - day = normalise_day(year, month as u32, day); - - // This is slow but guaranteed to succeed (short of interger overflow) - if day <= 28 { - date.with_day(day) - .unwrap() - .with_month(month as u32) - .unwrap() - .with_year(year) - .unwrap() - } else { - date.with_day(1) - .unwrap() - .with_month(month as u32) - .unwrap() - .with_year(year) - .unwrap() - .with_day(day) - .unwrap() - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime}; - - use super::*; - - #[test] - fn test_leap_year_cases() { - let _leap_years: Vec = vec![ - 1904, 1908, 1912, 1916, 1920, 1924, 1928, 1932, 1936, 1940, 1944, 1948, 1952, - 1956, 1960, 1964, 1968, 1972, 1976, 1980, 1984, 1988, 1992, 1996, 2000, 2004, - 2008, 2012, 2016, 2020, - ]; - let leap_years_1900_to_2020: HashSet = _leap_years.into_iter().collect(); - - for year in 1900..2021 { - assert_eq!(is_leap_year(year), leap_years_1900_to_2020.contains(&year)) - } - } - - #[test] - fn test_shift_months() { - let base = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); - - assert_eq!( - shift_months(base, 0, 1), - NaiveDate::from_ymd_opt(2020, 1, 31).unwrap() - ); - assert_eq!( - shift_months(base, 1, 1), - NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() - ); - assert_eq!( - shift_months(base, 2, 1), - NaiveDate::from_ymd_opt(2020, 3, 31).unwrap() - ); - assert_eq!( - shift_months(base, 3, 1), - NaiveDate::from_ymd_opt(2020, 4, 30).unwrap() - ); - assert_eq!( - shift_months(base, 4, 1), - NaiveDate::from_ymd_opt(2020, 5, 31).unwrap() - ); - assert_eq!( - shift_months(base, 5, 1), - NaiveDate::from_ymd_opt(2020, 6, 30).unwrap() - ); - assert_eq!( - shift_months(base, 6, 1), - NaiveDate::from_ymd_opt(2020, 7, 31).unwrap() - ); - assert_eq!( - shift_months(base, 7, 1), - NaiveDate::from_ymd_opt(2020, 8, 31).unwrap() - ); - assert_eq!( - shift_months(base, 8, 1), - NaiveDate::from_ymd_opt(2020, 9, 30).unwrap() - ); - assert_eq!( - shift_months(base, 9, 1), - NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() - ); - assert_eq!( - shift_months(base, 10, 1), - NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() - ); - assert_eq!( - shift_months(base, 11, 1), - NaiveDate::from_ymd_opt(2020, 12, 31).unwrap() - ); - assert_eq!( - shift_months(base, 12, 1), - NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() - ); - assert_eq!( - shift_months(base, 13, 1), - NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() - ); - - assert_eq!( - shift_months(base, 1, -1), - NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() - ); - assert_eq!( - shift_months(base, 2, -1), - NaiveDate::from_ymd_opt(2019, 11, 30).unwrap() - ); - assert_eq!( - shift_months(base, 3, -1), - NaiveDate::from_ymd_opt(2019, 10, 31).unwrap() - ); - assert_eq!( - shift_months(base, 4, -1), - NaiveDate::from_ymd_opt(2019, 9, 30).unwrap() - ); - assert_eq!( - shift_months(base, 5, -1), - NaiveDate::from_ymd_opt(2019, 8, 31).unwrap() - ); - assert_eq!( - shift_months(base, 6, -1), - NaiveDate::from_ymd_opt(2019, 7, 31).unwrap() - ); - assert_eq!( - shift_months(base, 7, -1), - NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() - ); - assert_eq!( - shift_months(base, 8, -1), - NaiveDate::from_ymd_opt(2019, 5, 31).unwrap() - ); - assert_eq!( - shift_months(base, 9, -1), - NaiveDate::from_ymd_opt(2019, 4, 30).unwrap() - ); - assert_eq!( - shift_months(base, 10, -1), - NaiveDate::from_ymd_opt(2019, 3, 31).unwrap() - ); - assert_eq!( - shift_months(base, 11, -1), - NaiveDate::from_ymd_opt(2019, 2, 28).unwrap() - ); - assert_eq!( - shift_months(base, 12, -1), - NaiveDate::from_ymd_opt(2019, 1, 31).unwrap() - ); - assert_eq!( - shift_months(base, 13, -1), - NaiveDate::from_ymd_opt(2018, 12, 31).unwrap() - ); - - assert_eq!( - shift_months(base, 1265, 1), - NaiveDate::from_ymd_opt(2125, 6, 30).unwrap() - ); - } - - #[test] - fn test_shift_months_with_overflow() { - let base = NaiveDate::from_ymd_opt(2020, 12, 31).unwrap(); - - assert_eq!(shift_months(base, 0, 1), base); - assert_eq!( - shift_months(base, 1, 1), - NaiveDate::from_ymd_opt(2021, 1, 31).unwrap() - ); - assert_eq!( - shift_months(base, 2, 1), - NaiveDate::from_ymd_opt(2021, 2, 28).unwrap() - ); - assert_eq!( - shift_months(base, 12, 1), - NaiveDate::from_ymd_opt(2021, 12, 31).unwrap() - ); - assert_eq!( - shift_months(base, 18, 1), - NaiveDate::from_ymd_opt(2022, 6, 30).unwrap() - ); - - assert_eq!( - shift_months(base, 1, -1), - NaiveDate::from_ymd_opt(2020, 11, 30).unwrap() - ); - assert_eq!( - shift_months(base, 2, -1), - NaiveDate::from_ymd_opt(2020, 10, 31).unwrap() - ); - assert_eq!( - shift_months(base, 10, -1), - NaiveDate::from_ymd_opt(2020, 2, 29).unwrap() - ); - assert_eq!( - shift_months(base, 12, -1), - NaiveDate::from_ymd_opt(2019, 12, 31).unwrap() - ); - assert_eq!( - shift_months(base, 18, -1), - NaiveDate::from_ymd_opt(2019, 6, 30).unwrap() - ); - } - - #[test] - fn test_shift_months_datetime() { - let date = NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(); - let o_clock = NaiveTime::from_hms_opt(1, 2, 3).unwrap(); - - let base = NaiveDateTime::new(date, o_clock); - - assert_eq!( - shift_months(base, 0, 1).date(), - NaiveDate::from_ymd_opt(2020, 1, 31).unwrap(), - ); - assert_eq!( - shift_months(base, 1, 1).date(), - NaiveDate::from_ymd_opt(2020, 2, 29).unwrap(), - ); - assert_eq!( - shift_months(base, 2, 1).date(), - NaiveDate::from_ymd_opt(2020, 3, 31).unwrap(), - ); - assert_eq!(shift_months(base, 0, 1).time(), o_clock); - assert_eq!(shift_months(base, 1, 1).time(), o_clock); - assert_eq!(shift_months(base, 2, 1).time(), o_clock); - } - - #[test] - fn add_11_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 11, 1); - assert_eq!(format!("{actual:?}").as_str(), "2000-12-01"); - } - - #[test] - fn add_12_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 12, 1); - assert_eq!(format!("{actual:?}").as_str(), "2001-01-01"); - } - - #[test] - fn add_13_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 13, 1); - assert_eq!(format!("{actual:?}").as_str(), "2001-02-01"); - } - - #[test] - fn sub_11_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 11, -1); - assert_eq!(format!("{actual:?}").as_str(), "1999-02-01"); - } - - #[test] - fn sub_12_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 12, -1); - assert_eq!(format!("{actual:?}").as_str(), "1999-01-01"); - } - - #[test] - fn sub_13_months() { - let prior = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); - let actual = shift_months(prior, 13, -1); - assert_eq!(format!("{actual:?}").as_str(), "1998-12-01"); - } -} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 0416086d81f35..e06f947ad5e76 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -20,26 +20,97 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::sync::Arc; -use crate::error::{unqualified_field_not_found, DataFusionError, Result, SchemaError}; -use crate::{field_not_found, Column, OwnedTableReference, TableReference}; +use crate::error::{ + unqualified_field_not_found, DataFusionError, Result, SchemaError, _plan_err, +}; +use crate::{ + field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, +}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -use std::fmt::{Display, Formatter}; -/// A reference-counted reference to a `DFSchema`. +/// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names +/// DFSchema wraps an Arrow schema and adds relation names. +/// +/// The schema may hold the fields across multiple tables. Some fields may be +/// qualified and some unqualified. A qualified field is a field that has a +/// relation name associated with it. +/// +/// Unqualified fields must be unique not only amongst themselves, but also must +/// have a distinct name from any qualified field names. This allows finding a +/// qualified field by name to be possible, so long as there aren't multiple +/// qualified fields with the same name. +/// +/// There is an alias to `Arc` named [DFSchemaRef]. +/// +/// # Creating qualified schemas +/// +/// Use [DFSchema::try_from_qualified_schema] to create a qualified schema from +/// an Arrow schema. +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// let column = Column::from_qualified_name("t1.c1"); +/// assert!(df_schema.has_column(&column)); +/// +/// // Can also access qualified fields with unqualified name, if it's unambiguous +/// let column = Column::from_qualified_name("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Creating unqualified schemas +/// +/// Create an unqualified schema using TryFrom: +/// +/// ```rust +/// use datafusion_common::{DFSchema, Column}; +/// use arrow_schema::{DataType, Field, Schema}; +/// +/// let arrow_schema = Schema::new(vec![ +/// Field::new("c1", DataType::Int32, false), +/// ]); +/// +/// let df_schema = DFSchema::try_from(arrow_schema).unwrap(); +/// let column = Column::new_unqualified("c1"); +/// assert!(df_schema.has_column(&column)); +/// ``` +/// +/// # Converting back to Arrow schema +/// +/// Use the `Into` trait to convert `DFSchema` into an Arrow schema: +/// +/// ```rust +/// use datafusion_common::{DFSchema, DFField}; +/// use arrow_schema::Schema; +/// +/// let df_schema = DFSchema::new(vec![ +/// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), +/// ]).unwrap(); +/// let schema = Schema::from(df_schema); +/// assert_eq!(schema.fields().len(), 1); +/// ``` #[derive(Debug, Clone, PartialEq, Eq)] pub struct DFSchema { /// Fields fields: Vec, /// Additional metadata in form of key value pairs metadata: HashMap, + /// Stores functional dependencies in the schema. + functional_dependencies: FunctionalDependencies, } impl DFSchema { @@ -48,6 +119,7 @@ impl DFSchema { Self { fields: vec![], metadata: HashMap::new(), + functional_dependencies: FunctionalDependencies::empty(), } } @@ -97,10 +169,17 @@ impl DFSchema { )); } } - Ok(Self { fields, metadata }) + Ok(Self { + fields, + metadata, + functional_dependencies: FunctionalDependencies::empty(), + }) } /// Create a `DFSchema` from an Arrow schema and a given qualifier + /// + /// To create a schema from an Arrow schema without a qualifier, use + /// `DFSchema::try_from`. pub fn try_from_qualified_schema<'a>( qualifier: impl Into>, schema: &Schema, @@ -116,6 +195,22 @@ impl DFSchema { ) } + /// Assigns functional dependencies. + pub fn with_functional_dependencies( + mut self, + functional_dependencies: FunctionalDependencies, + ) -> Result { + if functional_dependencies.is_valid(self.fields.len()) { + self.functional_dependencies = functional_dependencies; + Ok(self) + } else { + _plan_err!( + "Invalid functional dependency: {:?}", + functional_dependencies + ) + } + } + /// Create a new schema that contains the fields from this schema followed by the fields /// from the supplied schema. An error will be returned if there are duplicate field names. pub fn join(&self, schema: &DFSchema) -> Result { @@ -169,10 +264,10 @@ impl DFSchema { match &self.fields[i].qualifier { Some(qualifier) => { if (qualifier.to_string() + "." + self.fields[i].name()) == name { - return Err(DataFusionError::Plan(format!( + return _plan_err!( "Fully qualified field name '{name}' was supplied to `index_of` \ which is deprecated. Please use `index_of_column_by_name` instead" - ))); + ); } } None => (), @@ -360,23 +455,44 @@ impl DFSchema { .zip(arrow_schema.fields().iter()) .try_for_each(|(l_field, r_field)| { if !can_cast_types(r_field.data_type(), l_field.data_type()) { - Err(DataFusionError::Plan( - format!("Column {} (type: {}) is not compatible with column {} (type: {})", + _plan_err!("Column {} (type: {}) is not compatible with column {} (type: {})", r_field.name(), r_field.data_type(), l_field.name(), - l_field.data_type()))) + l_field.data_type()) } else { Ok(()) } }) } + /// Returns true if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns false otherwise. + /// + /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type + /// equivalence checking. + pub fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + if self.fields().len() != other.fields().len() { + return false; + } + let self_fields = self.fields().iter(); + let other_fields = other.fields().iter(); + self_fields.zip(other_fields).all(|(f1, f2)| { + f1.qualifier() == f2.qualifier() + && f1.name() == f2.name() + && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) + }) + } + /// Returns true if the two schemas have the same qualified named /// fields with the same data types. Returns false otherwise. /// /// This is a specialized version of Eq that ignores differences /// in nullability and metadata. + /// + /// Use [DFSchema]::logically_equivalent_names_and_types for a weaker + /// logical type checking, which for example would consider a dictionary + /// encoded UTF8 array to be equivalent to a plain UTF8 array. pub fn equivalent_names_and_types(&self, other: &Self) -> bool { if self.fields().len() != other.fields().len() { return false; @@ -390,6 +506,46 @@ impl DFSchema { }) } + /// Checks if two [`DataType`]s are logically equal. This is a notably weaker constraint + /// than datatype_is_semantically_equal in that a Dictionary type is logically + /// equal to a plain V type, but not semantically equal. Dictionary is also + /// logically equal to Dictionary. + fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { + // check nested fields + match (dt1, dt2) { + (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { + v1.as_ref() == v2.as_ref() + } + (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, + (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, + (DataType::List(f1), DataType::List(f2)) + | (DataType::LargeList(f1), DataType::LargeList(f2)) + | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) + | (DataType::Map(f1, _), DataType::Map(f2, _)) => { + Self::field_is_logically_equal(f1, f2) + } + (DataType::Struct(fields1), DataType::Struct(fields2)) => { + let iter1 = fields1.iter(); + let iter2 = fields2.iter(); + fields1.len() == fields2.len() && + // all fields have to be the same + iter1 + .zip(iter2) + .all(|(f1, f2)| Self::field_is_logically_equal(f1, f2)) + } + (DataType::Union(fields1, _), DataType::Union(fields2, _)) => { + let iter1 = fields1.iter(); + let iter2 = fields2.iter(); + fields1.len() == fields2.len() && + // all fields have to be the same + iter1 + .zip(iter2) + .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2)) + } + _ => dt1 == dt2, + } + } + /// Returns true of two [`DataType`]s are semantically equal (same /// name and type), ignoring both metadata and nullability. /// @@ -425,10 +581,23 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal128(_l_precision, _l_scale), + DataType::Decimal128(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal256(_l_precision, _l_scale), + DataType::Decimal256(_r_precision, _r_scale), + ) => true, _ => dt1 == dt2, } } + fn field_is_logically_equal(f1: &Field, f2: &Field) -> bool { + f1.name() == f2.name() + && Self::datatype_is_logically_equal(f1.data_type(), f2.data_type()) + } + fn field_is_semantically_equal(f1: &Field, f2: &Field) -> bool { f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) @@ -471,6 +640,11 @@ impl DFSchema { pub fn metadata(&self) -> &HashMap { &self.metadata } + + /// Get functional dependencies + pub fn functional_dependencies(&self) -> &FunctionalDependencies { + &self.functional_dependencies + } } impl From for Schema { @@ -581,6 +755,9 @@ pub trait ExprSchema: std::fmt::Debug { /// What is the datatype of this column? fn data_type(&self, col: &Column) -> Result<&DataType>; + + /// Returns the column's optional metadata. + fn metadata(&self, col: &Column) -> Result<&HashMap>; } // Implement `ExprSchema` for `Arc` @@ -592,6 +769,10 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type(&self, col: &Column) -> Result<&DataType> { self.as_ref().data_type(col) } + + fn metadata(&self, col: &Column) -> Result<&HashMap> { + ExprSchema::metadata(self.as_ref(), col) + } } impl ExprSchema for DFSchema { @@ -602,6 +783,10 @@ impl ExprSchema for DFSchema { fn data_type(&self, col: &Column) -> Result<&DataType> { Ok(self.field_from_column(col)?.data_type()) } + + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.field_from_column(col)?.metadata()) + } } /// DFField wraps an Arrow field and adds an optional qualifier @@ -661,6 +846,10 @@ impl DFField { self.field.is_nullable() } + pub fn metadata(&self) -> &HashMap { + self.field.metadata() + } + /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { if let Some(qualifier) = &self.qualifier { @@ -708,6 +897,13 @@ impl DFField { self.field = f.into(); self } + + /// Return field with new metadata + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + let f = self.field().as_ref().clone().with_metadata(metadata); + self.field = f.into(); + self + } } impl From for DFField { @@ -725,6 +921,58 @@ impl From for DFField { } } +/// DataFusion-specific extensions to [`Schema`]. +pub trait SchemaExt { + /// This is a specialized version of Eq that ignores differences + /// in nullability and metadata. + /// + /// It works the same as [`DFSchema::equivalent_names_and_types`]. + fn equivalent_names_and_types(&self, other: &Self) -> bool; + + /// Returns true if the two schemas have the same qualified named + /// fields with logically equivalent data types. Returns false otherwise. + /// + /// Use [DFSchema]::equivalent_names_and_types for stricter semantic type + /// equivalence checking. + fn logically_equivalent_names_and_types(&self, other: &Self) -> bool; +} + +impl SchemaExt for Schema { + fn equivalent_names_and_types(&self, other: &Self) -> bool { + if self.fields().len() != other.fields().len() { + return false; + } + + self.fields() + .iter() + .zip(other.fields().iter()) + .all(|(f1, f2)| { + f1.name() == f2.name() + && DFSchema::datatype_is_semantically_equal( + f1.data_type(), + f2.data_type(), + ) + }) + } + + fn logically_equivalent_names_and_types(&self, other: &Self) -> bool { + if self.fields().len() != other.fields().len() { + return false; + } + + self.fields() + .iter() + .zip(other.fields().iter()) + .all(|(f1, f2)| { + f1.name() == f2.name() + && DFSchema::datatype_is_logically_equal( + f1.data_type(), + f2.data_type(), + ) + }) + } +} + #[cfg(test)] mod tests { use crate::assert_contains; @@ -739,8 +987,8 @@ mod tests { // lookup with unqualified name "t1.c0" let err = schema.index_of_column(&col).unwrap_err(); assert_eq!( - err.to_string(), - "Schema error: No field named \"t1.c0\". Valid fields are t1.c0, t1.c1.", + err.strip_backtrace(), + "Schema error: No field named \"t1.c0\". Valid fields are t1.c0, t1.c1." ); Ok(()) } @@ -759,8 +1007,8 @@ mod tests { // lookup with unqualified name "t1.c0" let err = schema.index_of_column(&col).unwrap_err(); assert_eq!( - err.to_string(), - "Schema error: No field named \"t1.c0\". Valid fields are t1.\"CapitalColumn\", t1.\"field.with.period\".", + err.strip_backtrace(), + "Schema error: No field named \"t1.c0\". Valid fields are t1.\"CapitalColumn\", t1.\"field.with.period\"." ); Ok(()) } @@ -843,8 +1091,8 @@ mod tests { let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); assert_eq!( - join.unwrap_err().to_string(), - "Schema error: Schema contains duplicate unqualified field name c0", + join.unwrap_err().strip_backtrace(), + "Schema error: Schema contains duplicate unqualified field name c0" ); Ok(()) } @@ -920,12 +1168,12 @@ mod tests { let col = Column::from_qualified_name("t1.c0"); let err = schema.index_of_column(&col).unwrap_err(); - assert_eq!(err.to_string(), "Schema error: No field named t1.c0."); + assert_eq!(err.strip_backtrace(), "Schema error: No field named t1.c0."); // the same check without qualifier let col = Column::from_name("c0"); let err = schema.index_of_column(&col).err().unwrap(); - assert_eq!("Schema error: No field named c0.", err.to_string()); + assert_eq!(err.strip_backtrace(), "Schema error: No field named c0."); } #[test] @@ -991,7 +1239,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t], fields2: vec![&field1_i16_t], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -999,7 +1248,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t_meta], fields2: vec![&field1_i16_t], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1007,7 +1257,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t], fields2: vec![&field2_i16_t], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1015,7 +1266,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t], fields2: vec![&field1_i32_t], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1023,7 +1275,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t], fields2: vec![&field1_i16_f], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1031,7 +1284,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t], fields2: vec![&field1_i16_t_qualified], - expected: false, + expected_dfschema: false, + expected_arrow: true, } .run(); @@ -1039,7 +1293,8 @@ mod tests { TestCase { fields1: vec![&field2_i16_t, &field1_i16_t], fields2: vec![&field2_i16_t, &field3_i16_t], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1047,7 +1302,8 @@ mod tests { TestCase { fields1: vec![&field1_i16_t, &field2_i16_t], fields2: vec![&field1_i16_t], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1055,7 +1311,8 @@ mod tests { TestCase { fields1: vec![&field_dict_t], fields2: vec![&field_dict_t], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1063,7 +1320,8 @@ mod tests { TestCase { fields1: vec![&field_dict_t], fields2: vec![&field_dict_f], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1071,7 +1329,8 @@ mod tests { TestCase { fields1: vec![&field_dict_t], fields2: vec![&field1_i16_t], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1079,7 +1338,8 @@ mod tests { TestCase { fields1: vec![&list_t], fields2: vec![&list_f], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1087,7 +1347,8 @@ mod tests { TestCase { fields1: vec![&list_t], fields2: vec![&list_f_name], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1095,7 +1356,8 @@ mod tests { TestCase { fields1: vec![&struct_t], fields2: vec![&struct_f], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1103,7 +1365,8 @@ mod tests { TestCase { fields1: vec![&struct_t], fields2: vec![&struct_f_meta], - expected: true, + expected_dfschema: true, + expected_arrow: true, } .run(); @@ -1111,7 +1374,8 @@ mod tests { TestCase { fields1: vec![&struct_t], fields2: vec![&struct_f_type], - expected: false, + expected_dfschema: false, + expected_arrow: false, } .run(); @@ -1119,7 +1383,8 @@ mod tests { struct TestCase<'a> { fields1: Vec<&'a DFField>, fields2: Vec<&'a DFField>, - expected: bool, + expected_dfschema: bool, + expected_arrow: bool, } impl<'a> TestCase<'a> { @@ -1129,13 +1394,25 @@ mod tests { let schema2 = to_df_schema(self.fields2); assert_eq!( schema1.equivalent_names_and_types(&schema2), - self.expected, + self.expected_dfschema, "Comparison did not match expected: {}\n\n\ schema1:\n\n{:#?}\n\nschema2:\n\n{:#?}", - self.expected, + self.expected_dfschema, schema1, schema2 ); + + let arrow_schema1 = Schema::from(schema1); + let arrow_schema2 = Schema::from(schema2); + assert_eq!( + arrow_schema1.equivalent_names_and_types(&arrow_schema2), + self.expected_arrow, + "Comparison did not match expected: {}\n\n\ + arrow schema1:\n\n{:#?}\n\n arrow schema2:\n\n{:#?}", + self.expected_arrow, + arrow_schema1, + arrow_schema2 + ); } } @@ -1206,8 +1483,8 @@ mod tests { DFSchema::new_with_metadata([a, b].to_vec(), HashMap::new()).unwrap(), ); let schema: Schema = df_schema.as_ref().clone().into(); - let a_df = df_schema.fields.get(0).unwrap().field(); - let a_arrow = schema.fields.get(0).unwrap(); + let a_df = df_schema.fields.first().unwrap().field(); + let a_arrow = schema.fields.first().unwrap(); assert_eq!(a_df.metadata(), a_arrow.metadata()) } diff --git a/datafusion/common/src/display/graphviz.rs b/datafusion/common/src/display/graphviz.rs new file mode 100644 index 0000000000000..f84490cd3ea4e --- /dev/null +++ b/datafusion/common/src/display/graphviz.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic related to creating DOT language graphs. + +use std::fmt; + +#[derive(Default)] +pub struct GraphvizBuilder { + id_gen: usize, +} + +impl GraphvizBuilder { + // Generate next id in graphviz. + pub fn next_id(&mut self) -> usize { + self.id_gen += 1; + self.id_gen + } + + // Write out the start of whole graph. + pub fn start_graph(&mut self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + r#" +// Begin DataFusion GraphViz Plan, +// display it online here: https://dreampuf.github.io/GraphvizOnline +"# + )?; + writeln!(f, "digraph {{") + } + + pub fn end_graph(&mut self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "}}")?; + writeln!(f, "// End DataFusion GraphViz Plan") + } + + // write out the start of the subgraph cluster + pub fn start_cluster(&mut self, f: &mut fmt::Formatter, title: &str) -> fmt::Result { + writeln!(f, " subgraph cluster_{}", self.next_id())?; + writeln!(f, " {{")?; + writeln!(f, " graph[label={}]", Self::quoted(title)) + } + + // write out the end of the subgraph cluster + pub fn end_cluster(&mut self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, " }}") + } + + /// makes a quoted string suitable for inclusion in a graphviz chart + pub fn quoted(label: &str) -> String { + let label = label.replace('"', "_"); + format!("\"{label}\"") + } + + pub fn add_node( + &self, + f: &mut fmt::Formatter, + id: usize, + label: &str, + tooltip: Option<&str>, + ) -> fmt::Result { + if let Some(tooltip) = tooltip { + writeln!( + f, + " {}[shape=box label={}, tooltip={}]", + id, + GraphvizBuilder::quoted(label), + GraphvizBuilder::quoted(tooltip), + ) + } else { + writeln!( + f, + " {}[shape=box label={}]", + id, + GraphvizBuilder::quoted(label), + ) + } + } + + pub fn add_edge( + &self, + f: &mut fmt::Formatter, + from_id: usize, + to_id: usize, + ) -> fmt::Result { + writeln!( + f, + " {from_id} -> {to_id} [arrowhead=none, arrowtail=normal, dir=back]" + ) + } +} diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs new file mode 100644 index 0000000000000..4d1d48bf9fcc7 --- /dev/null +++ b/datafusion/common/src/display/mod.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types for plan display + +mod graphviz; +pub use graphviz::*; + +use std::{ + fmt::{self, Display, Formatter}, + sync::Arc, +}; + +/// Represents which type of plan, when storing multiple +/// for use in EXPLAIN plans +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PlanType { + /// The initial LogicalPlan provided to DataFusion + InitialLogicalPlan, + /// The LogicalPlan which results from applying an analyzer pass + AnalyzedLogicalPlan { + /// The name of the analyzer which produced this plan + analyzer_name: String, + }, + /// The LogicalPlan after all analyzer passes have been applied + FinalAnalyzedLogicalPlan, + /// The LogicalPlan which results from applying an optimizer pass + OptimizedLogicalPlan { + /// The name of the optimizer which produced this plan + optimizer_name: String, + }, + /// The final, fully optimized LogicalPlan that was converted to a physical plan + FinalLogicalPlan, + /// The initial physical plan, prepared for execution + InitialPhysicalPlan, + /// The initial physical plan with stats, prepared for execution + InitialPhysicalPlanWithStats, + /// The ExecutionPlan which results from applying an optimizer pass + OptimizedPhysicalPlan { + /// The name of the optimizer which produced this plan + optimizer_name: String, + }, + /// The final, fully optimized physical which would be executed + FinalPhysicalPlan, + /// The final with stats, fully optimized physical which would be executed + FinalPhysicalPlanWithStats, +} + +impl Display for PlanType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), + PlanType::AnalyzedLogicalPlan { analyzer_name } => { + write!(f, "logical_plan after {analyzer_name}") + } + PlanType::FinalAnalyzedLogicalPlan => write!(f, "analyzed_logical_plan"), + PlanType::OptimizedLogicalPlan { optimizer_name } => { + write!(f, "logical_plan after {optimizer_name}") + } + PlanType::FinalLogicalPlan => write!(f, "logical_plan"), + PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), + PlanType::InitialPhysicalPlanWithStats => { + write!(f, "initial_physical_plan_with_stats") + } + PlanType::OptimizedPhysicalPlan { optimizer_name } => { + write!(f, "physical_plan after {optimizer_name}") + } + PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), + PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), + } + } +} + +/// Represents some sort of execution plan, in String form +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StringifiedPlan { + /// An identifier of what type of plan this string represents + pub plan_type: PlanType, + /// The string representation of the plan + pub plan: Arc, +} + +impl StringifiedPlan { + /// Create a new Stringified plan of `plan_type` with string + /// representation `plan` + pub fn new(plan_type: PlanType, plan: impl Into) -> Self { + StringifiedPlan { + plan_type, + plan: Arc::new(plan.into()), + } + } + + /// Returns true if this plan should be displayed. Generally + /// `verbose_mode = true` will display all available plans + pub fn should_display(&self, verbose_mode: bool) -> bool { + match self.plan_type { + PlanType::FinalLogicalPlan | PlanType::FinalPhysicalPlan => true, + _ => verbose_mode, + } + } +} + +/// Trait for something that can be formatted as a stringified plan +pub trait ToStringifiedPlan { + /// Create a stringified plan with the specified type + fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan; +} diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 2074d35fb2f3f..4ae30ae86cddc 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -16,6 +16,8 @@ // under the License. //! DataFusion error types +#[cfg(feature = "backtrace")] +use std::backtrace::{Backtrace, BacktraceStatus}; use std::error::Error; use std::fmt::{Display, Formatter}; @@ -65,12 +67,17 @@ pub enum DataFusionError { NotImplemented(String), /// Error returned as a consequence of an error in DataFusion. /// This error should not happen in normal usage of DataFusion. - // DataFusions has internal invariants that we are unable to ask the compiler to check for us. - // This error is raised when one of those invariants is not verified during execution. + /// + /// DataFusions has internal invariants that the compiler is not + /// always able to check. This error is raised when one of those + /// invariants is not verified during execution. Internal(String), /// This error happens whenever a plan is not valid. Examples include /// impossible casts. Plan(String), + /// This error happens when an invalid or unsupported option is passed + /// in a SQL statement + Configuration(String), /// This error happens with schema-related errors, such as schema inference not possible /// and non-unique column names. SchemaError(SchemaError), @@ -97,18 +104,6 @@ macro_rules! context { }; } -#[macro_export] -macro_rules! plan_err { - ($desc:expr) => { - Err(datafusion_common::DataFusionError::Plan(format!( - "{} at {}:{}", - $desc, - file!(), - line!() - ))) - }; -} - /// Schema-related errors #[derive(Debug)] pub enum SchemaError { @@ -285,7 +280,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { - DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {desc}"), + DataFusionError::ArrowError(ref desc) => { + write!(f, "Arrow error: {desc}") + } #[cfg(feature = "parquet")] DataFusionError::ParquetError(ref desc) => { write!(f, "Parquet error: {desc}") @@ -294,15 +291,20 @@ impl Display for DataFusionError { DataFusionError::AvroError(ref desc) => { write!(f, "Avro error: {desc}") } - DataFusionError::IoError(ref desc) => write!(f, "IO error: {desc}"), + DataFusionError::IoError(ref desc) => { + write!(f, "IO error: {desc}") + } DataFusionError::SQL(ref desc) => { write!(f, "SQL error: {desc:?}") } + DataFusionError::Configuration(ref desc) => { + write!(f, "Invalid or Unsupported Configuration: {desc}") + } DataFusionError::NotImplemented(ref desc) => { write!(f, "This feature is not implemented: {desc}") } DataFusionError::Internal(ref desc) => { - write!(f, "Internal error: {desc}. This was likely caused by a bug in DataFusion's \ + write!(f, "Internal error: {desc}.\nThis was likely caused by a bug in DataFusion's \ code and we would welcome that you file an bug report in our issue tracker") } DataFusionError::Plan(ref desc) => { @@ -348,6 +350,7 @@ impl Error for DataFusionError { DataFusionError::SQL(e) => Some(e), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, + DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, DataFusionError::SchemaError(e) => Some(e), DataFusionError::Execution(_) => None, @@ -366,6 +369,8 @@ impl From for io::Error { } impl DataFusionError { + const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; + /// Get deepest underlying [`DataFusionError`] /// /// [`DataFusionError`]s sometimes form a chain, such as `DataFusionError::ArrowError()` in order to conform @@ -407,8 +412,126 @@ impl DataFusionError { pub fn context(self, description: impl Into) -> Self { Self::Context(description.into(), Box::new(self)) } + + pub fn strip_backtrace(&self) -> String { + self.to_string() + .split(Self::BACK_TRACE_SEP) + .collect::>() + .first() + .unwrap_or(&"") + .to_string() + } + + /// To enable optional rust backtrace in DataFusion: + /// - [`Setup Env Variables`] + /// - Enable `backtrace` cargo feature + /// + /// Example: + /// cargo build --features 'backtrace' + /// RUST_BACKTRACE=1 ./app + #[inline(always)] + pub fn get_back_trace() -> String { + #[cfg(feature = "backtrace")] + { + let back_trace = Backtrace::capture(); + if back_trace.status() == BacktraceStatus::Captured { + return format!("{}{}", Self::BACK_TRACE_SEP, back_trace); + } + + "".to_owned() + } + + #[cfg(not(feature = "backtrace"))] + "".to_owned() + } +} + +/// Unwrap an `Option` if possible. Otherwise return an `DataFusionError::Internal`. +/// In normal usage of DataFusion the unwrap should always succeed. +/// +/// Example: `let values = unwrap_or_internal_err!(values)` +#[macro_export] +macro_rules! unwrap_or_internal_err { + ($Value: ident) => { + $Value.ok_or_else(|| { + DataFusionError::Internal(format!( + "{} should not be None", + stringify!($Value) + )) + })? + }; +} + +macro_rules! with_dollar_sign { + ($($body:tt)*) => { + macro_rules! __with_dollar_sign { $($body)* } + __with_dollar_sign!($); + } +} + +/// Add a macros for concise DataFusionError::* errors declaration +/// supports placeholders the same way as `format!` +/// Examples: +/// plan_err!("Error") +/// plan_err!("Error {}", val) +/// plan_err!("Error {:?}", val) +/// plan_err!("Error {val}") +/// plan_err!("Error {val:?}") +/// +/// `NAME_ERR` - macro name for wrapping Err(DataFusionError::*) +/// `NAME_DF_ERR` - macro name for wrapping DataFusionError::*. Needed to keep backtrace opportunity +/// in construction where DataFusionError::* used directly, like `map_err`, `ok_or_else`, etc +macro_rules! make_error { + ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { + with_dollar_sign! { + ($d:tt) => { + /// Macro wraps `$ERR` to add backtrace feature + #[macro_export] + macro_rules! $NAME_DF_ERR { + ($d($d args:expr),*) => { + DataFusionError::$ERR(format!("{}{}", format!($d($d args),*), DataFusionError::get_back_trace()).into()) + } + } + + /// Macro wraps Err(`$ERR`) to add backtrace feature + #[macro_export] + macro_rules! $NAME_ERR { + ($d($d args:expr),*) => { + Err(DataFusionError::$ERR(format!("{}{}", format!($d($d args),*), DataFusionError::get_back_trace()).into())) + } + } + } + } + }; } +// Exposes a macro to create `DataFusionError::Plan` +make_error!(plan_err, plan_datafusion_err, Plan); + +// Exposes a macro to create `DataFusionError::Internal` +make_error!(internal_err, internal_datafusion_err, Internal); + +// Exposes a macro to create `DataFusionError::NotImplemented` +make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); + +// Exposes a macro to create `DataFusionError::Execution` +make_error!(exec_err, exec_datafusion_err, Execution); + +// Exposes a macro to create `DataFusionError::SQL` +#[macro_export] +macro_rules! sql_err { + ($ERR:expr) => { + Err(DataFusionError::SQL($ERR)) + }; +} + +// To avoid compiler error when using macro in the same crate: +// macros from the current crate cannot be referred to by absolute paths +pub use exec_err as _exec_err; +pub use internal_err as _internal_err; +pub use not_impl_err as _not_impl_err; +pub use plan_err as _plan_err; + #[cfg(test)] mod test { use std::sync::Arc; @@ -417,18 +540,50 @@ mod test { use arrow::error::ArrowError; #[test] - fn arrow_error_to_datafusion() { + fn datafusion_error_to_arrow() { let res = return_arrow_error().unwrap_err(); + assert!(res + .to_string() + .starts_with("External error: Error during planning: foo")); + } + + #[test] + fn arrow_error_to_datafusion() { + let res = return_datafusion_error().unwrap_err(); + assert_eq!(res.strip_backtrace(), "Arrow error: Schema error: bar"); + } + + // RUST_BACKTRACE=1 cargo test --features backtrace --package datafusion-common --lib -- error::test::test_backtrace + #[cfg(feature = "backtrace")] + #[test] + #[allow(clippy::unnecessary_literal_unwrap)] + fn test_enabled_backtrace() { + let res: Result<(), DataFusionError> = plan_err!("Err"); + let err = res.unwrap_err().to_string(); + assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); assert_eq!( - res.to_string(), - "External error: Error during planning: foo" + err.split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .first() + .unwrap(), + &"Error during planning: Err" ); + assert!(!err + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty()); } + #[cfg(not(feature = "backtrace"))] #[test] - fn datafusion_error_to_arrow() { - let res = return_datafusion_error().unwrap_err(); - assert_eq!(res.to_string(), "Arrow error: Schema error: bar"); + #[allow(clippy::unnecessary_literal_unwrap)] + fn test_disabled_backtrace() { + let res: Result<(), DataFusionError> = plan_err!("Err"); + let res = res.unwrap_err().to_string(); + assert!(!res.contains(DataFusionError::BACK_TRACE_SEP)); + assert_eq!(res, "Error during planning: Err"); } #[test] @@ -486,6 +641,46 @@ mod test { ); } + #[test] + #[allow(clippy::unnecessary_literal_unwrap)] + fn test_make_error_parse_input() { + let res: Result<(), DataFusionError> = plan_err!("Err"); + let res = res.unwrap_err(); + assert_eq!(res.strip_backtrace(), "Error during planning: Err"); + + let extra1 = "extra1"; + let extra2 = "extra2"; + + let res: Result<(), DataFusionError> = plan_err!("Err {} {}", extra1, extra2); + let res = res.unwrap_err(); + assert_eq!( + res.strip_backtrace(), + "Error during planning: Err extra1 extra2" + ); + + let res: Result<(), DataFusionError> = + plan_err!("Err {:?} {:#?}", extra1, extra2); + let res = res.unwrap_err(); + assert_eq!( + res.strip_backtrace(), + "Error during planning: Err \"extra1\" \"extra2\"" + ); + + let res: Result<(), DataFusionError> = plan_err!("Err {extra1} {extra2}"); + let res = res.unwrap_err(); + assert_eq!( + res.strip_backtrace(), + "Error during planning: Err extra1 extra2" + ); + + let res: Result<(), DataFusionError> = plan_err!("Err {extra1:?} {extra2:#?}"); + let res = res.unwrap_err(); + assert_eq!( + res.strip_backtrace(), + "Error during planning: Err \"extra1\" \"extra2\"" + ); + } + /// Model what happens when implementing SendableRecordBatchStream: /// DataFusion code needs to return an ArrowError fn return_arrow_error() -> arrow::error::Result<()> { @@ -504,30 +699,7 @@ mod test { let e = e.find_root(); // DataFusionError does not implement Eq, so we use a string comparison + some cheap "same variant" test instead - assert_eq!(e.to_string(), exp.to_string(),); + assert_eq!(e.strip_backtrace(), exp.strip_backtrace()); assert_eq!(std::mem::discriminant(e), std::mem::discriminant(&exp),) } } - -#[macro_export] -macro_rules! internal_err { - ($($arg:tt)*) => { - Err(DataFusionError::Internal(format!($($arg)*))) - }; -} - -/// Unwrap an `Option` if possible. Otherwise return an `DataFusionError::Internal`. -/// In normal usage of DataFusion the unwrap should always succeed. -/// -/// Example: `let values = unwrap_or_internal_err!(values)` -#[macro_export] -macro_rules! unwrap_or_internal_err { - ($Value: ident) => { - $Value.ok_or_else(|| { - DataFusionError::Internal(format!( - "{} should not be None", - stringify!($Value) - )) - })? - }; -} diff --git a/datafusion/core/tests/sqllogictests/src/utils.rs b/datafusion/common/src/file_options/arrow_writer.rs similarity index 52% rename from datafusion/core/tests/sqllogictests/src/utils.rs rename to datafusion/common/src/file_options/arrow_writer.rs index 4d064a76e2ad2..a30e6d800e20b 100644 --- a/datafusion/core/tests/sqllogictests/src/utils.rs +++ b/datafusion/common/src/file_options/arrow_writer.rs @@ -15,26 +15,22 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::{ - array::{Array, Decimal128Builder}, - datatypes::{Field, Schema}, - record_batch::RecordBatch, +//! Options related to how Arrow files should be written + +use crate::{ + config::ConfigOptions, + error::{DataFusionError, Result}, }; -use std::sync::Arc; -// TODO: move this to datafusion::test_utils? -pub fn make_decimal() -> RecordBatch { - let mut decimal_builder = Decimal128Builder::with_capacity(20); - for i in 110000..110010 { - decimal_builder.append_value(i as i128); - } - for i in 100000..100010 { - decimal_builder.append_value(-i as i128); +use super::StatementOptions; + +#[derive(Clone, Debug)] +pub struct ArrowWriterOptions {} + +impl TryFrom<(&ConfigOptions, &StatementOptions)> for ArrowWriterOptions { + type Error = DataFusionError; + + fn try_from(_value: (&ConfigOptions, &StatementOptions)) -> Result { + Ok(ArrowWriterOptions {}) } - let array = decimal_builder - .finish() - .with_precision_and_scale(10, 3) - .unwrap(); - let schema = Schema::new(vec![Field::new("c1", array.data_type().clone(), true)]); - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() } diff --git a/datafusion/common/src/file_options/avro_writer.rs b/datafusion/common/src/file_options/avro_writer.rs new file mode 100644 index 0000000000000..2e3a647058426 --- /dev/null +++ b/datafusion/common/src/file_options/avro_writer.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Options related to how avro files should be written + +use crate::{ + config::ConfigOptions, + error::{DataFusionError, Result}, +}; + +use super::StatementOptions; + +#[derive(Clone, Debug)] +pub struct AvroWriterOptions {} + +impl TryFrom<(&ConfigOptions, &StatementOptions)> for AvroWriterOptions { + type Error = DataFusionError; + + fn try_from(_value: (&ConfigOptions, &StatementOptions)) -> Result { + Ok(AvroWriterOptions {}) + } +} diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs new file mode 100644 index 0000000000000..d6046f0219dd3 --- /dev/null +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Options related to how csv files should be written + +use std::str::FromStr; + +use arrow::csv::WriterBuilder; + +use crate::{ + config::ConfigOptions, + error::{DataFusionError, Result}, + parsers::CompressionTypeVariant, +}; + +use super::StatementOptions; + +/// Options for writing CSV files +#[derive(Clone, Debug)] +pub struct CsvWriterOptions { + /// Struct from the arrow crate which contains all csv writing related settings + pub writer_options: WriterBuilder, + /// Compression to apply after ArrowWriter serializes RecordBatches. + /// This compression is applied by DataFusion not the ArrowWriter itself. + pub compression: CompressionTypeVariant, +} + +impl CsvWriterOptions { + pub fn new( + writer_options: WriterBuilder, + compression: CompressionTypeVariant, + ) -> Self { + Self { + writer_options, + compression, + } + } +} + +impl TryFrom<(&ConfigOptions, &StatementOptions)> for CsvWriterOptions { + type Error = DataFusionError; + + fn try_from(value: (&ConfigOptions, &StatementOptions)) -> Result { + let _configs = value.0; + let statement_options = value.1; + let mut builder = WriterBuilder::default(); + let mut compression = CompressionTypeVariant::UNCOMPRESSED; + for (option, value) in &statement_options.options { + builder = match option.to_lowercase().as_str(){ + "header" => { + let has_header = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as bool as required for {option}!")))?; + builder.with_header(has_header) + }, + "date_format" => builder.with_date_format(value.to_owned()), + "datetime_format" => builder.with_datetime_format(value.to_owned()), + "timestamp_format" => builder.with_timestamp_format(value.to_owned()), + "time_format" => builder.with_time_format(value.to_owned()), + "rfc3339" => builder, // No-op + "null_value" => builder.with_null(value.to_owned()), + "compression" => { + compression = CompressionTypeVariant::from_str(value.replace('\'', "").as_str())?; + builder + }, + "delimiter" => { + // Ignore string literal single quotes passed from sql parsing + let value = value.replace('\'', ""); + let chars: Vec = value.chars().collect(); + if chars.len()>1{ + return Err(DataFusionError::Configuration(format!( + "CSV Delimiter Option must be a single char, got: {}", value + ))) + } + builder.with_delimiter(chars[0].try_into().map_err(|_| { + DataFusionError::Internal( + "Unable to convert CSV delimiter into u8".into(), + ) + })?) + }, + "quote" | "escape" => { + // https://github.com/apache/arrow-rs/issues/5146 + // These two attributes are only available when reading csv files. + // To avoid error + builder + }, + _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for CSV format!"))) + } + } + Ok(CsvWriterOptions { + writer_options: builder, + compression, + }) + } +} diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs new file mode 100644 index 0000000000000..a07f2e0cb847b --- /dev/null +++ b/datafusion/common/src/file_options/file_type.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! File type abstraction + +use crate::error::{DataFusionError, Result}; + +use core::fmt; +use std::fmt::Display; +use std::str::FromStr; + +/// The default file extension of arrow files +pub const DEFAULT_ARROW_EXTENSION: &str = ".arrow"; +/// The default file extension of avro files +pub const DEFAULT_AVRO_EXTENSION: &str = ".avro"; +/// The default file extension of csv files +pub const DEFAULT_CSV_EXTENSION: &str = ".csv"; +/// The default file extension of json files +pub const DEFAULT_JSON_EXTENSION: &str = ".json"; +/// The default file extension of parquet files +pub const DEFAULT_PARQUET_EXTENSION: &str = ".parquet"; + +/// Define each `FileType`/`FileCompressionType`'s extension +pub trait GetExt { + /// File extension getter + fn get_ext(&self) -> String; +} + +/// Readable file type +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum FileType { + /// Apache Arrow file + ARROW, + /// Apache Avro file + AVRO, + /// Apache Parquet file + #[cfg(feature = "parquet")] + PARQUET, + /// CSV file + CSV, + /// JSON file + JSON, +} + +impl GetExt for FileType { + fn get_ext(&self) -> String { + match self { + FileType::ARROW => DEFAULT_ARROW_EXTENSION.to_owned(), + FileType::AVRO => DEFAULT_AVRO_EXTENSION.to_owned(), + #[cfg(feature = "parquet")] + FileType::PARQUET => DEFAULT_PARQUET_EXTENSION.to_owned(), + FileType::CSV => DEFAULT_CSV_EXTENSION.to_owned(), + FileType::JSON => DEFAULT_JSON_EXTENSION.to_owned(), + } + } +} + +impl Display for FileType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let out = match self { + FileType::CSV => "csv", + FileType::JSON => "json", + #[cfg(feature = "parquet")] + FileType::PARQUET => "parquet", + FileType::AVRO => "avro", + FileType::ARROW => "arrow", + }; + write!(f, "{}", out) + } +} + +impl FromStr for FileType { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + let s = s.to_uppercase(); + match s.as_str() { + "ARROW" => Ok(FileType::ARROW), + "AVRO" => Ok(FileType::AVRO), + #[cfg(feature = "parquet")] + "PARQUET" => Ok(FileType::PARQUET), + "CSV" => Ok(FileType::CSV), + "JSON" | "NDJSON" => Ok(FileType::JSON), + _ => Err(DataFusionError::NotImplemented(format!( + "Unknown FileType: {s}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use crate::error::DataFusionError; + use crate::file_options::FileType; + use std::str::FromStr; + + #[test] + fn from_str() { + for (ext, file_type) in [ + ("csv", FileType::CSV), + ("CSV", FileType::CSV), + ("json", FileType::JSON), + ("JSON", FileType::JSON), + ("avro", FileType::AVRO), + ("AVRO", FileType::AVRO), + ("parquet", FileType::PARQUET), + ("PARQUET", FileType::PARQUET), + ] { + assert_eq!(FileType::from_str(ext).unwrap(), file_type); + } + + assert!(matches!( + FileType::from_str("Unknown"), + Err(DataFusionError::NotImplemented(_)) + )); + } +} diff --git a/datafusion/common/src/file_options/json_writer.rs b/datafusion/common/src/file_options/json_writer.rs new file mode 100644 index 0000000000000..7f988016c69df --- /dev/null +++ b/datafusion/common/src/file_options/json_writer.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Options related to how json files should be written + +use std::str::FromStr; + +use crate::{ + config::ConfigOptions, + error::{DataFusionError, Result}, + parsers::CompressionTypeVariant, +}; + +use super::StatementOptions; + +/// Options for writing JSON files +#[derive(Clone, Debug)] +pub struct JsonWriterOptions { + pub compression: CompressionTypeVariant, +} + +impl JsonWriterOptions { + pub fn new(compression: CompressionTypeVariant) -> Self { + Self { compression } + } +} + +impl TryFrom<(&ConfigOptions, &StatementOptions)> for JsonWriterOptions { + type Error = DataFusionError; + + fn try_from(value: (&ConfigOptions, &StatementOptions)) -> Result { + let _configs = value.0; + let statement_options = value.1; + let mut compression = CompressionTypeVariant::UNCOMPRESSED; + for (option, value) in &statement_options.options { + match option.to_lowercase().as_str(){ + "compression" => { + compression = CompressionTypeVariant::from_str(value.replace('\'', "").as_str())?; + }, + _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for JSON format!"))) + } + } + Ok(JsonWriterOptions { compression }) + } +} diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs new file mode 100644 index 0000000000000..b7c1341e30460 --- /dev/null +++ b/datafusion/common/src/file_options/mod.rs @@ -0,0 +1,548 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Options related to how files should be written + +pub mod arrow_writer; +pub mod avro_writer; +pub mod csv_writer; +pub mod file_type; +pub mod json_writer; +#[cfg(feature = "parquet")] +pub mod parquet_writer; +pub(crate) mod parse_utils; + +use std::{ + collections::HashMap, + fmt::{self, Display}, + path::Path, + str::FromStr, +}; + +use crate::{ + config::ConfigOptions, file_options::parse_utils::parse_boolean_string, + DataFusionError, FileType, Result, +}; + +#[cfg(feature = "parquet")] +use self::parquet_writer::ParquetWriterOptions; + +use self::{ + arrow_writer::ArrowWriterOptions, avro_writer::AvroWriterOptions, + csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions, +}; + +/// Represents a single arbitrary setting in a +/// [StatementOptions] where OptionTuple.0 determines +/// the specific setting to be modified and OptionTuple.1 +/// determines the value which should be applied +pub type OptionTuple = (String, String); + +/// Represents arbitrary tuples of options passed as String +/// tuples from SQL statements. As in the following statement: +/// COPY ... TO ... (setting1 value1, setting2 value2, ...) +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct StatementOptions { + options: Vec, +} + +/// Useful for conversion from external tables which use Hashmap +impl From<&HashMap> for StatementOptions { + fn from(value: &HashMap) -> Self { + Self { + options: value + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect::>(), + } + } +} + +impl StatementOptions { + pub fn new(options: Vec) -> Self { + Self { options } + } + + pub fn into_inner(self) -> Vec { + self.options + } + + /// Scans for option and if it exists removes it and attempts to parse as a boolean + /// Returns none if it does not exist. + pub fn take_bool_option(&mut self, find: &str) -> Result> { + let maybe_option = self.scan_and_remove_option(find); + maybe_option + .map(|(_, v)| parse_boolean_string(find, v)) + .transpose() + } + + /// Scans for option and if it exists removes it and returns it + /// Returns none if it does not exist + pub fn take_str_option(&mut self, find: &str) -> Option { + let maybe_option = self.scan_and_remove_option(find); + maybe_option.map(|(_, v)| v) + } + + /// Infers the file_type given a target and arbitrary options. + /// If the options contain an explicit "format" option, that will be used. + /// Otherwise, attempt to infer file_type from the extension of target. + /// Finally, return an error if unable to determine the file_type + /// If found, format is removed from the options list. + pub fn try_infer_file_type(&mut self, target: &str) -> Result { + let explicit_format = self.scan_and_remove_option("format"); + let format = match explicit_format { + Some(s) => FileType::from_str(s.1.as_str()), + None => { + // try to infer file format from file extension + let extension: &str = &Path::new(target) + .extension() + .ok_or(DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension!" + .to_string(), + ))? + .to_str() + .ok_or(DataFusionError::Configuration( + "Format not explicitly set and failed to parse file extension!" + .to_string(), + ))? + .to_lowercase(); + + FileType::from_str(extension) + } + }?; + + Ok(format) + } + + /// Finds an option in StatementOptions if exists, removes and returns it + /// along with the vec of remaining options. + fn scan_and_remove_option(&mut self, find: &str) -> Option { + let idx = self + .options + .iter() + .position(|(k, _)| k.to_lowercase() == find.to_lowercase()); + match idx { + Some(i) => Some(self.options.swap_remove(i)), + None => None, + } + } +} + +/// This type contains all options needed to initialize a particular +/// RecordBatchWriter type. Each element in the enum contains a thin wrapper +/// around a "writer builder" type (e.g. arrow::csv::WriterBuilder) +/// plus any DataFusion specific writing options (e.g. CSV compression) +#[derive(Clone, Debug)] +pub enum FileTypeWriterOptions { + #[cfg(feature = "parquet")] + Parquet(ParquetWriterOptions), + CSV(CsvWriterOptions), + JSON(JsonWriterOptions), + Avro(AvroWriterOptions), + Arrow(ArrowWriterOptions), +} + +impl FileTypeWriterOptions { + /// Constructs a FileTypeWriterOptions given a FileType to be written + /// and arbitrary String tuple options. May return an error if any + /// string setting is unrecognized or unsupported. + pub fn build( + file_type: &FileType, + config_defaults: &ConfigOptions, + statement_options: &StatementOptions, + ) -> Result { + let options = (config_defaults, statement_options); + + let file_type_write_options = match file_type { + #[cfg(feature = "parquet")] + FileType::PARQUET => { + FileTypeWriterOptions::Parquet(ParquetWriterOptions::try_from(options)?) + } + FileType::CSV => { + FileTypeWriterOptions::CSV(CsvWriterOptions::try_from(options)?) + } + FileType::JSON => { + FileTypeWriterOptions::JSON(JsonWriterOptions::try_from(options)?) + } + FileType::AVRO => { + FileTypeWriterOptions::Avro(AvroWriterOptions::try_from(options)?) + } + FileType::ARROW => { + FileTypeWriterOptions::Arrow(ArrowWriterOptions::try_from(options)?) + } + }; + + Ok(file_type_write_options) + } + + /// Constructs a FileTypeWriterOptions from session defaults only. + pub fn build_default( + file_type: &FileType, + config_defaults: &ConfigOptions, + ) -> Result { + let empty_statement = StatementOptions::new(vec![]); + let options = (config_defaults, &empty_statement); + + let file_type_write_options = match file_type { + #[cfg(feature = "parquet")] + FileType::PARQUET => { + FileTypeWriterOptions::Parquet(ParquetWriterOptions::try_from(options)?) + } + FileType::CSV => { + FileTypeWriterOptions::CSV(CsvWriterOptions::try_from(options)?) + } + FileType::JSON => { + FileTypeWriterOptions::JSON(JsonWriterOptions::try_from(options)?) + } + FileType::AVRO => { + FileTypeWriterOptions::Avro(AvroWriterOptions::try_from(options)?) + } + FileType::ARROW => { + FileTypeWriterOptions::Arrow(ArrowWriterOptions::try_from(options)?) + } + }; + + Ok(file_type_write_options) + } + + /// Tries to extract ParquetWriterOptions from this FileTypeWriterOptions enum. + /// Returns an error if a different type from parquet is set. + #[cfg(feature = "parquet")] + pub fn try_into_parquet(&self) -> Result<&ParquetWriterOptions> { + match self { + FileTypeWriterOptions::Parquet(opt) => Ok(opt), + _ => Err(DataFusionError::Internal(format!( + "Expected parquet options but found options for: {}", + self + ))), + } + } + + /// Tries to extract CsvWriterOptions from this FileTypeWriterOptions enum. + /// Returns an error if a different type from csv is set. + pub fn try_into_csv(&self) -> Result<&CsvWriterOptions> { + match self { + FileTypeWriterOptions::CSV(opt) => Ok(opt), + _ => Err(DataFusionError::Internal(format!( + "Expected csv options but found options for {}", + self + ))), + } + } + + /// Tries to extract JsonWriterOptions from this FileTypeWriterOptions enum. + /// Returns an error if a different type from json is set. + pub fn try_into_json(&self) -> Result<&JsonWriterOptions> { + match self { + FileTypeWriterOptions::JSON(opt) => Ok(opt), + _ => Err(DataFusionError::Internal(format!( + "Expected json options but found options for {}", + self, + ))), + } + } + + /// Tries to extract AvroWriterOptions from this FileTypeWriterOptions enum. + /// Returns an error if a different type from avro is set. + pub fn try_into_avro(&self) -> Result<&AvroWriterOptions> { + match self { + FileTypeWriterOptions::Avro(opt) => Ok(opt), + _ => Err(DataFusionError::Internal(format!( + "Expected avro options but found options for {}!", + self + ))), + } + } + + /// Tries to extract ArrowWriterOptions from this FileTypeWriterOptions enum. + /// Returns an error if a different type from arrow is set. + pub fn try_into_arrow(&self) -> Result<&ArrowWriterOptions> { + match self { + FileTypeWriterOptions::Arrow(opt) => Ok(opt), + _ => Err(DataFusionError::Internal(format!( + "Expected arrow options but found options for {}", + self + ))), + } + } +} + +impl Display for FileTypeWriterOptions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + FileTypeWriterOptions::Arrow(_) => "ArrowWriterOptions", + FileTypeWriterOptions::Avro(_) => "AvroWriterOptions", + FileTypeWriterOptions::CSV(_) => "CsvWriterOptions", + FileTypeWriterOptions::JSON(_) => "JsonWriterOptions", + #[cfg(feature = "parquet")] + FileTypeWriterOptions::Parquet(_) => "ParquetWriterOptions", + }; + write!(f, "{}", name) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use parquet::{ + basic::{Compression, Encoding, ZstdLevel}, + file::properties::{EnabledStatistics, WriterVersion}, + schema::types::ColumnPath, + }; + + use crate::{ + config::ConfigOptions, + file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, + parsers::CompressionTypeVariant, + }; + + use crate::Result; + + use super::{parquet_writer::ParquetWriterOptions, StatementOptions}; + + #[test] + fn test_writeroptions_parquet_from_statement_options() -> Result<()> { + let mut option_map: HashMap = HashMap::new(); + option_map.insert("max_row_group_size".to_owned(), "123".to_owned()); + option_map.insert("data_pagesize_limit".to_owned(), "123".to_owned()); + option_map.insert("write_batch_size".to_owned(), "123".to_owned()); + option_map.insert("writer_version".to_owned(), "2.0".to_owned()); + option_map.insert("dictionary_page_size_limit".to_owned(), "123".to_owned()); + option_map.insert("created_by".to_owned(), "df write unit test".to_owned()); + option_map.insert("column_index_truncate_length".to_owned(), "123".to_owned()); + option_map.insert("data_page_row_count_limit".to_owned(), "123".to_owned()); + option_map.insert("bloom_filter_enabled".to_owned(), "true".to_owned()); + option_map.insert("encoding".to_owned(), "plain".to_owned()); + option_map.insert("dictionary_enabled".to_owned(), "true".to_owned()); + option_map.insert("compression".to_owned(), "zstd(4)".to_owned()); + option_map.insert("statistics_enabled".to_owned(), "page".to_owned()); + option_map.insert("bloom_filter_fpp".to_owned(), "0.123".to_owned()); + option_map.insert("bloom_filter_ndv".to_owned(), "123".to_owned()); + + let options = StatementOptions::from(&option_map); + let config = ConfigOptions::new(); + + let parquet_options = ParquetWriterOptions::try_from((&config, &options))?; + let properties = parquet_options.writer_options(); + + // Verify the expected options propagated down to parquet crate WriterProperties struct + assert_eq!(properties.max_row_group_size(), 123); + assert_eq!(properties.data_page_size_limit(), 123); + assert_eq!(properties.write_batch_size(), 123); + assert_eq!(properties.writer_version(), WriterVersion::PARQUET_2_0); + assert_eq!(properties.dictionary_page_size_limit(), 123); + assert_eq!(properties.created_by(), "df write unit test"); + assert_eq!(properties.column_index_truncate_length(), Some(123)); + assert_eq!(properties.data_page_row_count_limit(), 123); + properties + .bloom_filter_properties(&ColumnPath::from("")) + .expect("expected bloom filter enabled"); + assert_eq!( + properties + .encoding(&ColumnPath::from("")) + .expect("expected default encoding"), + Encoding::PLAIN + ); + assert!(properties.dictionary_enabled(&ColumnPath::from(""))); + assert_eq!( + properties.compression(&ColumnPath::from("")), + Compression::ZSTD(ZstdLevel::try_new(4_i32)?) + ); + assert_eq!( + properties.statistics_enabled(&ColumnPath::from("")), + EnabledStatistics::Page + ); + assert_eq!( + properties + .bloom_filter_properties(&ColumnPath::from("")) + .expect("expected bloom properties!") + .fpp, + 0.123 + ); + assert_eq!( + properties + .bloom_filter_properties(&ColumnPath::from("")) + .expect("expected bloom properties!") + .ndv, + 123 + ); + + Ok(()) + } + + #[test] + fn test_writeroptions_parquet_column_specific() -> Result<()> { + let mut option_map: HashMap = HashMap::new(); + + option_map.insert("bloom_filter_enabled::col1".to_owned(), "true".to_owned()); + option_map.insert( + "bloom_filter_enabled::col2.nested".to_owned(), + "true".to_owned(), + ); + option_map.insert("encoding::col1".to_owned(), "plain".to_owned()); + option_map.insert("encoding::col2.nested".to_owned(), "rle".to_owned()); + option_map.insert("dictionary_enabled::col1".to_owned(), "true".to_owned()); + option_map.insert( + "dictionary_enabled::col2.nested".to_owned(), + "true".to_owned(), + ); + option_map.insert("compression::col1".to_owned(), "zstd(4)".to_owned()); + option_map.insert("compression::col2.nested".to_owned(), "zstd(10)".to_owned()); + option_map.insert("statistics_enabled::col1".to_owned(), "page".to_owned()); + option_map.insert( + "statistics_enabled::col2.nested".to_owned(), + "none".to_owned(), + ); + option_map.insert("bloom_filter_fpp::col1".to_owned(), "0.123".to_owned()); + option_map.insert( + "bloom_filter_fpp::col2.nested".to_owned(), + "0.456".to_owned(), + ); + option_map.insert("bloom_filter_ndv::col1".to_owned(), "123".to_owned()); + option_map.insert("bloom_filter_ndv::col2.nested".to_owned(), "456".to_owned()); + + let options = StatementOptions::from(&option_map); + let config = ConfigOptions::new(); + + let parquet_options = ParquetWriterOptions::try_from((&config, &options))?; + let properties = parquet_options.writer_options(); + + let col1 = ColumnPath::from(vec!["col1".to_owned()]); + let col2_nested = ColumnPath::from(vec!["col2".to_owned(), "nested".to_owned()]); + + // Verify the expected options propagated down to parquet crate WriterProperties struct + + properties + .bloom_filter_properties(&col1) + .expect("expected bloom filter enabled for col1"); + + properties + .bloom_filter_properties(&col2_nested) + .expect("expected bloom filter enabled cor col2_nested"); + + assert_eq!( + properties.encoding(&col1).expect("expected encoding"), + Encoding::PLAIN + ); + + assert_eq!( + properties + .encoding(&col2_nested) + .expect("expected encoding"), + Encoding::RLE + ); + + assert!(properties.dictionary_enabled(&col1)); + assert!(properties.dictionary_enabled(&col2_nested)); + + assert_eq!( + properties.compression(&col1), + Compression::ZSTD(ZstdLevel::try_new(4_i32)?) + ); + + assert_eq!( + properties.compression(&col2_nested), + Compression::ZSTD(ZstdLevel::try_new(10_i32)?) + ); + + assert_eq!( + properties.statistics_enabled(&col1), + EnabledStatistics::Page + ); + + assert_eq!( + properties.statistics_enabled(&col2_nested), + EnabledStatistics::None + ); + + assert_eq!( + properties + .bloom_filter_properties(&col1) + .expect("expected bloom properties!") + .fpp, + 0.123 + ); + + assert_eq!( + properties + .bloom_filter_properties(&col2_nested) + .expect("expected bloom properties!") + .fpp, + 0.456 + ); + + assert_eq!( + properties + .bloom_filter_properties(&col1) + .expect("expected bloom properties!") + .ndv, + 123 + ); + + assert_eq!( + properties + .bloom_filter_properties(&col2_nested) + .expect("expected bloom properties!") + .ndv, + 456 + ); + + Ok(()) + } + + #[test] + fn test_writeroptions_csv_from_statement_options() -> Result<()> { + let mut option_map: HashMap = HashMap::new(); + option_map.insert("header".to_owned(), "true".to_owned()); + option_map.insert("date_format".to_owned(), "123".to_owned()); + option_map.insert("datetime_format".to_owned(), "123".to_owned()); + option_map.insert("timestamp_format".to_owned(), "2.0".to_owned()); + option_map.insert("time_format".to_owned(), "123".to_owned()); + option_map.insert("rfc3339".to_owned(), "true".to_owned()); + option_map.insert("null_value".to_owned(), "123".to_owned()); + option_map.insert("compression".to_owned(), "gzip".to_owned()); + option_map.insert("delimiter".to_owned(), ";".to_owned()); + + let options = StatementOptions::from(&option_map); + let config = ConfigOptions::new(); + + let csv_options = CsvWriterOptions::try_from((&config, &options))?; + let builder = csv_options.writer_options; + assert!(builder.header()); + let buff = Vec::new(); + let _properties = builder.build(buff); + assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP); + // TODO expand unit test if csv::WriterBuilder allows public read access to properties + + Ok(()) + } + + #[test] + fn test_writeroptions_json_from_statement_options() -> Result<()> { + let mut option_map: HashMap = HashMap::new(); + option_map.insert("compression".to_owned(), "gzip".to_owned()); + + let options = StatementOptions::from(&option_map); + let config = ConfigOptions::new(); + + let json_options = JsonWriterOptions::try_from((&config, &options))?; + assert_eq!(json_options.compression, CompressionTypeVariant::GZIP); + + Ok(()) + } +} diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs new file mode 100644 index 0000000000000..80fa023587eef --- /dev/null +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -0,0 +1,373 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Options related to how parquet files should be written + +use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; + +use crate::{config::ConfigOptions, DataFusionError, Result}; + +use super::StatementOptions; + +use parquet::{ + basic::{BrotliLevel, GzipLevel, ZstdLevel}, + file::properties::{EnabledStatistics, WriterVersion}, + schema::types::ColumnPath, +}; + +/// Options for writing parquet files +#[derive(Clone, Debug)] +pub struct ParquetWriterOptions { + pub writer_options: WriterProperties, +} + +impl ParquetWriterOptions { + pub fn new(writer_options: WriterProperties) -> Self { + Self { writer_options } + } +} + +impl ParquetWriterOptions { + pub fn writer_options(&self) -> &WriterProperties { + &self.writer_options + } +} + +/// Constructs a default Parquet WriterPropertiesBuilder using +/// Session level ConfigOptions to initialize settings +pub fn default_builder(options: &ConfigOptions) -> Result { + let parquet_session_options = &options.execution.parquet; + let mut builder = WriterProperties::builder() + .set_data_page_size_limit(parquet_session_options.data_pagesize_limit) + .set_write_batch_size(parquet_session_options.write_batch_size) + .set_writer_version(parse_version_string( + &parquet_session_options.writer_version, + )?) + .set_dictionary_page_size_limit( + parquet_session_options.dictionary_page_size_limit, + ) + .set_max_row_group_size(parquet_session_options.max_row_group_size) + .set_created_by(parquet_session_options.created_by.clone()) + .set_column_index_truncate_length( + parquet_session_options.column_index_truncate_length, + ) + .set_data_page_row_count_limit(parquet_session_options.data_page_row_count_limit) + .set_bloom_filter_enabled(parquet_session_options.bloom_filter_enabled); + + builder = match &parquet_session_options.encoding { + Some(encoding) => builder.set_encoding(parse_encoding_string(encoding)?), + None => builder, + }; + + builder = match &parquet_session_options.dictionary_enabled { + Some(enabled) => builder.set_dictionary_enabled(*enabled), + None => builder, + }; + + builder = match &parquet_session_options.compression { + Some(compression) => { + builder.set_compression(parse_compression_string(compression)?) + } + None => builder, + }; + + builder = match &parquet_session_options.statistics_enabled { + Some(statistics) => { + builder.set_statistics_enabled(parse_statistics_string(statistics)?) + } + None => builder, + }; + + builder = match &parquet_session_options.max_statistics_size { + Some(size) => builder.set_max_statistics_size(*size), + None => builder, + }; + + builder = match &parquet_session_options.bloom_filter_fpp { + Some(fpp) => builder.set_bloom_filter_fpp(*fpp), + None => builder, + }; + + builder = match &parquet_session_options.bloom_filter_ndv { + Some(ndv) => builder.set_bloom_filter_ndv(*ndv), + None => builder, + }; + + Ok(builder) +} + +impl TryFrom<(&ConfigOptions, &StatementOptions)> for ParquetWriterOptions { + type Error = DataFusionError; + + fn try_from( + configs_and_statement_options: (&ConfigOptions, &StatementOptions), + ) -> Result { + let configs = configs_and_statement_options.0; + let statement_options = configs_and_statement_options.1; + let mut builder = default_builder(configs)?; + for (option, value) in &statement_options.options { + let (option, col_path) = split_option_and_column_path(option); + builder = match option.to_lowercase().as_str(){ + "max_row_group_size" => builder + .set_max_row_group_size(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as u64 as required for {option}!")))?), + "data_pagesize_limit" => builder + .set_data_page_size_limit(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?), + "write_batch_size" => builder + .set_write_batch_size(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?), + "writer_version" => builder + .set_writer_version(parse_version_string(value)?), + "dictionary_page_size_limit" => builder + .set_dictionary_page_size_limit(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?), + "created_by" => builder + .set_created_by(value.to_owned()), + "column_index_truncate_length" => builder + .set_column_index_truncate_length(Some(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?)), + "data_page_row_count_limit" => builder + .set_data_page_row_count_limit(value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?), + "bloom_filter_enabled" => { + let parsed_value = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as bool as required for {option}!")))?; + match col_path{ + Some(path) => builder.set_column_bloom_filter_enabled(path, parsed_value), + None => builder.set_bloom_filter_enabled(parsed_value) + } + }, + "encoding" => { + let parsed_encoding = parse_encoding_string(value)?; + match col_path{ + Some(path) => builder.set_column_encoding(path, parsed_encoding), + None => builder.set_encoding(parsed_encoding) + } + }, + "dictionary_enabled" => { + let parsed_value = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as bool as required for {option}!")))?; + match col_path{ + Some(path) => builder.set_column_dictionary_enabled(path, parsed_value), + None => builder.set_dictionary_enabled(parsed_value) + } + }, + "compression" => { + let parsed_compression = parse_compression_string(value)?; + match col_path{ + Some(path) => builder.set_column_compression(path, parsed_compression), + None => builder.set_compression(parsed_compression) + } + }, + "statistics_enabled" => { + let parsed_value = parse_statistics_string(value)?; + match col_path{ + Some(path) => builder.set_column_statistics_enabled(path, parsed_value), + None => builder.set_statistics_enabled(parsed_value) + } + }, + "max_statistics_size" => { + let parsed_value = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as usize as required for {option}!")))?; + match col_path{ + Some(path) => builder.set_column_max_statistics_size(path, parsed_value), + None => builder.set_max_statistics_size(parsed_value) + } + }, + "bloom_filter_fpp" => { + let parsed_value = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as f64 as required for {option}!")))?; + match col_path{ + Some(path) => builder.set_column_bloom_filter_fpp(path, parsed_value), + None => builder.set_bloom_filter_fpp(parsed_value) + } + }, + "bloom_filter_ndv" => { + let parsed_value = value.parse() + .map_err(|_| DataFusionError::Configuration(format!("Unable to parse {value} as u64 as required for {option}!")))?; + match col_path{ + Some(path) => builder.set_column_bloom_filter_ndv(path, parsed_value), + None => builder.set_bloom_filter_ndv(parsed_value) + } + }, + _ => return Err(DataFusionError::Configuration(format!("Found unsupported option {option} with value {value} for Parquet format!"))) + } + } + Ok(ParquetWriterOptions { + writer_options: builder.build(), + }) + } +} + +/// Parses datafusion.execution.parquet.encoding String to a parquet::basic::Encoding +pub(crate) fn parse_encoding_string( + str_setting: &str, +) -> Result { + let str_setting_lower: &str = &str_setting.to_lowercase(); + match str_setting_lower { + "plain" => Ok(parquet::basic::Encoding::PLAIN), + "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), + "rle" => Ok(parquet::basic::Encoding::RLE), + "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), + "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), + "delta_length_byte_array" => { + Ok(parquet::basic::Encoding::DELTA_LENGTH_BYTE_ARRAY) + } + "delta_byte_array" => Ok(parquet::basic::Encoding::DELTA_BYTE_ARRAY), + "rle_dictionary" => Ok(parquet::basic::Encoding::RLE_DICTIONARY), + "byte_stream_split" => Ok(parquet::basic::Encoding::BYTE_STREAM_SPLIT), + _ => Err(DataFusionError::Configuration(format!( + "Unknown or unsupported parquet encoding: \ + {str_setting}. Valid values are: plain, plain_dictionary, rle, \ + bit_packed, delta_binary_packed, delta_length_byte_array, \ + delta_byte_array, rle_dictionary, and byte_stream_split." + ))), + } +} + +/// Splits compression string into compression codec and optional compression_level +/// I.e. gzip(2) -> gzip, 2 +fn split_compression_string(str_setting: &str) -> Result<(String, Option)> { + // ignore string literal chars passed from sqlparser i.e. remove single quotes + let str_setting = str_setting.replace('\'', ""); + let split_setting = str_setting.split_once('('); + + match split_setting { + Some((codec, rh)) => { + let level = &rh[..rh.len() - 1].parse::().map_err(|_| { + DataFusionError::Configuration(format!( + "Could not parse compression string. \ + Got codec: {} and unknown level from {}", + codec, str_setting + )) + })?; + Ok((codec.to_owned(), Some(*level))) + } + None => Ok((str_setting.to_owned(), None)), + } +} + +/// Helper to ensure compression codecs which don't support levels +/// don't have one set. E.g. snappy(2) is invalid. +fn check_level_is_none(codec: &str, level: &Option) -> Result<()> { + if level.is_some() { + return Err(DataFusionError::Configuration(format!( + "Compression {codec} does not support specifying a level" + ))); + } + Ok(()) +} + +/// Helper to ensure compression codecs which require a level +/// do have one set. E.g. zstd is invalid, zstd(3) is valid +fn require_level(codec: &str, level: Option) -> Result { + level.ok_or(DataFusionError::Configuration(format!( + "{codec} compression requires specifying a level such as {codec}(4)" + ))) +} + +/// Parses datafusion.execution.parquet.compression String to a parquet::basic::Compression +pub(crate) fn parse_compression_string( + str_setting: &str, +) -> Result { + let str_setting_lower: &str = &str_setting.to_lowercase(); + let (codec, level) = split_compression_string(str_setting_lower)?; + let codec = codec.as_str(); + match codec { + "uncompressed" => { + check_level_is_none(codec, &level)?; + Ok(parquet::basic::Compression::UNCOMPRESSED) + } + "snappy" => { + check_level_is_none(codec, &level)?; + Ok(parquet::basic::Compression::SNAPPY) + } + "gzip" => { + let level = require_level(codec, level)?; + Ok(parquet::basic::Compression::GZIP(GzipLevel::try_new( + level, + )?)) + } + "lzo" => { + check_level_is_none(codec, &level)?; + Ok(parquet::basic::Compression::LZO) + } + "brotli" => { + let level = require_level(codec, level)?; + Ok(parquet::basic::Compression::BROTLI(BrotliLevel::try_new( + level, + )?)) + } + "lz4" => { + check_level_is_none(codec, &level)?; + Ok(parquet::basic::Compression::LZ4) + } + "zstd" => { + let level = require_level(codec, level)?; + Ok(parquet::basic::Compression::ZSTD(ZstdLevel::try_new( + level as i32, + )?)) + } + "lz4_raw" => { + check_level_is_none(codec, &level)?; + Ok(parquet::basic::Compression::LZ4_RAW) + } + _ => Err(DataFusionError::Configuration(format!( + "Unknown or unsupported parquet compression: \ + {str_setting}. Valid values are: uncompressed, snappy, gzip(level), \ + lzo, brotli(level), lz4, zstd(level), and lz4_raw." + ))), + } +} + +pub(crate) fn parse_version_string(str_setting: &str) -> Result { + let str_setting_lower: &str = &str_setting.to_lowercase(); + match str_setting_lower { + "1.0" => Ok(WriterVersion::PARQUET_1_0), + "2.0" => Ok(WriterVersion::PARQUET_2_0), + _ => Err(DataFusionError::Configuration(format!( + "Unknown or unsupported parquet writer version {str_setting} \ + valid options are 1.0 and 2.0" + ))), + } +} + +pub(crate) fn parse_statistics_string(str_setting: &str) -> Result { + let str_setting_lower: &str = &str_setting.to_lowercase(); + match str_setting_lower { + "none" => Ok(EnabledStatistics::None), + "chunk" => Ok(EnabledStatistics::Chunk), + "page" => Ok(EnabledStatistics::Page), + _ => Err(DataFusionError::Configuration(format!( + "Unknown or unsupported parquet statistics setting {str_setting} \ + valid options are none, page, and chunk" + ))), + } +} + +pub(crate) fn split_option_and_column_path( + str_setting: &str, +) -> (String, Option) { + match str_setting.replace('\'', "").split_once("::") { + Some((s1, s2)) => { + let col_path = ColumnPath::new(s2.split('.').map(|s| s.to_owned()).collect()); + (s1.to_owned(), Some(col_path)) + } + None => (str_setting.to_owned(), None), + } +} diff --git a/datafusion/common/src/file_options/parse_utils.rs b/datafusion/common/src/file_options/parse_utils.rs new file mode 100644 index 0000000000000..38cf5eb489f7f --- /dev/null +++ b/datafusion/common/src/file_options/parse_utils.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functions for parsing arbitrary passed strings to valid file_option settings +use crate::{DataFusionError, Result}; + +/// Converts a String option to a bool, or returns an error if not a valid bool string. +pub(crate) fn parse_boolean_string(option: &str, value: String) -> Result { + match value.to_lowercase().as_str() { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(DataFusionError::Configuration(format!( + "Unsupported value {value} for option {option}! \ + Valid values are true or false!" + ))), + } +} diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs new file mode 100644 index 0000000000000..d5421c36cd734 --- /dev/null +++ b/datafusion/common/src/format.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::display::{DurationFormat, FormatOptions}; + +/// The default [`FormatOptions`] to use within DataFusion +pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = + FormatOptions::new().with_duration_format(DurationFormat::Pretty); diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs new file mode 100644 index 0000000000000..1cb1751d713ef --- /dev/null +++ b/datafusion/common/src/functional_dependencies.rs @@ -0,0 +1,675 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FunctionalDependencies keeps track of functional dependencies +//! inside DFSchema. + +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; +use std::ops::Deref; +use std::vec::IntoIter; + +use crate::error::_plan_err; +use crate::utils::{merge_and_order_indices, set_difference}; +use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; + +use sqlparser::ast::TableConstraint; + +/// This object defines a constraint on a table. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Constraint { + /// Columns with the given indices form a composite primary key (they are + /// jointly unique and not nullable): + PrimaryKey(Vec), + /// Columns with the given indices form a composite unique key: + Unique(Vec), +} + +/// This object encapsulates a list of functional constraints: +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Constraints { + inner: Vec, +} + +impl Constraints { + /// Create empty constraints + pub fn empty() -> Self { + Constraints::new_unverified(vec![]) + } + + /// Create a new `Constraints` object from the given `constraints`. + /// Users should use the `empty` or `new_from_table_constraints` functions + /// for constructing `Constraints`. This constructor is for internal + /// purposes only and does not check whether the argument is valid. The user + /// is responsible for supplying a valid vector of `Constraint` objects. + pub fn new_unverified(constraints: Vec) -> Self { + Self { inner: constraints } + } + + /// Convert each `TableConstraint` to corresponding `Constraint` + pub fn new_from_table_constraints( + constraints: &[TableConstraint], + df_schema: &DFSchemaRef, + ) -> Result { + let constraints = constraints + .iter() + .map(|c: &TableConstraint| match c { + TableConstraint::Unique { + columns, + is_primary, + .. + } => { + // Get primary key and/or unique indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = df_schema + .fields() + .iter() + .position(|item| { + item.qualified_name() == pk.value.clone() + }) + .ok_or_else(|| { + DataFusionError::Execution( + "Primary key doesn't exist".to_string(), + ) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(if *is_primary { + Constraint::PrimaryKey(indices) + } else { + Constraint::Unique(indices) + }) + } + TableConstraint::ForeignKey { .. } => { + _plan_err!("Foreign key constraints are not currently supported") + } + TableConstraint::Check { .. } => { + _plan_err!("Check constraints are not currently supported") + } + TableConstraint::Index { .. } => { + _plan_err!("Indexes are not currently supported") + } + TableConstraint::FulltextOrSpatial { .. } => { + _plan_err!("Indexes are not currently supported") + } + }) + .collect::>>()?; + Ok(Constraints::new_unverified(constraints)) + } + + /// Check whether constraints is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl IntoIterator for Constraints { + type Item = Constraint; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +impl Display for Constraints { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let pk: Vec = self.inner.iter().map(|c| format!("{:?}", c)).collect(); + let pk = pk.join(", "); + if !pk.is_empty() { + write!(f, " constraints=[{pk}]") + } else { + write!(f, "") + } + } +} + +impl Deref for Constraints { + type Target = [Constraint]; + + fn deref(&self) -> &Self::Target { + self.inner.as_slice() + } +} + +/// This object defines a functional dependence in the schema. A functional +/// dependence defines a relationship between determinant keys and dependent +/// columns. A determinant key is a column, or a set of columns, whose value +/// uniquely determines values of some other (dependent) columns. If two rows +/// have the same determinant key, dependent columns in these rows are +/// necessarily the same. If the determinant key is unique, the set of +/// dependent columns is equal to the entire schema and the determinant key can +/// serve as a primary key. Note that a primary key may "downgrade" into a +/// determinant key due to an operation such as a join, and this object is +/// used to track dependence relationships in such cases. For more information +/// on functional dependencies, see: +/// +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionalDependence { + // Column indices of the (possibly composite) determinant key: + pub source_indices: Vec, + // Column indices of dependent column(s): + pub target_indices: Vec, + /// Flag indicating whether one of the `source_indices` can receive NULL values. + /// For a data source, if the constraint in question is `Constraint::Unique`, + /// this flag is `true`. If the constraint in question is `Constraint::PrimaryKey`, + /// this flag is `false`. + /// Note that as the schema changes between different stages in a plan, + /// such as after LEFT JOIN or RIGHT JOIN operations, this property may + /// change. + pub nullable: bool, + // The functional dependency mode: + pub mode: Dependency, +} + +/// Describes functional dependency mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Dependency { + Single, // A determinant key may occur only once. + Multi, // A determinant key may occur multiple times (in multiple rows). +} + +impl FunctionalDependence { + // Creates a new functional dependence. + pub fn new( + source_indices: Vec, + target_indices: Vec, + nullable: bool, + ) -> Self { + Self { + source_indices, + target_indices, + nullable, + // Start with the least restrictive mode by default: + mode: Dependency::Multi, + } + } + + pub fn with_mode(mut self, mode: Dependency) -> Self { + self.mode = mode; + self + } +} + +/// This object encapsulates all functional dependencies in a given relation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionalDependencies { + deps: Vec, +} + +impl FunctionalDependencies { + /// Creates an empty `FunctionalDependencies` object. + pub fn empty() -> Self { + Self { deps: vec![] } + } + + /// Creates a new `FunctionalDependencies` object from a vector of + /// `FunctionalDependence` objects. + pub fn new(dependencies: Vec) -> Self { + Self { deps: dependencies } + } + + /// Creates a new `FunctionalDependencies` object from the given constraints. + pub fn new_from_constraints( + constraints: Option<&Constraints>, + n_field: usize, + ) -> Self { + if let Some(Constraints { inner: constraints }) = constraints { + // Construct dependency objects based on each individual constraint: + let dependencies = constraints + .iter() + .map(|constraint| { + // All the field indices are associated with the whole table + // since we are dealing with table level constraints: + let dependency = match constraint { + Constraint::PrimaryKey(indices) => FunctionalDependence::new( + indices.to_vec(), + (0..n_field).collect::>(), + false, + ), + Constraint::Unique(indices) => FunctionalDependence::new( + indices.to_vec(), + (0..n_field).collect::>(), + true, + ), + }; + // As primary keys are guaranteed to be unique, set the + // functional dependency mode to `Dependency::Single`: + dependency.with_mode(Dependency::Single) + }) + .collect::>(); + Self::new(dependencies) + } else { + // There is no constraint, return an empty object: + Self::empty() + } + } + + pub fn with_dependency(mut self, mode: Dependency) -> Self { + self.deps.iter_mut().for_each(|item| item.mode = mode); + self + } + + /// Merges the given functional dependencies with these. + pub fn extend(&mut self, other: FunctionalDependencies) { + self.deps.extend(other.deps); + } + + /// Sanity checks if functional dependencies are valid. For example, if + /// there are 10 fields, we cannot receive any index further than 9. + pub fn is_valid(&self, n_field: usize) -> bool { + self.deps.iter().all( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + source_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + && target_indices + .iter() + .max() + .map(|&max_index| max_index < n_field) + .unwrap_or(true) + }, + ) + } + + /// Adds the `offset` value to `source_indices` and `target_indices` for + /// each functional dependency. + pub fn add_offset(&mut self, offset: usize) { + self.deps.iter_mut().for_each( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + *source_indices = add_offset_to_vec(source_indices, offset); + *target_indices = add_offset_to_vec(target_indices, offset); + }, + ) + } + + /// Updates `source_indices` and `target_indices` of each functional + /// dependence using the index mapping given in `proj_indices`. + /// + /// Assume that `proj_indices` is \[2, 5, 8\] and we have a functional + /// dependence \[5\] (`source_indices`) -> \[5, 8\] (`target_indices`). + /// In the updated schema, fields at indices \[2, 5, 8\] will transform + /// to \[0, 1, 2\]. Therefore, the resulting functional dependence will + /// be \[1\] -> \[1, 2\]. + pub fn project_functional_dependencies( + &self, + proj_indices: &[usize], + // The argument `n_out` denotes the schema field length, which is needed + // to correctly associate a `Single`-mode dependence with the whole table. + n_out: usize, + ) -> FunctionalDependencies { + let mut projected_func_dependencies = vec![]; + for FunctionalDependence { + source_indices, + target_indices, + nullable, + mode, + } in &self.deps + { + let new_source_indices = + update_elements_with_matching_indices(source_indices, proj_indices); + let new_target_indices = if *mode == Dependency::Single { + // Associate with all of the fields in the schema: + (0..n_out).collect() + } else { + // Update associations according to projection: + update_elements_with_matching_indices(target_indices, proj_indices) + }; + // All of the composite indices should still be valid after projection; + // otherwise, functional dependency cannot be propagated. + if new_source_indices.len() == source_indices.len() { + let new_func_dependence = FunctionalDependence::new( + new_source_indices, + new_target_indices, + *nullable, + ) + .with_mode(*mode); + projected_func_dependencies.push(new_func_dependence); + } + } + FunctionalDependencies::new(projected_func_dependencies) + } + + /// This function joins this set of functional dependencies with the `other` + /// according to the given `join_type`. + pub fn join( + &self, + other: &FunctionalDependencies, + join_type: &JoinType, + left_cols_len: usize, + ) -> FunctionalDependencies { + // Get mutable copies of left and right side dependencies: + let mut right_func_dependencies = other.clone(); + let mut left_func_dependencies = self.clone(); + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right => { + // Add offset to right schema: + right_func_dependencies.add_offset(left_cols_len); + + // Result may have multiple values, update the dependency mode: + left_func_dependencies = + left_func_dependencies.with_dependency(Dependency::Multi); + right_func_dependencies = + right_func_dependencies.with_dependency(Dependency::Multi); + + if *join_type == JoinType::Left { + // Downgrade the right side, since it may have additional NULL values: + right_func_dependencies.downgrade_dependencies(); + } else if *join_type == JoinType::Right { + // Downgrade the left side, since it may have additional NULL values: + left_func_dependencies.downgrade_dependencies(); + } + // Combine left and right functional dependencies: + left_func_dependencies.extend(right_func_dependencies); + left_func_dependencies + } + JoinType::LeftSemi | JoinType::LeftAnti => { + // These joins preserve functional dependencies of the left side: + left_func_dependencies + } + JoinType::RightSemi | JoinType::RightAnti => { + // These joins preserve functional dependencies of the right side: + right_func_dependencies + } + JoinType::Full => { + // All of the functional dependencies are lost in a FULL join: + FunctionalDependencies::empty() + } + } + } + + /// This function downgrades a functional dependency when nullability becomes + /// a possibility: + /// - If the dependency in question is UNIQUE (i.e. nullable), a new null value + /// invalidates the dependency. + /// - If the dependency in question is PRIMARY KEY (i.e. not nullable), a new + /// null value turns it into UNIQUE mode. + fn downgrade_dependencies(&mut self) { + // Delete nullable dependencies, since they are no longer valid: + self.deps.retain(|item| !item.nullable); + self.deps.iter_mut().for_each(|item| item.nullable = true); + } + + /// This function ensures that functional dependencies involving uniquely + /// occuring determinant keys cover their entire table in terms of + /// dependent columns. + pub fn extend_target_indices(&mut self, n_out: usize) { + self.deps.iter_mut().for_each( + |FunctionalDependence { + mode, + target_indices, + .. + }| { + // If unique, cover the whole table: + if *mode == Dependency::Single { + *target_indices = (0..n_out).collect::>(); + } + }, + ) + } +} + +impl Deref for FunctionalDependencies { + type Target = [FunctionalDependence]; + + fn deref(&self) -> &Self::Target { + self.deps.as_slice() + } +} + +/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. +pub fn aggregate_functional_dependencies( + aggr_input_schema: &DFSchema, + group_by_expr_names: &[String], + aggr_schema: &DFSchema, +) -> FunctionalDependencies { + let mut aggregate_func_dependencies = vec![]; + let aggr_input_fields = aggr_input_schema.fields(); + let aggr_fields = aggr_schema.fields(); + // Association covers the whole table: + let target_indices = (0..aggr_schema.fields().len()).collect::>(); + // Get functional dependencies of the schema: + let func_dependencies = aggr_input_schema.functional_dependencies(); + for FunctionalDependence { + source_indices, + nullable, + mode, + .. + } in &func_dependencies.deps + { + // Keep source indices in a `HashSet` to prevent duplicate entries: + let mut new_source_indices = vec![]; + let mut new_source_field_names = vec![]; + let source_field_names = source_indices + .iter() + .map(|&idx| aggr_input_fields[idx].qualified_name()) + .collect::>(); + + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { + // When one of the input determinant expressions matches with + // the GROUP BY expression, add the index of the GROUP BY + // expression as a new determinant key: + if source_field_names.contains(group_by_expr_name) { + new_source_indices.push(idx); + new_source_field_names.push(group_by_expr_name.clone()); + } + } + let existing_target_indices = + get_target_functional_dependencies(aggr_input_schema, group_by_expr_names); + let new_target_indices = get_target_functional_dependencies( + aggr_input_schema, + &new_source_field_names, + ); + let mode = if existing_target_indices == new_target_indices + && new_target_indices.is_some() + { + // If dependency covers all GROUP BY expressions, mode will be `Single`: + Dependency::Single + } else { + // Otherwise, existing mode is preserved: + *mode + }; + // All of the composite indices occur in the GROUP BY expression: + if new_source_indices.len() == source_indices.len() { + aggregate_func_dependencies.push( + FunctionalDependence::new( + new_source_indices, + target_indices.clone(), + *nullable, + ) + .with_mode(mode), + ); + } + } + + // If we have a single GROUP BY key, we can guarantee uniqueness after + // aggregation: + if group_by_expr_names.len() == 1 { + // If `source_indices` contain 0, delete this functional dependency + // as it will be added anyway with mode `Dependency::Single`: + aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); + // Add a new functional dependency associated with the whole table: + aggregate_func_dependencies.push( + // Use nullable property of the group by expression + FunctionalDependence::new( + vec![0], + target_indices, + aggr_fields[0].is_nullable(), + ) + .with_mode(Dependency::Single), + ); + } + FunctionalDependencies::new(aggregate_func_dependencies) +} + +/// Returns target indices, for the determinant keys that are inside +/// group by expressions. +pub fn get_target_functional_dependencies( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let mut combined_target_indices = HashSet::new(); + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + let source_key_names = source_indices + .iter() + .map(|id_key_idx| field_names[*id_key_idx].clone()) + .collect::>(); + // If the GROUP BY expression contains a determinant key, we can use + // the associated fields after aggregation even if they are not part + // of the GROUP BY expression. + if source_key_names + .iter() + .all(|source_key_name| group_by_expr_names.contains(source_key_name)) + { + combined_target_indices.extend(target_indices.iter()); + } + } + (!combined_target_indices.is_empty()).then_some({ + let mut result = combined_target_indices.into_iter().collect::>(); + result.sort(); + result + }) +} + +/// Returns indices for the minimal subset of GROUP BY expressions that are +/// functionally equivalent to the original set of GROUP BY expressions. +pub fn get_required_group_by_exprs_indices( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + let mut groupby_expr_indices = group_by_expr_names + .iter() + .map(|group_by_expr_name| { + field_names + .iter() + .position(|field_name| field_name == group_by_expr_name) + }) + .collect::>>()?; + + groupby_expr_indices.sort(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + if source_indices + .iter() + .all(|source_idx| groupby_expr_indices.contains(source_idx)) + { + // If all source indices are among GROUP BY expression indices, we + // can remove target indices from GROUP BY expression indices and + // use source indices instead. + groupby_expr_indices = set_difference(&groupby_expr_indices, target_indices); + groupby_expr_indices = + merge_and_order_indices(groupby_expr_indices, source_indices); + } + } + groupby_expr_indices + .iter() + .map(|idx| { + group_by_expr_names + .iter() + .position(|name| &field_names[*idx] == name) + }) + .collect() +} + +/// Updates entries inside the `entries` vector with their corresponding +/// indices inside the `proj_indices` vector. +fn update_elements_with_matching_indices( + entries: &[usize], + proj_indices: &[usize], +) -> Vec { + entries + .iter() + .filter_map(|val| proj_indices.iter().position(|proj_idx| proj_idx == val)) + .collect() +} + +/// Adds `offset` value to each entry inside `in_data`. +fn add_offset_to_vec>( + in_data: &[T], + offset: T, +) -> Vec { + in_data.iter().map(|&item| item + offset).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constraints_iter() { + let constraints = Constraints::new_unverified(vec![ + Constraint::PrimaryKey(vec![10]), + Constraint::Unique(vec![20]), + ]); + let mut iter = constraints.iter(); + assert_eq!(iter.next(), Some(&Constraint::PrimaryKey(vec![10]))); + assert_eq!(iter.next(), Some(&Constraint::Unique(vec![20]))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_get_updated_id_keys() { + let fund_dependencies = + FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![1], + vec![0, 1, 2], + true, + )]); + let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); + let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![0], + vec![0, 1], + true, + )]); + assert_eq!(res, expected); + } +} diff --git a/datafusion/physical-expr/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs similarity index 68% rename from datafusion/physical-expr/src/hash_utils.rs rename to datafusion/common/src/hash_utils.rs index b751df928d2a8..9198461e00bf9 100644 --- a/datafusion/physical-expr/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -17,19 +17,19 @@ //! Functionality used both on logical and physical plans +use std::sync::Arc; + use ahash::RandomState; use arrow::array::*; use arrow::datatypes::*; use arrow::row::Rows; use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; -use datafusion_common::{ - cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, - }, - DataFusionError, Result, + +use crate::cast::{ + as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, }; -use std::sync::Arc; +use crate::error::{DataFusionError, Result, _internal_err}; // Combines two hashes into one hash #[inline] @@ -51,7 +51,7 @@ fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: } } -pub(crate) trait HashValue { +pub trait HashValue { fn hash_one(&self, state: &RandomState) -> u64; } @@ -84,35 +84,93 @@ macro_rules! hash_float_value { } hash_float_value!((half::f16, u16), (f32, u32), (f64, u64)); +/// Builds hash values of PrimitiveArray and writes them into `hashes_buffer` +/// If `rehash==true` this combines the previous hash value in the buffer +/// with the new hash using `combine_hashes` +fn hash_array_primitive( + array: &PrimitiveArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) where + T: ArrowPrimitiveType, + ::Native: HashValue, +{ + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + + if array.null_count() == 0 { + if rehash { + for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { + *hash = combine_hashes(value.hash_one(random_state), *hash); + } + } else { + for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { + *hash = value.hash_one(random_state); + } + } + } else if rehash { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + if !array.is_null(i) { + let value = unsafe { array.value_unchecked(i) }; + *hash = combine_hashes(value.hash_one(random_state), *hash); + } + } + } else { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + if !array.is_null(i) { + let value = unsafe { array.value_unchecked(i) }; + *hash = value.hash_one(random_state); + } + } + } +} + +/// Hashes one array into the `hashes_buffer` +/// If `rehash==true` this combines the previous hash value in the buffer +/// with the new hash using `combine_hashes` fn hash_array( array: T, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, + rehash: bool, ) where T: ArrayAccessor, T::Item: HashValue, { + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + if array.null_count() == 0 { - if multi_col { + if rehash { for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = combine_hashes(array.value(i).hash_one(random_state), *hash); + let value = unsafe { array.value_unchecked(i) }; + *hash = combine_hashes(value.hash_one(random_state), *hash); } } else { for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = array.value(i).hash_one(random_state); + let value = unsafe { array.value_unchecked(i) }; + *hash = value.hash_one(random_state); } } - } else if multi_col { + } else if rehash { for (i, hash) in hashes_buffer.iter_mut().enumerate() { if !array.is_null(i) { - *hash = combine_hashes(array.value(i).hash_one(random_state), *hash); + let value = unsafe { array.value_unchecked(i) }; + *hash = combine_hashes(value.hash_one(random_state), *hash); } } } else { for (i, hash) in hashes_buffer.iter_mut().enumerate() { if !array.is_null(i) { - *hash = array.value(i).hash_one(random_state); + let value = unsafe { array.value_unchecked(i) }; + *hash = value.hash_one(random_state); } } } @@ -149,6 +207,39 @@ fn hash_dictionary( Ok(()) } +fn hash_list_array( + array: &GenericListArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> +where + OffsetSize: OffsetSizeTrait, +{ + let values = array.values().clone(); + let offsets = array.value_offsets(); + let nulls = array.nulls(); + let mut values_hashes = vec![0u64; values.len()]; + create_hashes(&[values], random_state, &mut values_hashes)?; + if let Some(nulls) = nulls { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) /// @@ -208,42 +299,48 @@ pub fn create_hashes<'a>( random_state: &RandomState, hashes_buffer: &'a mut Vec, ) -> Result<&'a mut Vec> { - // combine hashes with `combine_hashes` if we have more than 1 column - - let multi_col = arrays.len() > 1; - - for col in arrays { + for (i, col) in arrays.iter().enumerate() { let array = col.as_ref(); + // combine hashes with `combine_hashes` for all columns besides the first + let rehash = i >= 1; downcast_primitive_array! { - array => hash_array(array, random_state, hashes_buffer, multi_col), - DataType::Null => hash_null(random_state, hashes_buffer, multi_col), - DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, multi_col), - DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, multi_col), - DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, multi_col), - DataType::Binary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, multi_col), - DataType::LargeBinary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, multi_col), + array => hash_array_primitive(array, random_state, hashes_buffer, rehash), + DataType::Null => hash_null(random_state, hashes_buffer, rehash), + DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash), + DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash), + DataType::Binary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::LargeBinary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), DataType::FixedSizeBinary(_) => { let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); - hash_array(array, random_state, hashes_buffer, multi_col) + hash_array(array, random_state, hashes_buffer, rehash) } DataType::Decimal128(_, _) => { let array = as_primitive_array::(array)?; - hash_array(array, random_state, hashes_buffer, multi_col) + hash_array_primitive(array, random_state, hashes_buffer, rehash) } DataType::Decimal256(_, _) => { let array = as_primitive_array::(array)?; - hash_array(array, random_state, hashes_buffer, multi_col) + hash_array_primitive(array, random_state, hashes_buffer, rehash) } DataType::Dictionary(_, _) => downcast_dictionary_array! { - array => hash_dictionary(array, random_state, hashes_buffer, multi_col)?, + array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() } + DataType::List(_) => { + let array = as_list_array(array); + hash_list_array(array, random_state, hashes_buffer)?; + } + DataType::LargeList(_) => { + let array = as_large_list_array(array); + hash_list_array(array, random_state, hashes_buffer)?; + } _ => { // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( + return _internal_err!( "Unsupported data type in hasher: {}", col.data_type() - ))); + ); } } } @@ -356,7 +453,7 @@ mod tests { // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] fn create_hashes_for_dict_arrays() { - let strings = vec![Some("foo"), None, Some("bar"), Some("foo"), None]; + let strings = [Some("foo"), None, Some("bar"), Some("foo"), None]; let string_array = Arc::new(strings.iter().cloned().collect::()); let dict_array = Arc::new( @@ -396,12 +493,34 @@ mod tests { assert_ne!(dict_hashes[0], dict_hashes[2]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + Some(vec![Some(0), Some(1), Some(2)]), + ]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[5]); + assert_eq!(hashes[1], hashes[4]); + assert_eq!(hashes[2], hashes[3]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] fn create_multi_column_hash_for_dict_arrays() { - let strings1 = vec![Some("foo"), None, Some("bar")]; - let strings2 = vec![Some("blarg"), Some("blah"), None]; + let strings1 = [Some("foo"), None, Some("bar")]; + let strings2 = [Some("blarg"), Some("blah"), None]; let string_array = Arc::new(strings1.iter().cloned().collect::()); let dict_array = Arc::new( diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index 9da9e5625f726..0a00a57ba45fe 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! [`JoinType`] and [`JoinConstraint`] +//! Defines the [`JoinType`], [`JoinConstraint`] and [`JoinSide`] types. use std::{ fmt::{self, Display, Formatter}, str::FromStr, }; +use crate::error::_not_impl_err; use crate::{DataFusionError, Result}; /// Join type @@ -81,9 +82,7 @@ impl FromStr for JoinType { "RIGHTSEMI" => Ok(JoinType::RightSemi), "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), - _ => Err(DataFusionError::NotImplemented(format!( - "The join type {s} does not exist or is not implemented" - ))), + _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } } @@ -96,3 +95,32 @@ pub enum JoinConstraint { /// Join USING Using, } + +impl Display for JoinSide { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + JoinSide::Left => write!(f, "left"), + JoinSide::Right => write!(f, "right"), + } + } +} + +/// Join side. +/// Stores the referred table side during calculations +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum JoinSide { + /// Left side of the join + Left, + /// Right side of the join + Right, +} + +impl JoinSide { + /// Inverse the join side + pub fn negate(&self) -> Self { + match self { + JoinSide::Left => JoinSide::Right, + JoinSide::Right => JoinSide::Left, + } + } +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index ef7e0947008a2..ed547782e4a5e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -15,35 +15,59 @@ // specific language governing permissions and limitations // under the License. -pub mod cast; mod column; -pub mod config; -pub mod delta; mod dfschema; mod error; +mod functional_dependencies; mod join_type; -pub mod parsers; +mod param_value; #[cfg(feature = "pyarrow")] mod pyarrow; -pub mod scalar; mod schema_reference; -pub mod stats; mod table_reference; +mod unnest; + +pub mod alias; +pub mod cast; +pub mod config; +pub mod display; +pub mod file_options; +pub mod format; +pub mod hash_utils; +pub mod parsers; +pub mod rounding; +pub mod scalar; +pub mod stats; pub mod test_util; pub mod tree_node; pub mod utils; +/// Reexport arrow crate +pub use arrow; pub use column::Column; -pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; +pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema}; pub use error::{ field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, SharedResult, }; -pub use join_type::{JoinConstraint, JoinType}; +pub use file_options::file_type::{ + FileType, GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, + DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, +}; +pub use file_options::FileTypeWriterOptions; +pub use functional_dependencies::{ + aggregate_functional_dependencies, get_required_group_by_exprs_indices, + get_target_functional_dependencies, Constraint, Constraints, Dependency, + FunctionalDependence, FunctionalDependencies, +}; +pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{OwnedTableReference, ResolvedTableReference, TableReference}; +pub use unnest::UnnestOptions; +pub use utils::project_schema; /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs new file mode 100644 index 0000000000000..253c312b66d51 --- /dev/null +++ b/datafusion/common/src/param_value.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_internal_err, _plan_err}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow_schema::DataType; +use std::collections::HashMap; + +/// The parameter value corresponding to the placeholder +#[derive(Debug, Clone)] +pub enum ParamValues { + /// for positional query parameters, like select * from test where a > $1 and b = $2 + LIST(Vec), + /// for named query parameters, like select * from test where a > $foo and b = $goo + MAP(HashMap), +} + +impl ParamValues { + /// Verify parameter list length and type + pub fn verify(&self, expect: &Vec) -> Result<()> { + match self { + ParamValues::LIST(list) => { + // Verify if the number of params matches the number of values + if expect.len() != list.len() { + return _plan_err!( + "Expected {} parameters, got {}", + expect.len(), + list.len() + ); + } + + // Verify if the types of the params matches the types of the values + let iter = expect.iter().zip(list.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.data_type() { + return _plan_err!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.data_type(), + i + ); + } + } + Ok(()) + } + ParamValues::MAP(_) => { + // If it is a named query, variables can be reused, + // but the lengths are not necessarily equal + Ok(()) + } + } + } + + pub fn get_placeholders_with_values( + &self, + id: &String, + data_type: &Option, + ) -> Result { + match self { + ParamValues::LIST(list) => { + if id.is_empty() || id == "$0" { + return _plan_err!("Empty placeholder id"); + } + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {e}" + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = list.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + ParamValues::MAP(map) => { + // convert name (in format $a, $b, ..) to mapped values (a, b, ..) + let name = &id[1..]; + // value at the name position in param_values should be the value for the placeholder + let value = map.get(name).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with name {id}" + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if Some(value.data_type()) != *data_type { + return _internal_err!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.data_type() + ); + } + Ok(value.clone()) + } + } + } +} + +impl From> for ParamValues { + fn from(value: Vec) -> Self { + Self::LIST(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: Vec<(K, ScalarValue)>) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} + +impl From> for ParamValues +where + K: Into, +{ + fn from(value: HashMap) -> Self { + let value: HashMap = + value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + Self::MAP(value) + } +} diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index 58f4db751c4c2..ea2508f8c4559 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -46,7 +46,7 @@ impl FromStr for CompressionTypeVariant { "BZIP2" | "BZ2" => Ok(Self::BZIP2), "XZ" => Ok(Self::XZ), "ZST" | "ZSTD" => Ok(Self::ZSTD), - "" => Ok(Self::UNCOMPRESSED), + "" | "UNCOMPRESSED" => Ok(Self::UNCOMPRESSED), _ => Err(ParserError::ParserError(format!( "Unsupported file compression type {s}" ))), diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index d18782e037ae4..f4356477532f4 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! PyArrow +//! Conversions between PyArrow and DataFusion types use arrow::array::ArrayData; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; @@ -54,7 +54,7 @@ impl FromPyArrow for ScalarValue { impl ToPyArrow for ScalarValue { fn to_pyarrow(&self, py: Python) -> PyResult { - let array = self.to_array(); + let array = self.to_array()?; // convert to pyarrow array using C data interface let pyarray = array.to_data().to_pyarrow(py)?; let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; @@ -94,10 +94,11 @@ mod tests { Some(locals), ) .expect("Couldn't get python info"); - let executable: String = - locals.get_item("executable").unwrap().extract().unwrap(); - let python_path: Vec<&str> = - locals.get_item("python_path").unwrap().extract().unwrap(); + let executable = locals.get_item("executable").unwrap().unwrap(); + let executable: String = executable.extract().unwrap(); + + let python_path = locals.get_item("python_path").unwrap().unwrap(); + let python_path: Vec<&str> = python_path.extract().unwrap(); panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ HINT: try `pip install pyarrow`\n\ @@ -118,7 +119,7 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Int32(Some(23)), ScalarValue::Float64(Some(12.34)), - ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::from("Hello!"), ScalarValue::Date32(Some(1234)), ]; diff --git a/datafusion/physical-expr/src/intervals/rounding.rs b/datafusion/common/src/rounding.rs similarity index 97% rename from datafusion/physical-expr/src/intervals/rounding.rs rename to datafusion/common/src/rounding.rs index 06c4f9e8a9570..413067ecd61ed 100644 --- a/datafusion/physical-expr/src/intervals/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -22,8 +22,8 @@ use std::ops::{Add, BitAnd, Sub}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use crate::Result; +use crate::ScalarValue; // Define constants for ARM #[cfg(all(target_arch = "aarch64", not(target_os = "windows")))] @@ -162,13 +162,12 @@ impl FloatBits for f64 { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_up; +/// use datafusion_common::rounding::next_up; /// /// let f: f32 = 1.0; /// let next_f = next_up(f); /// assert_eq!(next_f, 1.0000001); /// ``` -#[allow(dead_code)] pub fn next_up(float: F) -> F { let bits = float.to_bits(); if float.float_is_nan() || bits == F::infinity().to_bits() { @@ -196,13 +195,12 @@ pub fn next_up(float: F) -> F { /// # Examples /// /// ``` -/// use datafusion_physical_expr::intervals::rounding::next_down; +/// use datafusion_common::rounding::next_down; /// /// let f: f32 = 1.0; /// let next_f = next_down(f); /// assert_eq!(next_f, 0.99999994); /// ``` -#[allow(dead_code)] pub fn next_down(float: F) -> F { let bits = float.to_bits(); if float.float_is_nan() || bits == F::neg_infinity().to_bits() { diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index e84ef545198ef..d730fbf89b723 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -18,48 +18,83 @@ //! This module provides ScalarValue, an enum that can be used for storage of single elements use std::borrow::Borrow; -use std::cmp::{max, Ordering}; +use std::cmp::Ordering; use std::collections::HashSet; use std::convert::{Infallible, TryInto}; -use std::ops::{Add, Sub}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::cast::{ - as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array, - as_fixed_size_list_array, as_list_array, as_struct_array, + as_decimal128_array, as_decimal256_array, as_dictionary_array, + as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, }; -use crate::delta::shift_months; -use crate::error::{DataFusionError, Result}; -use arrow::buffer::NullBuffer; -use arrow::compute::nullif; -use arrow::datatypes::{FieldRef, Fields, SchemaBuilder}; +use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; +use crate::hash_utils::create_hashes; +use crate::utils::{array_into_large_list_array, array_into_list_array}; +use arrow::compute::kernels::numeric::*; +use arrow::datatypes::{i256, Fields, SchemaBuilder}; +use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; -use arrow_array::timezone::Tz; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; - -// Constants we use throughout this file: -const MILLISECS_IN_ONE_DAY: i64 = 86_400_000; -const NANOSECS_IN_ONE_DAY: i64 = 86_400_000_000_000; -const SECS_IN_ONE_MONTH: i64 = 2_592_000; // assuming 30 days. -const MILLISECS_IN_ONE_MONTH: i64 = 2_592_000_000; // assuming 30 days. -const MICROSECS_IN_ONE_MONTH: i64 = 2_592_000_000_000; // assuming 30 days. -const NANOSECS_IN_ONE_MONTH: i128 = 2_592_000_000_000_000; // assuming 30 days. - -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part to arrow's [`Array`]. +use arrow_array::cast::as_list_array; +use arrow_array::types::ArrowTimestampType; +use arrow_array::{ArrowNativeTypeOp, Scalar}; + +/// A dynamically typed, nullable single value, (the single-valued counter-part +/// to arrow's [`Array`]) +/// +/// # Performance +/// +/// In general, please use arrow [`Array`]s rather than [`ScalarValue`] whenever +/// possible, as it is far more efficient for multiple values. +/// +/// # Example +/// ``` +/// # use datafusion_common::ScalarValue; +/// // Create single scalar value for an Int32 value +/// let s1 = ScalarValue::Int32(Some(10)); +/// +/// // You can also create values using the From impl: +/// let s2 = ScalarValue::from(10i32); +/// assert_eq!(s1, s2); +/// ``` +/// +/// # Null Handling /// +/// `ScalarValue` represents null values in the same way as Arrow. Nulls are +/// "typed" in the sense that a null value in an [`Int32Array`] is different +/// than a null value in a [`Float64Array`], and is different than the values in +/// a [`NullArray`]. +/// +/// ``` +/// # fn main() -> datafusion_common::Result<()> { +/// # use std::collections::hash_set::Difference; +/// # use datafusion_common::ScalarValue; +/// # use arrow::datatypes::DataType; +/// // You can create a 'null' Int32 value directly: +/// let s1 = ScalarValue::Int32(None); +/// +/// // You can also create a null value for a given datatype: +/// let s2 = ScalarValue::try_from(&DataType::Int32)?; +/// assert_eq!(s1, s2); +/// +/// // Note that this is DIFFERENT than a `ScalarValue::Null` +/// let s3 = ScalarValue::Null; +/// assert_ne!(s1, s3); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Further Reading /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. @@ -75,6 +110,8 @@ pub enum ScalarValue { Float64(Option), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), + /// 256bit decimal, using the i256 to represent the decimal, precision scale + Decimal256(Option, u8, i8), /// signed 8bit int Int8(Option), /// signed 16bit int @@ -101,8 +138,16 @@ pub enum ScalarValue { FixedSizeBinary(i32, Option>), /// large binary LargeBinary(Option>), - /// list of nested ScalarValue - List(Option>, FieldRef), + /// Fixed size list scalar. + /// + /// The array must be a FixedSizeListArray with length 1. + FixedSizeList(ArrayRef), + /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] + /// + /// The array must be a ListArray with length 1. + List(ArrayRef), + /// The array must be a LargeListArray with length 1. + LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -132,6 +177,14 @@ pub enum ScalarValue { /// Months and days are encoded as 32-bit signed integers. /// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds). IntervalMonthDayNano(Option), + /// Duration in seconds + DurationSecond(Option), + /// Duration in milliseconds + DurationMillisecond(Option), + /// Duration in microseconds + DurationMicrosecond(Option), + /// Duration in nanoseconds + DurationNanosecond(Option), /// struct of nested ScalarValue Struct(Option>, Fields), /// Dictionary type: index type and value @@ -150,6 +203,10 @@ impl PartialEq for ScalarValue { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } (Decimal128(_, _, _), _) => false, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal256(_, _, _), _) => false, (Boolean(v1), Boolean(v2)) => v1.eq(v2), (Boolean(_), _) => false, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -188,8 +245,12 @@ impl PartialEq for ScalarValue { (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), (LargeBinary(_), _) => false, - (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (List(_, _), _) => false, + (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), + (FixedSizeList(_), _) => false, + (List(v1), List(v2)) => v1.eq(v2), + (List(_), _) => false, + (LargeList(v1), LargeList(v2)) => v1.eq(v2), + (LargeList(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -210,29 +271,19 @@ impl PartialEq for ScalarValue { (TimestampMicrosecond(_, _), _) => false, (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), (TimestampNanosecond(_, _), _) => false, + (DurationSecond(v1), DurationSecond(v2)) => v1.eq(v2), + (DurationSecond(_), _) => false, + (DurationMillisecond(v1), DurationMillisecond(v2)) => v1.eq(v2), + (DurationMillisecond(_), _) => false, + (DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.eq(v2), + (DurationMicrosecond(_), _) => false, + (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.eq(v2), + (DurationNanosecond(_), _) => false, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), - (IntervalYearMonth(v1), IntervalDayTime(v2)) => { - ym_to_milli(v1).eq(&dt_to_milli(v2)) - } - (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { - ym_to_nano(v1).eq(&mdn_to_nano(v2)) - } (IntervalYearMonth(_), _) => false, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), - (IntervalDayTime(v1), IntervalYearMonth(v2)) => { - dt_to_milli(v1).eq(&ym_to_milli(v2)) - } - (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { - dt_to_nano(v1).eq(&mdn_to_nano(v2)) - } (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), - (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { - mdn_to_nano(v1).eq(&ym_to_nano(v2)) - } - (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { - mdn_to_nano(v1).eq(&dt_to_nano(v2)) - } (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, @@ -261,6 +312,15 @@ impl PartialOrd for ScalarValue { } } (Decimal128(_, _, _), _) => None, + (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal256(_, _, _), _) => None, (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), (Boolean(_), _) => None, (Float32(v1), Float32(v2)) => match (v1, v2) { @@ -299,1158 +359,111 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(v1, t1), List(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (List(_, _), _) => None, - (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), - (Date32(_), _) => None, - (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), - (Date64(_), _) => None, - (Time32Second(v1), Time32Second(v2)) => v1.partial_cmp(v2), - (Time32Second(_), _) => None, - (Time32Millisecond(v1), Time32Millisecond(v2)) => v1.partial_cmp(v2), - (Time32Millisecond(_), _) => None, - (Time64Microsecond(v1), Time64Microsecond(v2)) => v1.partial_cmp(v2), - (Time64Microsecond(_), _) => None, - (Time64Nanosecond(v1), Time64Nanosecond(v2)) => v1.partial_cmp(v2), - (Time64Nanosecond(_), _) => None, - (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), - (TimestampSecond(_, _), _) => None, - (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMillisecond(_, _), _) => None, - (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMicrosecond(_, _), _) => None, - (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampNanosecond(_, _), _) => None, - (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), - (IntervalYearMonth(v1), IntervalDayTime(v2)) => { - ym_to_milli(v1).partial_cmp(&dt_to_milli(v2)) - } - (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { - ym_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) - } - (IntervalYearMonth(_), _) => None, - (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), - (IntervalDayTime(v1), IntervalYearMonth(v2)) => { - dt_to_milli(v1).partial_cmp(&ym_to_milli(v2)) - } - (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { - dt_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) - } - (IntervalDayTime(_), _) => None, - (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), - (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { - mdn_to_nano(v1).partial_cmp(&ym_to_nano(v2)) - } - (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { - mdn_to_nano(v1).partial_cmp(&dt_to_nano(v2)) - } - (IntervalMonthDayNano(_), _) => None, - (Struct(v1, t1), Struct(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None + (List(arr1), List(arr2)) + | (FixedSizeList(arr1), FixedSizeList(arr2)) + | (LargeList(arr1), LargeList(arr2)) => { + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + assert_eq!(arr1.len(), 1); + assert_eq!(arr2.len(), 1); + + if arr1.data_type() != arr2.data_type() { + return None; } - } - (Struct(_, _), _) => None, - (Dictionary(k1, v1), Dictionary(k2, v2)) => { - // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { - v1.partial_cmp(v2) - } else { - None - } - } - (Dictionary(_, _), _) => None, - (Null, Null) => Some(Ordering::Equal), - (Null, _) => None, - } - } -} -/// This function computes the duration (in milliseconds) of the given -/// year-month-interval. -#[inline] -pub fn ym_to_sec(val: &Option) -> Option { - val.map(|value| (value as i64) * SECS_IN_ONE_MONTH) -} - -/// This function computes the duration (in milliseconds) of the given -/// year-month-interval. -#[inline] -pub fn ym_to_milli(val: &Option) -> Option { - val.map(|value| (value as i64) * MILLISECS_IN_ONE_MONTH) -} - -/// This function computes the duration (in milliseconds) of the given -/// year-month-interval. -#[inline] -pub fn ym_to_micro(val: &Option) -> Option { - val.map(|value| (value as i64) * MICROSECS_IN_ONE_MONTH) -} - -/// This function computes the duration (in nanoseconds) of the given -/// year-month-interval. -#[inline] -pub fn ym_to_nano(val: &Option) -> Option { - val.map(|value| (value as i128) * NANOSECS_IN_ONE_MONTH) -} - -/// This function computes the duration (in seconds) of the given -/// daytime-interval. -#[inline] -pub fn dt_to_sec(val: &Option) -> Option { - val.map(|val| { - let (days, millis) = IntervalDayTimeType::to_parts(val); - (days as i64) * MILLISECS_IN_ONE_DAY + (millis as i64 / 1_000) - }) -} - -/// This function computes the duration (in milliseconds) of the given -/// daytime-interval. -#[inline] -pub fn dt_to_milli(val: &Option) -> Option { - val.map(|val| { - let (days, millis) = IntervalDayTimeType::to_parts(val); - (days as i64) * MILLISECS_IN_ONE_DAY + (millis as i64) - }) -} - -/// This function computes the duration (in microseconds) of the given -/// daytime-interval. -#[inline] -pub fn dt_to_micro(val: &Option) -> Option { - val.map(|val| { - let (days, millis) = IntervalDayTimeType::to_parts(val); - (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + (millis as i128) * 1_000 - }) -} - -/// This function computes the duration (in nanoseconds) of the given -/// daytime-interval. -#[inline] -pub fn dt_to_nano(val: &Option) -> Option { - val.map(|val| { - let (days, millis) = IntervalDayTimeType::to_parts(val); - (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + (millis as i128) * 1_000_000 - }) -} - -/// This function computes the duration (in seconds) of the given -/// month-day-nano-interval. Assumes a month is 30 days long. -#[inline] -pub fn mdn_to_sec(val: &Option) -> Option { - val.map(|val| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); - (months as i128) * NANOSECS_IN_ONE_MONTH - + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) - + (nanos as i128) / 1_000_000_000 - }) -} - -/// This function computes the duration (in milliseconds) of the given -/// month-day-nano-interval. Assumes a month is 30 days long. -#[inline] -pub fn mdn_to_milli(val: &Option) -> Option { - val.map(|val| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); - (months as i128) * NANOSECS_IN_ONE_MONTH - + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) - + (nanos as i128) / 1_000_000 - }) -} - -/// This function computes the duration (in microseconds) of the given -/// month-day-nano-interval. Assumes a month is 30 days long. -#[inline] -pub fn mdn_to_micro(val: &Option) -> Option { - val.map(|val| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); - (months as i128) * NANOSECS_IN_ONE_MONTH - + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) - + (nanos as i128) / 1_000 - }) -} - -/// This function computes the duration (in nanoseconds) of the given -/// month-day-nano-interval. Assumes a month is 30 days long. -#[inline] -pub fn mdn_to_nano(val: &Option) -> Option { - val.map(|val| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); - (months as i128) * NANOSECS_IN_ONE_MONTH - + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) - + (nanos as i128) - }) -} - -impl Eq for ScalarValue {} - -// TODO implement this in arrow-rs with simd -// https://github.com/apache/arrow-rs/issues/1010 -macro_rules! decimal_op { - ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt) => {{ - let (difference, side) = if $LHS_SCALE > $RHS_SCALE { - ($LHS_SCALE - $RHS_SCALE, true) - } else { - ($RHS_SCALE - $LHS_SCALE, false) - }; - let scale = max($LHS_SCALE, $RHS_SCALE); - Ok(match ($LHS, $RHS, difference) { - (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale), - (lhs, None, 0) => ScalarValue::Decimal128(*lhs, $PRECISION, scale), - (Some(lhs_value), None, _) => { - let mut new_value = *lhs_value; - if !side { - new_value *= 10_i128.pow(difference as u32) - } - ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) - } - (None, Some(rhs_value), 0) => { - let value = decimal_right!(*rhs_value, $OPERATION); - ScalarValue::Decimal128(Some(value), $PRECISION, scale) - } - (None, Some(rhs_value), _) => { - let mut new_value = decimal_right!(*rhs_value, $OPERATION); - if side { - new_value *= 10_i128.pow(difference as u32) - }; - ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) - } - (Some(lhs_value), Some(rhs_value), 0) => { - decimal_binary_op!(lhs_value, rhs_value, $OPERATION, $PRECISION, scale) - } - (Some(lhs_value), Some(rhs_value), _) => { - let (left_arg, right_arg) = if side { - (*lhs_value, rhs_value * 10_i128.pow(difference as u32)) - } else { - (lhs_value * 10_i128.pow(difference as u32), *rhs_value) - }; - decimal_binary_op!(left_arg, right_arg, $OPERATION, $PRECISION, scale) - } - }) - }}; -} - -macro_rules! decimal_binary_op { - ($LHS:expr, $RHS:expr, $OPERATION:tt, $PRECISION:expr, $SCALE:expr) => { - // TODO: This simple implementation loses precision for calculations like - // multiplication and division. Improve this implementation for such - // operations. - ScalarValue::Decimal128(Some($LHS $OPERATION $RHS), $PRECISION, $SCALE) - }; -} - -macro_rules! decimal_right { - ($TERM:expr, +) => { - $TERM - }; - ($TERM:expr, *) => { - $TERM - }; - ($TERM:expr, -) => { - -$TERM - }; - ($TERM:expr, /) => { - Err(DataFusionError::NotImplemented(format!( - "Decimal reciprocation not yet supported", - ))) - }; -} - -// Returns the result of applying operation to two scalar values. -macro_rules! primitive_op { - ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $OPERATION:tt) => { - match ($LEFT, $RIGHT) { - (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)), - #[allow(unused_variables)] - (None, Some(b)) => { primitive_right!(*b, $OPERATION, $SCALAR) }, - (Some(a), Some(b)) => Ok(ScalarValue::$SCALAR(Some(*a $OPERATION *b))), - } - }; -} -macro_rules! primitive_checked_op { - ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $FUNCTION:ident, $OPERATION:tt) => { - match ($LEFT, $RIGHT) { - (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)), - #[allow(unused_variables)] - (None, Some(b)) => { - primitive_checked_right!(*b, $OPERATION, $SCALAR) - } - (Some(a), Some(b)) => { - if let Some(value) = (*a).$FUNCTION(*b) { - Ok(ScalarValue::$SCALAR(Some(value))) - } else { - Err(DataFusionError::Execution( - "Overflow while calculating ScalarValue.".to_string(), - )) + fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } } - } - } - }; -} - -macro_rules! primitive_checked_right { - ($TERM:expr, -, $SCALAR:ident) => { - if let Some(value) = $TERM.checked_neg() { - Ok(ScalarValue::$SCALAR(Some(value))) - } else { - Err(DataFusionError::Execution( - "Overflow while calculating ScalarValue.".to_string(), - )) - } - }; - ($TERM:expr, $OPERATION:tt, $SCALAR:ident) => { - primitive_right!($TERM, $OPERATION, $SCALAR) - }; -} - -macro_rules! primitive_right { - ($TERM:expr, +, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, *, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, -, UInt64) => { - unsigned_subtraction_error!("UInt64") - }; - ($TERM:expr, -, UInt32) => { - unsigned_subtraction_error!("UInt32") - }; - ($TERM:expr, -, UInt16) => { - unsigned_subtraction_error!("UInt16") - }; - ($TERM:expr, -, UInt8) => { - unsigned_subtraction_error!("UInt8") - }; - ($TERM:expr, -, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some(-$TERM))) - }; - ($TERM:expr, /, Float64) => { - Ok(ScalarValue::$SCALAR(Some($TERM.recip()))) - }; - ($TERM:expr, /, Float32) => { - Ok(ScalarValue::$SCALAR(Some($TERM.recip()))) - }; - ($TERM:expr, /, $SCALAR:ident) => { - Err(DataFusionError::Internal(format!( - "Can not divide an uninitialized value to a non-floating point value", - ))) - }; - ($TERM:expr, &, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, |, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, ^, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, &&, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; - ($TERM:expr, ||, $SCALAR:ident) => { - Ok(ScalarValue::$SCALAR(Some($TERM))) - }; -} - -macro_rules! unsigned_subtraction_error { - ($SCALAR:expr) => {{ - let msg = format!( - "Can not subtract a {} value from an uninitialized value", - $SCALAR - ); - Err(DataFusionError::Internal(msg)) - }}; -} - -macro_rules! impl_checked_op { - ($LHS:expr, $RHS:expr, $FUNCTION:ident, $OPERATION:tt) => { - // Only covering primitive types that support checked_* operands, and fall back to raw operation for other types. - match ($LHS, $RHS) { - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - primitive_checked_op!(lhs, rhs, UInt64, $FUNCTION, $OPERATION) - }, - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - primitive_checked_op!(lhs, rhs, Int64, $FUNCTION, $OPERATION) - }, - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - primitive_checked_op!(lhs, rhs, UInt32, $FUNCTION, $OPERATION) - }, - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - primitive_checked_op!(lhs, rhs, Int32, $FUNCTION, $OPERATION) - }, - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - primitive_checked_op!(lhs, rhs, UInt16, $FUNCTION, $OPERATION) - }, - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - primitive_checked_op!(lhs, rhs, Int16, $FUNCTION, $OPERATION) - }, - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - primitive_checked_op!(lhs, rhs, UInt8, $FUNCTION, $OPERATION) - }, - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - primitive_checked_op!(lhs, rhs, Int8, $FUNCTION, $OPERATION) - }, - _ => { - impl_op!($LHS, $RHS, $OPERATION) - } - } - }; -} - -macro_rules! impl_op { - ($LHS:expr, $RHS:expr, +) => { - impl_op_arithmetic!($LHS, $RHS, +) - }; - ($LHS:expr, $RHS:expr, -) => { - match ($LHS, $RHS) { - ( - ScalarValue::TimestampSecond(Some(ts_lhs), tz_lhs), - ScalarValue::TimestampSecond(Some(ts_rhs), tz_rhs), - ) => { - let err = || { - DataFusionError::Execution( - "Overflow while converting seconds to milliseconds".to_string(), - ) - }; - ts_sub_to_interval::( - ts_lhs.checked_mul(1_000).ok_or_else(err)?, - ts_rhs.checked_mul(1_000).ok_or_else(err)?, - tz_lhs.as_deref(), - tz_rhs.as_deref(), - ) - }, - ( - ScalarValue::TimestampMillisecond(Some(ts_lhs), tz_lhs), - ScalarValue::TimestampMillisecond(Some(ts_rhs), tz_rhs), - ) => ts_sub_to_interval::( - *ts_lhs, - *ts_rhs, - tz_lhs.as_deref(), - tz_rhs.as_deref(), - ), - ( - ScalarValue::TimestampMicrosecond(Some(ts_lhs), tz_lhs), - ScalarValue::TimestampMicrosecond(Some(ts_rhs), tz_rhs), - ) => { - let err = || { - DataFusionError::Execution( - "Overflow while converting microseconds to nanoseconds".to_string(), - ) - }; - ts_sub_to_interval::( - ts_lhs.checked_mul(1_000).ok_or_else(err)?, - ts_rhs.checked_mul(1_000).ok_or_else(err)?, - tz_lhs.as_deref(), - tz_rhs.as_deref(), - ) - }, - ( - ScalarValue::TimestampNanosecond(Some(ts_lhs), tz_lhs), - ScalarValue::TimestampNanosecond(Some(ts_rhs), tz_rhs), - ) => ts_sub_to_interval::( - *ts_lhs, - *ts_rhs, - tz_lhs.as_deref(), - tz_rhs.as_deref(), - ), - _ => impl_op_arithmetic!($LHS, $RHS, -) - } - }; - ($LHS:expr, $RHS:expr, &) => { - impl_bit_op_arithmetic!($LHS, $RHS, &) - }; - ($LHS:expr, $RHS:expr, |) => { - impl_bit_op_arithmetic!($LHS, $RHS, |) - }; - ($LHS:expr, $RHS:expr, ^) => { - impl_bit_op_arithmetic!($LHS, $RHS, ^) - }; - ($LHS:expr, $RHS:expr, &&) => { - impl_bool_op_arithmetic!($LHS, $RHS, &&) - }; - ($LHS:expr, $RHS:expr, ||) => { - impl_bool_op_arithmetic!($LHS, $RHS, ||) - }; -} -macro_rules! impl_bit_op_arithmetic { - ($LHS:expr, $RHS:expr, $OPERATION:tt) => { - match ($LHS, $RHS) { - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - primitive_op!(lhs, rhs, UInt64, $OPERATION) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - primitive_op!(lhs, rhs, Int64, $OPERATION) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - primitive_op!(lhs, rhs, UInt32, $OPERATION) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - primitive_op!(lhs, rhs, Int32, $OPERATION) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - primitive_op!(lhs, rhs, UInt16, $OPERATION) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - primitive_op!(lhs, rhs, Int16, $OPERATION) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - primitive_op!(lhs, rhs, UInt8, $OPERATION) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - primitive_op!(lhs, rhs, Int8, $OPERATION) - } - _ => Err(DataFusionError::Internal(format!( - "Operator {} is not implemented for types {:?} and {:?}", - stringify!($OPERATION), - $LHS, - $RHS - ))), - } - }; -} - -macro_rules! impl_bool_op_arithmetic { - ($LHS:expr, $RHS:expr, $OPERATION:tt) => { - match ($LHS, $RHS) { - (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { - primitive_op!(lhs, rhs, Boolean, $OPERATION) - } - _ => Err(DataFusionError::Internal(format!( - "Operator {} is not implemented for types {:?} and {:?}", - stringify!($OPERATION), - $LHS, - $RHS - ))), - } - }; -} - -macro_rules! impl_op_arithmetic { - ($LHS:expr, $RHS:expr, $OPERATION:tt) => { - match ($LHS, $RHS) { - // Binary operations on arguments with the same type: - ( - ScalarValue::Decimal128(v1, p1, s1), - ScalarValue::Decimal128(v2, p2, s2), - ) => { - decimal_op!(v1, v2, *p1.max(p2), *s1, *s2, $OPERATION) - } - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - primitive_op!(lhs, rhs, Float64, $OPERATION) - } - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - primitive_op!(lhs, rhs, Float32, $OPERATION) - } - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - primitive_op!(lhs, rhs, UInt64, $OPERATION) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - primitive_op!(lhs, rhs, Int64, $OPERATION) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - primitive_op!(lhs, rhs, UInt32, $OPERATION) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - primitive_op!(lhs, rhs, Int32, $OPERATION) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - primitive_op!(lhs, rhs, UInt16, $OPERATION) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - primitive_op!(lhs, rhs, Int16, $OPERATION) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - primitive_op!(lhs, rhs, UInt8, $OPERATION) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - primitive_op!(lhs, rhs, Int8, $OPERATION) - } - ( - ScalarValue::IntervalYearMonth(Some(lhs)), - ScalarValue::IntervalYearMonth(Some(rhs)), - ) => Ok(ScalarValue::IntervalYearMonth(Some(op_ym( - *lhs, - *rhs, - get_sign!($OPERATION), - )))), - ( - ScalarValue::IntervalDayTime(Some(lhs)), - ScalarValue::IntervalDayTime(Some(rhs)), - ) => Ok(ScalarValue::IntervalDayTime(Some(op_dt( - *lhs, - *rhs, - get_sign!($OPERATION), - )))), - ( - ScalarValue::IntervalMonthDayNano(Some(lhs)), - ScalarValue::IntervalMonthDayNano(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_mdn( - *lhs, - *rhs, - get_sign!($OPERATION), - )))), - // Binary operations on arguments with different types: - (ScalarValue::Date32(Some(days)), _) => { - let value = date32_op(*days, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::Date32(Some(value))) - } - (ScalarValue::Date64(Some(ms)), _) => { - let value = date64_op(*ms, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::Date64(Some(value))) - } - (ScalarValue::TimestampSecond(Some(ts_s), zone), _) => { - let value = seconds_add(*ts_s, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampSecond(Some(value), zone.clone())) - } - (_, ScalarValue::TimestampSecond(Some(ts_s), zone)) => { - let value = seconds_add(*ts_s, $LHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampSecond(Some(value), zone.clone())) - } - (ScalarValue::TimestampMillisecond(Some(ts_ms), zone), _) => { - let value = milliseconds_add(*ts_ms, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampMillisecond(Some(value), zone.clone())) - } - (_, ScalarValue::TimestampMillisecond(Some(ts_ms), zone)) => { - let value = milliseconds_add(*ts_ms, $LHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampMillisecond(Some(value), zone.clone())) - } - (ScalarValue::TimestampMicrosecond(Some(ts_us), zone), _) => { - let value = microseconds_add(*ts_us, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampMicrosecond(Some(value), zone.clone())) - } - (_, ScalarValue::TimestampMicrosecond(Some(ts_us), zone)) => { - let value = microseconds_add(*ts_us, $LHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampMicrosecond(Some(value), zone.clone())) - } - (ScalarValue::TimestampNanosecond(Some(ts_ns), zone), _) => { - let value = nanoseconds_add(*ts_ns, $RHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) - } - (_, ScalarValue::TimestampNanosecond(Some(ts_ns), zone)) => { - let value = nanoseconds_add(*ts_ns, $LHS, get_sign!($OPERATION))?; - Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) - } - ( - ScalarValue::IntervalYearMonth(Some(lhs)), - ScalarValue::IntervalDayTime(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_ym_dt( - *lhs, - *rhs, - get_sign!($OPERATION), - false, - )))), - ( - ScalarValue::IntervalYearMonth(Some(lhs)), - ScalarValue::IntervalMonthDayNano(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_ym_mdn( - *lhs, - *rhs, - get_sign!($OPERATION), - false, - )))), - ( - ScalarValue::IntervalDayTime(Some(lhs)), - ScalarValue::IntervalYearMonth(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_ym_dt( - *rhs, - *lhs, - get_sign!($OPERATION), - true, - )))), - ( - ScalarValue::IntervalDayTime(Some(lhs)), - ScalarValue::IntervalMonthDayNano(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_dt_mdn( - *lhs, - *rhs, - get_sign!($OPERATION), - false, - )))), - ( - ScalarValue::IntervalMonthDayNano(Some(lhs)), - ScalarValue::IntervalYearMonth(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_ym_mdn( - *rhs, - *lhs, - get_sign!($OPERATION), - true, - )))), - ( - ScalarValue::IntervalMonthDayNano(Some(lhs)), - ScalarValue::IntervalDayTime(Some(rhs)), - ) => Ok(ScalarValue::IntervalMonthDayNano(Some(op_dt_mdn( - *rhs, - *lhs, - get_sign!($OPERATION), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "Operator {} is not implemented for types {:?} and {:?}", - stringify!($OPERATION), - $LHS, - $RHS - ))), - } - }; -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different -/// types ([`IntervalYearMonthType`] and [`IntervalDayTimeType`], respectively). -/// The argument `sign` chooses between addition and subtraction, the argument -/// `commute` swaps `lhs` and `rhs`. The return value is an 128-bit integer. -/// It can be involved in a [`IntervalMonthDayNanoType`] in the outer scope. -#[inline] -pub fn op_ym_dt(mut lhs: i32, rhs: i64, sign: i32, commute: bool) -> i128 { - let (mut days, millis) = IntervalDayTimeType::to_parts(rhs); - let mut nanos = (millis as i64) * 1_000_000; - if commute { - lhs *= sign; - } else { - days *= sign; - nanos *= sign as i64; - }; - IntervalMonthDayNanoType::make_value(lhs, days, nanos) -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different -/// types ([`IntervalYearMonthType`] and [`IntervalMonthDayNanoType`], respectively). -/// The argument `sign` chooses between addition and subtraction, the argument -/// `commute` swaps `lhs` and `rhs`. The return value is an 128-bit integer. -/// It can be involved in a [`IntervalMonthDayNanoType`] in the outer scope. -#[inline] -pub fn op_ym_mdn(lhs: i32, rhs: i128, sign: i32, commute: bool) -> i128 { - let (mut months, mut days, mut nanos) = IntervalMonthDayNanoType::to_parts(rhs); - if commute { - months += lhs * sign; - } else { - months = lhs + (months * sign); - days *= sign; - nanos *= sign as i64; - } - IntervalMonthDayNanoType::make_value(months, days, nanos) -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different -/// types ([`IntervalDayTimeType`] and [`IntervalMonthDayNanoType`], respectively). -/// The argument `sign` chooses between addition and subtraction, the argument -/// `commute` swaps `lhs` and `rhs`. The return value is an 128-bit integer. -/// It can be involved in a [`IntervalMonthDayNanoType`] in the outer scope. -#[inline] -pub fn op_dt_mdn(lhs: i64, rhs: i128, sign: i32, commute: bool) -> i128 { - let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(lhs); - let (rhs_months, rhs_days, rhs_nanos) = IntervalMonthDayNanoType::to_parts(rhs); - if commute { - IntervalMonthDayNanoType::make_value( - rhs_months, - lhs_days * sign + rhs_days, - (lhs_millis * sign) as i64 * 1_000_000 + rhs_nanos, - ) - } else { - IntervalMonthDayNanoType::make_value( - rhs_months * sign, - lhs_days + rhs_days * sign, - (lhs_millis as i64) * 1_000_000 + rhs_nanos * (sign as i64), - ) - } -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of -/// the same type [`IntervalYearMonthType`]. The argument `sign` chooses between -/// addition and subtraction. The return value is an 32-bit integer. It can be -/// involved in a [`IntervalYearMonthType`] in the outer scope. -#[inline] -pub fn op_ym(lhs: i32, rhs: i32, sign: i32) -> i32 { - lhs + rhs * sign -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of -/// the same type [`IntervalDayTimeType`]. The argument `sign` chooses between -/// addition and subtraction. The return value is an 64-bit integer. It can be -/// involved in a [`IntervalDayTimeType`] in the outer scope. -#[inline] -pub fn op_dt(lhs: i64, rhs: i64, sign: i32) -> i64 { - let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(lhs); - let (rhs_days, rhs_millis) = IntervalDayTimeType::to_parts(rhs); - IntervalDayTimeType::make_value( - lhs_days + rhs_days * sign, - lhs_millis + rhs_millis * sign, - ) -} - -/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of -/// the same type [`IntervalMonthDayNanoType`]. The argument `sign` chooses between -/// addition and subtraction. The return value is an 128-bit integer. It can be -/// involved in a [`IntervalMonthDayNanoType`] in the outer scope. -#[inline] -pub fn op_mdn(lhs: i128, rhs: i128, sign: i32) -> i128 { - let (lhs_months, lhs_days, lhs_nanos) = IntervalMonthDayNanoType::to_parts(lhs); - let (rhs_months, rhs_days, rhs_nanos) = IntervalMonthDayNanoType::to_parts(rhs); - IntervalMonthDayNanoType::make_value( - lhs_months + rhs_months * sign, - lhs_days + rhs_days * sign, - lhs_nanos + rhs_nanos * (sign as i64), - ) -} - -macro_rules! get_sign { - (+) => { - 1 - }; - (-) => { - -1 - }; -} - -pub const YM_MODE: i8 = 0; -pub const DT_MODE: i8 = 1; -pub const MDN_MODE: i8 = 2; - -pub const MILLISECOND_MODE: bool = false; -pub const NANOSECOND_MODE: bool = true; -/// This function computes subtracts `rhs_ts` from `lhs_ts`, taking timezones -/// into account when given. Units of the resulting interval is specified by -/// the constant `TIME_MODE`. -/// The default behavior of Datafusion is the following: -/// - When subtracting timestamps at seconds/milliseconds precision, the output -/// interval will have the type [`IntervalDayTimeType`]. -/// - When subtracting timestamps at microseconds/nanoseconds precision, the -/// output interval will have the type [`IntervalMonthDayNanoType`]. -fn ts_sub_to_interval( - lhs_ts: i64, - rhs_ts: i64, - lhs_tz: Option<&str>, - rhs_tz: Option<&str>, -) -> Result { - let parsed_lhs_tz = parse_timezones(lhs_tz)?; - let parsed_rhs_tz = parse_timezones(rhs_tz)?; - - let (naive_lhs, naive_rhs) = - calculate_naives::(lhs_ts, parsed_lhs_tz, rhs_ts, parsed_rhs_tz)?; - let delta_secs = naive_lhs.signed_duration_since(naive_rhs); - - match TIME_MODE { - MILLISECOND_MODE => { - let as_millisecs = delta_secs.num_milliseconds(); - Ok(ScalarValue::new_interval_dt( - (as_millisecs / MILLISECS_IN_ONE_DAY) as i32, - (as_millisecs % MILLISECS_IN_ONE_DAY) as i32, - )) - } - NANOSECOND_MODE => { - let as_nanosecs = delta_secs.num_nanoseconds().ok_or_else(|| { - DataFusionError::Execution(String::from( - "Can not compute timestamp differences with nanosecond precision", - )) - })?; - Ok(ScalarValue::new_interval_mdn( - 0, - (as_nanosecs / NANOSECS_IN_ONE_DAY) as i32, - as_nanosecs % NANOSECS_IN_ONE_DAY, - )) - } - } -} - -/// This function parses the timezone from string to Tz. -/// If it cannot parse or timezone field is [`None`], it returns [`None`]. -pub fn parse_timezones(tz: Option<&str>) -> Result> { - if let Some(tz) = tz { - let parsed_tz: Tz = tz.parse().map_err(|_| { - DataFusionError::Execution("cannot parse given timezone".to_string()) - })?; - Ok(Some(parsed_tz)) - } else { - Ok(None) - } -} - -/// This function takes two timestamps with an optional timezone, -/// and returns the duration between them. If one of the timestamps -/// has a [`None`] timezone, the other one is also treated as having [`None`]. -pub fn calculate_naives( - lhs_ts: i64, - parsed_lhs_tz: Option, - rhs_ts: i64, - parsed_rhs_tz: Option, -) -> Result<(NaiveDateTime, NaiveDateTime)> { - let err = || { - DataFusionError::Execution(String::from( - "error while converting Int64 to DateTime in timestamp subtraction", - )) - }; - match (parsed_lhs_tz, parsed_rhs_tz, TIME_MODE) { - (Some(lhs_tz), Some(rhs_tz), MILLISECOND_MODE) => { - let lhs = arrow_array::temporal_conversions::as_datetime_with_timezone::< - arrow_array::types::TimestampMillisecondType, - >(lhs_ts, rhs_tz) - .ok_or_else(err)? - .naive_local(); - let rhs = arrow_array::temporal_conversions::as_datetime_with_timezone::< - arrow_array::types::TimestampMillisecondType, - >(rhs_ts, lhs_tz) - .ok_or_else(err)? - .naive_local(); - Ok((lhs, rhs)) - } - (Some(lhs_tz), Some(rhs_tz), NANOSECOND_MODE) => { - let lhs = arrow_array::temporal_conversions::as_datetime_with_timezone::< - arrow_array::types::TimestampNanosecondType, - >(lhs_ts, rhs_tz) - .ok_or_else(err)? - .naive_local(); - let rhs = arrow_array::temporal_conversions::as_datetime_with_timezone::< - arrow_array::types::TimestampNanosecondType, - >(rhs_ts, lhs_tz) - .ok_or_else(err)? - .naive_local(); - Ok((lhs, rhs)) - } - (_, _, MILLISECOND_MODE) => { - let lhs = arrow_array::temporal_conversions::as_datetime::< - arrow_array::types::TimestampMillisecondType, - >(lhs_ts) - .ok_or_else(err)?; - let rhs = arrow_array::temporal_conversions::as_datetime::< - arrow_array::types::TimestampMillisecondType, - >(rhs_ts) - .ok_or_else(err)?; - Ok((lhs, rhs)) - } - (_, _, NANOSECOND_MODE) => { - let lhs = arrow_array::temporal_conversions::as_datetime::< - arrow_array::types::TimestampNanosecondType, - >(lhs_ts) - .ok_or_else(err)?; - let rhs = arrow_array::temporal_conversions::as_datetime::< - arrow_array::types::TimestampNanosecondType, - >(rhs_ts) - .ok_or_else(err)?; - Ok((lhs, rhs)) - } - } -} - -#[inline] -pub fn date32_op(days: i32, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let prior = epoch.add(Duration::days(days as i64)); - do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_days() as i32) -} - -#[inline] -pub fn date64_op(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let prior = epoch.add(Duration::milliseconds(ms)); - do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_milliseconds()) -} - -#[inline] -pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - do_date_time_math(ts_s, 0, scalar, sign).map(|dt| dt.timestamp()) -} - -#[inline] -pub fn seconds_add_array( - ts_s: i64, - interval: i128, - sign: i32, -) -> Result { - do_date_time_math_array::(ts_s, 0, interval, sign) - .map(|dt| dt.timestamp()) -} - -#[inline] -pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ms.div_euclid(1000); - let nsecs = ts_ms.rem_euclid(1000) * 1_000_000; - do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_millis()) -} - -#[inline] -pub fn milliseconds_add_array( - ts_ms: i64, - interval: i128, - sign: i32, -) -> Result { - let secs = ts_ms.div_euclid(1000); - let nsecs = ts_ms.rem_euclid(1000) * 1_000_000; - do_date_time_math_array::(secs, nsecs as u32, interval, sign) - .map(|dt| dt.timestamp_millis()) -} - -#[inline] -pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_us.div_euclid(1_000_000); - let nsecs = ts_us.rem_euclid(1_000_000) * 1_000; - do_date_time_math(secs, nsecs as u32, scalar, sign) - .map(|dt| dt.timestamp_nanos() / 1000) -} - -#[inline] -pub fn microseconds_add_array( - ts_us: i64, - interval: i128, - sign: i32, -) -> Result { - let secs = ts_us.div_euclid(1_000_000); - let nsecs = ts_us.rem_euclid(1_000_000) * 1_000; - do_date_time_math_array::(secs, nsecs as u32, interval, sign) - .map(|dt| dt.timestamp_nanos() / 1000) -} - -#[inline] -pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ns.div_euclid(1_000_000_000); - let nsecs = ts_ns.rem_euclid(1_000_000_000); - do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_nanos()) -} - -#[inline] -pub fn nanoseconds_add_array( - ts_ns: i64, - interval: i128, - sign: i32, -) -> Result { - let secs = ts_ns.div_euclid(1_000_000_000); - let nsecs = ts_ns.rem_euclid(1_000_000_000); - do_date_time_math_array::(secs, nsecs as u32, interval, sign) - .map(|dt| dt.timestamp_nanos()) -} - -#[inline] -pub fn seconds_sub(ts_lhs: i64, ts_rhs: i64) -> i64 { - let diff_ms = (ts_lhs - ts_rhs) * 1000; - let days = (diff_ms / MILLISECS_IN_ONE_DAY) as i32; - let millis = (diff_ms % MILLISECS_IN_ONE_DAY) as i32; - IntervalDayTimeType::make_value(days, millis) -} -#[inline] -pub fn milliseconds_sub(ts_lhs: i64, ts_rhs: i64) -> i64 { - let diff_ms = ts_lhs - ts_rhs; - let days = (diff_ms / MILLISECS_IN_ONE_DAY) as i32; - let millis = (diff_ms % MILLISECS_IN_ONE_DAY) as i32; - IntervalDayTimeType::make_value(days, millis) -} -#[inline] -pub fn microseconds_sub(ts_lhs: i64, ts_rhs: i64) -> i128 { - let diff_ns = (ts_lhs - ts_rhs) * 1000; - let days = (diff_ns / NANOSECS_IN_ONE_DAY) as i32; - let nanos = diff_ns % NANOSECS_IN_ONE_DAY; - IntervalMonthDayNanoType::make_value(0, days, nanos) -} -#[inline] -pub fn nanoseconds_sub(ts_lhs: i64, ts_rhs: i64) -> i128 { - let diff_ns = ts_lhs - ts_rhs; - let days = (diff_ns / NANOSECS_IN_ONE_DAY) as i32; - let nanos = diff_ns % NANOSECS_IN_ONE_DAY; - IntervalMonthDayNanoType::make_value(0, days, nanos) -} - -#[inline] -fn do_date_time_math( - secs: i64, - nsecs: u32, - scalar: &ScalarValue, - sign: i32, -) -> Result { - let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| { - DataFusionError::Internal(format!( - "Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs} scalar {scalar:?} sign {sign}" - )) - })?; - do_date_math(prior, scalar, sign) -} + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); -#[inline] -fn do_date_time_math_array( - secs: i64, - nsecs: u32, - interval: i128, - sign: i32, -) -> Result { - let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| { - DataFusionError::Internal(format!( - "Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs}" - )) - })?; - do_date_math_array::<_, INTERVAL_MODE>(prior, interval, sign) -} + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; -fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result -where - D: Datelike + Add, -{ - Ok(match scalar { - ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), - ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i, sign), - ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), - other => Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {other:?}" - )))?, - }) -} + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } -fn do_date_math_array( - prior: D, - interval: i128, - sign: i32, -) -> Result -where - D: Datelike + Add, -{ - Ok(match INTERVAL_MODE { - YM_MODE => shift_months(prior, interval as i32, sign), - DT_MODE => add_day_time(prior, interval as i64, sign), - MDN_MODE => add_m_d_nano(prior, interval, sign), - _ => { - return Err(DataFusionError::Internal( - "Undefined interval mode for interval calculations".to_string(), - )); + Some(Ordering::Equal) + } + (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), + (Date64(_), _) => None, + (Time32Second(v1), Time32Second(v2)) => v1.partial_cmp(v2), + (Time32Second(_), _) => None, + (Time32Millisecond(v1), Time32Millisecond(v2)) => v1.partial_cmp(v2), + (Time32Millisecond(_), _) => None, + (Time64Microsecond(v1), Time64Microsecond(v2)) => v1.partial_cmp(v2), + (Time64Microsecond(_), _) => None, + (Time64Nanosecond(v1), Time64Nanosecond(v2)) => v1.partial_cmp(v2), + (Time64Nanosecond(_), _) => None, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), + (TimestampSecond(_, _), _) => None, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMillisecond(_, _), _) => None, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMicrosecond(_, _), _) => None, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampNanosecond(_, _), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(_), _) => None, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(_), _) => None, + (DurationSecond(v1), DurationSecond(v2)) => v1.partial_cmp(v2), + (DurationSecond(_), _) => None, + (DurationMillisecond(v1), DurationMillisecond(v2)) => v1.partial_cmp(v2), + (DurationMillisecond(_), _) => None, + (DurationMicrosecond(v1), DurationMicrosecond(v2)) => v1.partial_cmp(v2), + (DurationMicrosecond(_), _) => None, + (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), + (DurationNanosecond(_), _) => None, + (Struct(v1, t1), Struct(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Struct(_, _), _) => None, + (Dictionary(k1, v1), Dictionary(k2, v2)) => { + // Don't compare if the key types don't match (it is effectively a different datatype) + if k1 == k2 { + v1.partial_cmp(v2) + } else { + None + } + } + (Dictionary(_, _), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, } - }) -} - -// Can remove once chrono:0.4.23 is released -pub fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D -where - D: Datelike + Add, -{ - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); - let months = months * sign; - let days = days * sign; - let nanos = nanos * sign as i64; - let a = shift_months(prior, months, 1); - let b = a.add(Duration::days(days as i64)); - b.add(Duration::nanoseconds(nanos)) + } } -// Can remove once chrono:0.4.23 is released -pub fn add_day_time(prior: D, interval: i64, sign: i32) -> D -where - D: Datelike + Add, -{ - let (days, ms) = IntervalDayTimeType::to_parts(interval); - let days = days * sign; - let ms = ms * sign; - let intermediate = prior.add(Duration::days(days as i64)); - intermediate.add(Duration::milliseconds(ms as i64)) -} +impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper struct Fl(T); @@ -1469,6 +482,10 @@ macro_rules! hash_float_value { hash_float_value!((f64, u64), (f32, u32)); // manual implementation of `Hash` +// +// # Panics +// +// Panics if there is an error when creating hash values for rows impl std::hash::Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; @@ -1478,6 +495,11 @@ impl std::hash::Hash for ScalarValue { p.hash(state); s.hash(state) } + Decimal256(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Boolean(v) => v.hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), @@ -1494,9 +516,14 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(v, t) => { - v.hash(state); - t.hash(state); + List(arr) | LargeList(arr) | FixedSizeList(arr) => { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = + create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -1508,6 +535,10 @@ impl std::hash::Hash for ScalarValue { TimestampMillisecond(v, _) => v.hash(state), TimestampMicrosecond(v, _) => v.hash(state), TimestampNanosecond(v, _) => v.hash(state), + DurationSecond(v) => v.hash(state), + DurationMillisecond(v) => v.hash(state), + DurationMicrosecond(v) => v.hash(state), + DurationNanosecond(v) => v.hash(state), IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), @@ -1525,15 +556,19 @@ impl std::hash::Hash for ScalarValue { } } -/// return a reference to the values array and the index into it for a +/// Return a reference to the values array and the index into it for a /// dictionary array +/// +/// # Errors +/// +/// Errors if the array cannot be downcasted to DictionaryArray #[inline] pub fn get_dict_value( array: &dyn Array, index: usize, -) -> (&ArrayRef, Option) { - let dict_array = as_dictionary_array::(array).unwrap(); - (dict_array.values(), dict_array.key(index)) +) -> Result<(&ArrayRef, Option)> { + let dict_array = as_dictionary_array::(array)?; + Ok((dict_array.values(), dict_array.key(index))) } /// Create a dictionary array representing `value` repeated `size` @@ -1541,9 +576,9 @@ pub fn get_dict_value( fn dict_from_scalar( value: &ScalarValue, size: usize, -) -> ArrayRef { +) -> Result { // values array is one element long (the value) - let values_array = value.to_array_of_size(1); + let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) @@ -1555,11 +590,9 @@ fn dict_from_scalar( // Note: this path could be made faster by using the ArrayData // APIs and skipping validation, if it every comes up in // performance traces. - Arc::new( - DictionaryArray::::try_new(key_array, values_array) - // should always be valid by construction above - .expect("Can not construct dictionary array"), - ) + Ok(Arc::new( + DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above + )) } /// Create a dictionary array representing all the values in values @@ -1598,152 +631,44 @@ fn dict_from_values( macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR( + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, false => Some(array.value($index).into()), }, $TZ.clone(), - ) + )) }}; } macro_rules! typed_cast { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) - }}; -} - -// keep until https://github.com/apache/arrow-rs/issues/2054 is finished -macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::$SCALAR_TY, - true, - ))), - $SIZE, - ) - } - Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE) - } - } - }}; -} - -macro_rules! build_timestamp_list { - ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::Timestamp($TIME_UNIT, $TIME_ZONE), - true, - ))), - $SIZE, - ) - } - Some(values) => match $TIME_UNIT { - TimeUnit::Second => { - build_values_list_tz!( - TimestampSecondBuilder, - TimestampSecond, - values, - $SIZE - ) - } - TimeUnit::Microsecond => build_values_list_tz!( - TimestampMillisecondBuilder, - TimestampMillisecond, - values, - $SIZE - ), - TimeUnit::Millisecond => build_values_list_tz!( - TimestampMicrosecondBuilder, - TimestampMicrosecond, - values, - $SIZE - ), - TimeUnit::Nanosecond => build_values_list_tz!( - TimestampNanosecondBuilder, - TimestampNanosecond, - values, - $SIZE - ), + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; + Ok::(ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), }, - } - }}; -} - -macro_rules! new_builder { - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -macro_rules! build_values_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let builder = new_builder!($VALUE_BUILDER_TY, $VALUES.len()); - let mut builder = ListBuilder::new(builder); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(v.clone()); - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true); - } - - builder.finish() - }}; -} - -macro_rules! build_values_list_tz { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let mut builder = - ListBuilder::new($VALUE_BUILDER_TY::with_capacity($VALUES.len())); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v), _) => { - builder.values().append_value(v.clone()); - } - ScalarValue::$SCALAR_TY(None, _) => { - builder.values().append_null(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true); - } - - builder.finish() + )) }}; } @@ -1775,30 +700,58 @@ macro_rules! build_timestamp_array_from_option { macro_rules! eq_array_primitive { ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + use std::any::type_name; + let array = $array + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$ARRAYTYPE>() + )) + })?; let is_valid = array.is_valid($index); - match $VALUE { + Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, None => !is_valid, - } + }) }}; } impl ScalarValue { + /// Create a [`Result`] with the provided value and datatype + /// + /// # Panics + /// + /// Panics if d is not compatible with T + pub fn new_primitive( + a: Option, + d: &DataType, + ) -> Result { + match a { + None => d.try_into(), + Some(v) => { + let array = PrimitiveArray::::new(vec![v].into(), None) + .with_data_type(d.clone()); + Self::try_from_array(&array, 0) + } + } + } + /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result { // make sure the precision and scale is valid if precision <= DECIMAL128_MAX_PRECISION && scale.unsigned_abs() <= precision { return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); } - Err(DataFusionError::Internal(format!( + _internal_err!( "Can not new a decimal type ScalarValue for precision {precision} and scale {scale}" - ))) + ) } /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { - ScalarValue::Utf8(Some(val.into())) + ScalarValue::from(val.into()) } /// Returns a [`ScalarValue::IntervalYearMonth`] representing @@ -1822,9 +775,18 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } - /// Create a new nullable ScalarValue::List with the specified child_type - pub fn new_list(scalars: Option>, child_type: DataType) -> Self { - Self::List(scalars, Arc::new(Field::new("item", child_type, true))) + /// Returns a [`ScalarValue`] representing + /// `value` and `tz_opt` timezone + pub fn new_timestamp( + value: Option, + tz_opt: Option>, + ) -> Self { + match T::UNIT { + TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt), + } } /// Create a zero value in the given type. @@ -1863,10 +825,20 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { ScalarValue::IntervalMonthDayNano(Some(0)) } + DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(Some(0)), + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(Some(0)) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(Some(0)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(Some(0)) + } _ => { - return Err(DataFusionError::NotImplemented(format!( + return _not_impl_err!( "Can't create a zero scalar from data_type \"{datatype:?}\"" - ))); + ); } }) } @@ -1886,9 +858,9 @@ impl ScalarValue { DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), _ => { - return Err(DataFusionError::NotImplemented(format!( + return _not_impl_err!( "Can't create an one scalar from data_type \"{datatype:?}\"" - ))); + ); } }) } @@ -1904,9 +876,9 @@ impl ScalarValue { DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), _ => { - return Err(DataFusionError::NotImplemented(format!( + return _not_impl_err!( "Can't create a negative one scalar from data_type \"{datatype:?}\"" - ))); + ); } }) } @@ -1925,15 +897,15 @@ impl ScalarValue { DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), _ => { - return Err(DataFusionError::NotImplemented(format!( + return _not_impl_err!( "Can't create a negative one scalar from data_type \"{datatype:?}\"" - ))); + ); } }) } - /// Getter for the `DataType` of the value - pub fn get_datatype(&self) -> DataType { + /// return the [`DataType`] of this `ScalarValue` + pub fn data_type(&self) -> DataType { match self { ScalarValue::Boolean(_) => DataType::Boolean, ScalarValue::UInt8(_) => DataType::UInt8, @@ -1947,6 +919,9 @@ impl ScalarValue { ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } + ScalarValue::Decimal256(_, precision, scale) => { + DataType::Decimal256(*precision, *scale) + } ScalarValue::TimestampSecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) } @@ -1966,11 +941,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1984,14 +957,32 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { DataType::Interval(IntervalUnit::MonthDayNano) } + ScalarValue::DurationSecond(_) => DataType::Duration(TimeUnit::Second), + ScalarValue::DurationMillisecond(_) => { + DataType::Duration(TimeUnit::Millisecond) + } + ScalarValue::DurationMicrosecond(_) => { + DataType::Duration(TimeUnit::Microsecond) + } + ScalarValue::DurationNanosecond(_) => { + DataType::Duration(TimeUnit::Nanosecond) + } ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), ScalarValue::Dictionary(k, v) => { - DataType::Dictionary(k.clone(), Box::new(v.get_datatype())) + DataType::Dictionary(k.clone(), Box::new(v.data_type())) } ScalarValue::Null => DataType::Null, } } + /// Getter for the `DataType` of the value. + /// + /// Suggest using [`Self::data_type`] as a more standard API + #[deprecated(since = "31.0.0", note = "use data_type instead")] + pub fn get_datatype(&self) -> DataType { + self.data_type() + } + /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { match self { @@ -1999,7 +990,8 @@ impl ScalarValue { | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) - | ScalarValue::Float32(None) => Ok(self.clone()), + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) => Ok(self.clone()), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), @@ -2022,55 +1014,102 @@ impl ScalarValue { ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) } - value => Err(DataFusionError::Internal(format!( + ScalarValue::Decimal256(Some(v), precision, scale) => Ok( + ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), + ), + ScalarValue::TimestampSecond(Some(v), tz) => { + Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampNanosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMicrosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + } + ScalarValue::TimestampMillisecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + } + value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" - ))), + ), } } + /// Wrapping addition of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels pub fn add>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, +) + let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } - + /// Checked addition of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels pub fn add_checked>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_checked_op!(self, rhs, checked_add, +) + let r = add(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } + /// Wrapping subtraction of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels pub fn sub>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, -) + let r = sub_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } + /// Checked subtraction of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels pub fn sub_checked>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_checked_op!(self, rhs, checked_sub, -) - } - - pub fn and>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, &&) + let r = sub(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } - pub fn or>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, ||) + /// Wrapping multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul>(&self, other: T) -> Result { + let r = mul_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } - pub fn bitand>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, &) + /// Checked multiplication of `ScalarValue` + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn mul_checked>(&self, other: T) -> Result { + let r = mul(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } - pub fn bitor>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, |) + /// Performs `lhs / rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn div>(&self, other: T) -> Result { + let r = div(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } - pub fn bitxor>(&self, other: T) -> Result { - let rhs = other.borrow(); - impl_op!(self, rhs, ^) + /// Performs `lhs % rhs` + /// + /// Overflow or division by zero will result in an error, with exception to + /// floating point numbers, which instead follow the IEEE 754 rules. + /// + /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code + /// should operate on Arrays directly, using vectorized array kernels. + pub fn rem>(&self, other: T) -> Result { + let r = rem(&self.to_scalar()?, &other.borrow().to_scalar()?)?; + Self::try_from_array(r.as_ref(), 0) } pub fn is_unsigned(&self) -> bool { @@ -2091,6 +1130,7 @@ impl ScalarValue { ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), + ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), ScalarValue::Int16(v) => v.is_none(), ScalarValue::Int32(v) => v.is_none(), @@ -2104,7 +1144,11 @@ impl ScalarValue { ScalarValue::Binary(v) => v.is_none(), ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), - ScalarValue::List(v, _) => v.is_none(), + // arr.len() should be 1 for a list scalar, but we don't seem to + // enforce that anywhere, so we still check against array length. + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -2118,6 +1162,10 @@ impl ScalarValue { ScalarValue::IntervalYearMonth(v) => v.is_none(), ScalarValue::IntervalDayTime(v) => v.is_none(), ScalarValue::IntervalMonthDayNano(v) => v.is_none(), + ScalarValue::DurationSecond(v) => v.is_none(), + ScalarValue::DurationMillisecond(v) => v.is_none(), + ScalarValue::DurationMicrosecond(v) => v.is_none(), + ScalarValue::DurationNanosecond(v) => v.is_none(), ScalarValue::Struct(v, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } @@ -2131,47 +1179,83 @@ impl ScalarValue { /// /// Note: the datatype itself must support subtraction. pub fn distance(&self, other: &ScalarValue) -> Option { - // Having an explicit null check here is important because the - // subtraction for scalar values will return a real value even - // if one side is null. - if self.is_null() || other.is_null() { - return None; - } - - let distance = if self > other { - self.sub_checked(other).ok()? - } else { - other.sub_checked(self).ok()? - }; - - match distance { - ScalarValue::Int8(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int16(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int32(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int64(Some(v)) => usize::try_from(v).ok(), - ScalarValue::UInt8(Some(v)) => Some(v as usize), - ScalarValue::UInt16(Some(v)) => Some(v as usize), - ScalarValue::UInt32(Some(v)) => usize::try_from(v).ok(), - ScalarValue::UInt64(Some(v)) => usize::try_from(v).ok(), + match (self, other) { + (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), // TODO: we might want to look into supporting ceil/floor here for floats. - ScalarValue::Float32(Some(v)) => Some(v.round() as usize), - ScalarValue::Float64(Some(v)) => Some(v.round() as usize), + (Self::Float32(Some(l)), Self::Float32(Some(r))) => { + Some((l - r).abs().round() as _) + } + (Self::Float64(Some(l)), Self::Float64(Some(r))) => { + Some((l - r).abs().round() as _) + } _ => None, } } /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { + /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + pub fn to_array(&self) -> Result { self.to_array_of_size(1) } + /// Converts a scalar into an arrow [`Scalar`] (which implements + /// the [`Datum`] interface). + /// + /// This can be used to call arrow compute kernels such as `lt` + /// + /// # Errors + /// + /// Errors if the ScalarValue cannot be converted into a 1-row array + /// + /// # Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{BooleanArray, Int32Array}; + /// + /// let arr = Int32Array::from(vec![Some(1), None, Some(10)]); + /// let five = ScalarValue::Int32(Some(5)); + /// + /// let result = arrow::compute::kernels::cmp::lt( + /// &arr, + /// &five.to_scalar().unwrap(), + /// ).unwrap(); + /// + /// let expected = BooleanArray::from(vec![ + /// Some(true), + /// None, + /// Some(false) + /// ] + /// ); + /// + /// assert_eq!(&result, &expected); + /// ``` + /// [`Datum`]: arrow_array::Datum + pub fn to_scalar(&self) -> Result> { + Ok(Scalar::new(self.to_array_of_size(1)?)) + } + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] - /// corresponding to those values. For example, + /// corresponding to those values. For example, an iterator of + /// [`ScalarValue::Int32`] would be converted to an [`Int32Array`]. /// /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// - /// Example + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type + /// + /// # Example /// ``` /// use datafusion_common::ScalarValue; /// use arrow::array::{ArrayRef, BooleanArray}; @@ -2204,11 +1288,11 @@ impl ScalarValue { // figure out the type based on the first element let data_type = match scalars.peek() { None => { - return Err(DataFusionError::Internal( - "Empty iterator passed to ScalarValue::iter_to_array".to_string(), - )); + return _internal_err!( + "Empty iterator passed to ScalarValue::iter_to_array" + ); } - Some(sv) => sv.get_datatype(), + Some(sv) => sv.data_type(), }; /// Creates an array of $ARRAY_TY by unpacking values of @@ -2220,11 +1304,11 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) } else { - Err(DataFusionError::Internal(format!( + _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv - ))) + ) } }) .collect::>()?; @@ -2240,11 +1324,11 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v, _) = sv { Ok(v) } else { - Err(DataFusionError::Internal(format!( + _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv - ))) + ) } }) .collect::>()?; @@ -2262,11 +1346,11 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) } else { - Err(DataFusionError::Internal(format!( + _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv - ))) + ) } }) .collect::>()?; @@ -2275,70 +1359,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(xs, _) => xs.map(|x| { - x.iter().map(|x| match x { - ScalarValue::$SCALAR_TY(i) => *i, - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>() - }), - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }), - )) - }}; - } + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; - macro_rules! build_array_list_string { - ($BUILDER:ident, $SCALAR_TY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - for s in xs { - match s { - ScalarValue::$SCALAR_TY(Some(val)) => { - builder.values().append_value(val); - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null(); - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ))) - } - } - } - builder.append(true); - } - ScalarValue::List(None, _) => { - builder.append(false); - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ))) - } - } + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -2347,12 +1397,12 @@ impl ScalarValue { ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Decimal256(_, _) => { - return Err(DataFusionError::Internal( - "Decimal256 is not supported for ScalarValue".to_string(), - )); + DataType::Decimal256(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) } - DataType::Null => ScalarValue::iter_to_null_array(scalars), + DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -2415,47 +1465,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8) - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16) - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32) - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64) - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32) - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64) - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, Utf8) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, LargeUtf8) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -2487,9 +1497,7 @@ impl ScalarValue { } }; } else { - return Err(DataFusionError::Internal(format!( - "Expected Struct but found: {scalar}" - ))); + return _internal_err!("Expected Struct but found: {scalar}"); }; } @@ -2503,7 +1511,7 @@ impl ScalarValue { .collect::>>()?; let array = StructArray::from(field_values); - nullif(&array, &null_mask_builder.finish())? + arrow::compute::nullif(&array, &null_mask_builder.finish())? } DataType::Dictionary(key_type, value_type) => { // create the values array @@ -2513,13 +1521,13 @@ impl ScalarValue { if &inner_key_type == key_type { Ok(*scalar) } else { - panic!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})"); + _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") } } _ => { - Err(DataFusionError::Internal(format!( + _internal_err!( "Expected scalar of type {value_type} but found: {scalar} {scalar:?}" - ))) + ) } }) .collect::>>()?; @@ -2545,10 +1553,10 @@ impl ScalarValue { if let ScalarValue::FixedSizeBinary(_, v) = sv { Ok(v) } else { - Err(DataFusionError::Internal(format!( + _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {data_type:?}, got {sv:?}" - ))) + ) } }) .collect::>>()?; @@ -2570,30 +1578,33 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { - return Err(DataFusionError::Internal(format!( + return _internal_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, scalars.peek() - ))); + ); } }; Ok(array) } - fn iter_to_null_array(scalars: impl IntoIterator) -> ArrayRef { - let length = - scalars - .into_iter() - .fold(0usize, |r, element: ScalarValue| match element { - ScalarValue::Null => r + 1, - _ => unreachable!(), - }); - new_null_array(&DataType::Null, length) + fn iter_to_null_array( + scalars: impl IntoIterator, + ) -> Result { + let length = scalars.into_iter().try_fold( + 0usize, + |r, element: ScalarValue| match element { + ScalarValue::Null => Ok::(r + 1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }, + )?; + Ok(new_null_array(&DataType::Null, length)) } fn iter_to_decimal_array( @@ -2604,76 +1615,34 @@ impl ScalarValue { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } }) - .collect::() + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } - fn iter_to_array_list( + fn iter_to_decimal256_array( scalars: impl IntoIterator, - data_type: &DataType, - ) -> Result> { - let mut offsets = Int32Array::builder(0); - offsets.append_value(0); - - let mut elements: Vec = Vec::new(); - let mut valid = BooleanBufferBuilder::new(0); - let mut flat_len = 0i32; - for scalar in scalars { - if let ScalarValue::List(values, field) = scalar { - match values { - Some(values) => { - let element_array = if !values.is_empty() { - ScalarValue::iter_to_array(values)? - } else { - arrow::array::new_empty_array(field.data_type()) - }; - - // Add new offset index - flat_len += element_array.len() as i32; - offsets.append_value(flat_len); - - elements.push(element_array); - - // Element is valid - valid.append(true); - } - None => { - // Repeat previous offset index - offsets.append_value(flat_len); - - // Element is null - valid.append(false); - } + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + s => { + _internal_err!( + "Expected ScalarValue::Decimal256 element. Received {s:?}" + ) } - } else { - return Err(DataFusionError::Internal(format!( - "Expected ScalarValue::List element. Received {scalar:?}" - ))); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices - let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(data_type.clone()) - .len(offsets_array.len() - 1) - .nulls(Some(NullBuffer::new(valid.finish()))) - .add_buffer(offsets_array.values().inner().clone()) - .add_child_data(flat_array.to_data()); - - let list_array = ListArray::from(array_data.build()?); - Ok(list_array) + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) } fn build_decimal_array( @@ -2681,19 +1650,120 @@ impl ScalarValue { precision: u8, scale: i8, size: usize, - ) -> Decimal128Array { + ) -> Result { + match value { + Some(val) => Decimal128Array::from(vec![val; size]) + .with_precision_and_scale(precision, scale) + .map_err(DataFusionError::ArrowError), + None => { + let mut builder = Decimal128Array::builder(size) + .with_precision_and_scale(precision, scale) + .map_err(DataFusionError::ArrowError)?; + builder.append_nulls(size); + Ok(builder.finish()) + } + } + } + + fn build_decimal256_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Result { std::iter::repeat(value) .take(size) - .collect::() + .collect::() .with_precision_and_scale(precision, scale) - .unwrap() + .map_err(DataFusionError::ArrowError) + } + + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = as_list_array(&array).unwrap(); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_list_array(values)) + } + + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`LargeListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{LargeListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_large_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); + /// let result = as_large_list_array(&array).unwrap(); + /// + /// let expected = LargeListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_large_list_array(values)) } /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is + /// - a decimal that fails be converted to a decimal array of size + /// - a `Fixedsizelist` that fails to be concatenated into an array of size + /// - a `List` that fails to be concatenated into an array of size + /// - a `Dictionary` that fails be converted to a dictionary array of size + pub fn to_array_of_size(&self, size: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size), + ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, + ), + ScalarValue::Decimal256(e, precision, scale) => Arc::new( + ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, ), ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef @@ -2805,35 +1875,15 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(values, field) => Arc::new(match field.data_type() { - DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), - DataType::Int8 => build_list!(Int8Builder, Int8, values, size), - DataType::Int16 => build_list!(Int16Builder, Int16, values, size), - DataType::Int32 => build_list!(Int32Builder, Int32, values, size), - DataType::Int64 => build_list!(Int64Builder, Int64, values, size), - DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), - DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), - DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), - DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), - DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), - DataType::Float32 => build_list!(Float32Builder, Float32, values, size), - DataType::Float64 => build_list!(Float64Builder, Float64, values, size), - DataType::Timestamp(unit, tz) => { - build_timestamp_list!(unit.clone(), tz.clone(), values, size) - } - &DataType::LargeUtf8 => { - build_list!(LargeStringBuilder, LargeUtf8, values, size) - } - _ => ScalarValue::iter_to_array_list( - repeat(self.clone()).take(size), - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - ) - .unwrap(), - }), + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + let arrays = std::iter::repeat(arr.as_ref()) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()) + .map_err(DataFusionError::ArrowError)? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2897,39 +1947,67 @@ impl ScalarValue { e, size ), + ScalarValue::DurationSecond(e) => build_array_from_option!( + Duration, + TimeUnit::Second, + DurationSecondArray, + e, + size + ), + ScalarValue::DurationMillisecond(e) => build_array_from_option!( + Duration, + TimeUnit::Millisecond, + DurationMillisecondArray, + e, + size + ), + ScalarValue::DurationMicrosecond(e) => build_array_from_option!( + Duration, + TimeUnit::Microsecond, + DurationMicrosecondArray, + e, + size + ), + ScalarValue::DurationNanosecond(e) => build_array_from_option!( + Duration, + TimeUnit::Nanosecond, + DurationNanosecondArray, + e, + size + ), ScalarValue::Struct(values, fields) => match values { Some(values) => { - let field_values: Vec<_> = fields + let field_values = fields .iter() .zip(values.iter()) .map(|(field, value)| { - (field.clone(), value.to_array_of_size(size)) + Ok((field.clone(), value.to_array_of_size(size)?)) }) - .collect(); + .collect::>>()?; Arc::new(StructArray::from(field_values)) } None => { - let dt = self.get_datatype(); + let dt = self.data_type(); new_null_array(&dt, size) } }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size), - DataType::Int16 => dict_from_scalar::(v, size), - DataType::Int32 => dict_from_scalar::(v, size), - DataType::Int64 => dict_from_scalar::(v, size), - DataType::UInt8 => dict_from_scalar::(v, size), - DataType::UInt16 => dict_from_scalar::(v, size), - DataType::UInt32 => dict_from_scalar::(v, size), - DataType::UInt64 => dict_from_scalar::(v, size), + DataType::Int8 => dict_from_scalar::(v, size)?, + DataType::Int16 => dict_from_scalar::(v, size)?, + DataType::Int32 => dict_from_scalar::(v, size)?, + DataType::Int64 => dict_from_scalar::(v, size)?, + DataType::UInt8 => dict_from_scalar::(v, size)?, + DataType::UInt16 => dict_from_scalar::(v, size)?, + DataType::UInt32 => dict_from_scalar::(v, size)?, + DataType::UInt64 => dict_from_scalar::(v, size)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), } } ScalarValue::Null => new_null_array(&DataType::Null, size), - } + }) } fn get_decimal_value_from_array( @@ -2938,12 +2016,91 @@ impl ScalarValue { precision: u8, scale: i8, ) -> Result { - let array = as_decimal128_array(array)?; - if array.is_null(index) { - Ok(ScalarValue::Decimal128(None, precision, scale)) - } else { - let value = array.value(index); - Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + match array.data_type() { + DataType::Decimal128(_, _) => { + let array = as_decimal128_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal128(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) + } + } + DataType::Decimal256(_, _) => { + let array = as_decimal256_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal256(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal256(Some(value), precision, scale)) + } + } + _ => _internal_err!("Unsupported decimal type"), + } + } + + /// Retrieve ScalarValue for each row in `array` + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::ListArray; + /// use arrow::datatypes::{DataType, Int32Type}; + /// + /// let list_arr = ListArray::from_iter_primitive::(vec![ + /// Some(vec![Some(1), Some(2), Some(3)]), + /// None, + /// Some(vec![Some(4), Some(5)]) + /// ]); + /// + /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// + /// let expected = vec![ + /// vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ], + /// vec![], + /// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))] + /// ]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { + let mut scalars = Vec::with_capacity(array.len()); + + for index in 0..array.len() { + let scalar_values = match array.data_type() { + DataType::List(_) => { + let list_array = as_list_array(array); + match list_array.is_null(index) { + true => Vec::new(), + false => { + let nested_array = list_array.value(index); + ScalarValue::convert_array_to_scalar_vec(&nested_array)? + .into_iter() + .flatten() + .collect() + } + } + } + _ => { + let scalar = ScalarValue::try_from_array(array, index)?; + vec![scalar] + } + }; + scalars.push(scalar_values); + } + Ok(scalars) + } + + // TODO: Support more types after other ScalarValue is wrapped with ArrayRef + /// Get raw data (inner array) inside ScalarValue + pub fn raw_data(&self) -> Result { + match self { + ScalarValue::List(arr) => Ok(arr.to_owned()), + _ => _internal_err!("ScalarValue is not a list"), } } @@ -2961,101 +2118,107 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::Decimal256(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) - } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), - DataType::List(nested_type) => { - let list_array = as_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - ScalarValue::new_list(value, nested_type.data_type().clone()) + typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::LargeUtf8 => { + typed_cast!(array, index, LargeStringArray, LargeUtf8)? } - DataType::Date32 => { - typed_cast!(array, index, Date32Array, Date32) + DataType::List(_) => { + let list_array = as_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `ListArray` with the value at `index`. + let arr = Arc::new(array_into_list_array(nested_array)); + + ScalarValue::List(arr) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `LargeListArray` with the value at `index`. + let arr = Arc::new(array_into_large_list_array(nested_array)); + + ScalarValue::LargeList(arr) } - DataType::Date64 => { - typed_cast!(array, index, Date64Array, Date64) + // TODO: There is no test for FixedSizeList now, add it later + DataType::FixedSizeList(_, _) => { + let list_array = as_fixed_size_list_array(array)?; + let nested_array = list_array.value(index); + // Produces a single element `ListArray` with the value at `index`. + let arr = Arc::new(array_into_list_array(nested_array)); + + ScalarValue::List(arr) } + DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, + DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second) + typed_cast!(array, index, Time32SecondArray, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond) + typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond) + typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampSecondArray, - TimestampSecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMillisecondArray, - TimestampMillisecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMicrosecondArray, - TimestampMicrosecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampNanosecondArray, - TimestampNanosecond, - tz_opt - ) + typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? } + DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( + array, + index, + TimestampSecondArray, + TimestampSecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMillisecondArray, + TimestampMillisecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampMicrosecondArray, + TimestampMicrosecond, + tz_opt + )?, + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( + array, + index, + TimestampNanosecondArray, + TimestampNanosecond, + tz_opt + )?, DataType::Dictionary(key_type, _) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary @@ -3079,20 +2242,6 @@ impl ScalarValue { } Self::Struct(Some(field_values), fields.clone()) } - DataType::FixedSizeList(nested_type, _len) => { - let list_array = as_fixed_size_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - ScalarValue::new_list(value, nested_type.data_type().clone()) - } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; let size = match array.data_type() { @@ -3108,35 +2257,47 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime) + typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth) + typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - ) + DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( + array, + index, + IntervalMonthDayNanoArray, + IntervalMonthDayNano + )?, + + DataType::Duration(TimeUnit::Second) => { + typed_cast!(array, index, DurationSecondArray, DurationSecond)? + } + DataType::Duration(TimeUnit::Millisecond) => { + typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? } + DataType::Duration(TimeUnit::Microsecond) => { + typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? + } + DataType::Duration(TimeUnit::Nanosecond) => { + typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? + } + other => { - return Err(DataFusionError::NotImplemented(format!( + return _not_impl_err!( "Can't create a scalar from array of type \"{other:?}\"" - ))); + ); } }) } /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::Utf8(Some(value)); + let value = ScalarValue::from(value); let cast_options = CastOptions { safe: false, format_options: Default::default(), }; - let cast_arr = cast_with_options(&value.to_array(), target_type, &cast_options)?; + let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -3159,6 +2320,25 @@ impl ScalarValue { } } + fn eq_array_decimal256( + array: &ArrayRef, + index: usize, + value: Option<&i256>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal256_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + /// Compares a single row of array @ index for equality with self, /// in an optimized fashion. /// @@ -3175,9 +2355,19 @@ impl ScalarValue { /// /// This function has a few narrow usescases such as hash table key /// comparisons where comparing a single row at a time is necessary. + /// + /// # Errors + /// + /// Errors if + /// - it fails to downcast `array` to the data type of `self` + /// - `self` is a `Struct` + /// + /// # Panics + /// + /// Panics if `self` is a dictionary with invalid key type #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - match self { + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -3185,106 +2375,143 @@ impl ScalarValue { v.as_ref(), *precision, *scale, - ) - .unwrap() + )? + } + ScalarValue::Decimal256(v, precision, scale) => { + ScalarValue::eq_array_decimal256( + array, + index, + v.as_ref(), + *precision, + *scale, + )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) + eq_array_primitive!(array, index, BooleanArray, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) + eq_array_primitive!(array, index, Float32Array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) + eq_array_primitive!(array, index, Float64Array, val)? + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, + ScalarValue::Int16(val) => { + eq_array_primitive!(array, index, Int16Array, val)? + } + ScalarValue::Int32(val) => { + eq_array_primitive!(array, index, Int32Array, val)? + } + ScalarValue::Int64(val) => { + eq_array_primitive!(array, index, Int64Array, val)? + } + ScalarValue::UInt8(val) => { + eq_array_primitive!(array, index, UInt8Array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) + eq_array_primitive!(array, index, UInt16Array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) + eq_array_primitive!(array, index, UInt32Array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) + eq_array_primitive!(array, index, UInt64Array, val)? + } + ScalarValue::Utf8(val) => { + eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) + eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) + eq_array_primitive!(array, index, BinaryArray, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val) + eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) + eq_array_primitive!(array, index, LargeBinaryArray, val)? + } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + let right = array.slice(index, 1); + arr == &right } - ScalarValue::List(_, _) => unimplemented!(), ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) + eq_array_primitive!(array, index, Date32Array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) + eq_array_primitive!(array, index, Date64Array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val) + eq_array_primitive!(array, index, Time32SecondArray, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val) + eq_array_primitive!(array, index, Time32MillisecondArray, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val) + eq_array_primitive!(array, index, Time64MicrosecondArray, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val) + eq_array_primitive!(array, index, Time64NanosecondArray, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) + eq_array_primitive!(array, index, TimestampSecondArray, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) + eq_array_primitive!(array, index, TimestampMillisecondArray, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) + eq_array_primitive!(array, index, TimestampNanosecondArray, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) + eq_array_primitive!(array, index, IntervalYearMonthArray, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) + eq_array_primitive!(array, index, IntervalDayTimeArray, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) + eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? + } + ScalarValue::DurationSecond(val) => { + eq_array_primitive!(array, index, DurationSecondArray, val)? + } + ScalarValue::DurationMillisecond(val) => { + eq_array_primitive!(array, index, DurationMillisecondArray, val)? + } + ScalarValue::DurationMicrosecond(val) => { + eq_array_primitive!(array, index, DurationMicrosecondArray, val)? + } + ScalarValue::DurationNanosecond(val) => { + eq_array_primitive!(array, index, DurationNanosecondArray, val)? + } + ScalarValue::Struct(_, _) => { + return _not_impl_err!("Struct is not supported yet") } - ScalarValue::Struct(_, _) => unimplemented!(), ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index), - DataType::Int16 => get_dict_value::(array, index), - DataType::Int32 => get_dict_value::(array, index), - DataType::Int64 => get_dict_value::(array, index), - DataType::UInt8 => get_dict_value::(array, index), - DataType::UInt16 => get_dict_value::(array, index), - DataType::UInt32 => get_dict_value::(array, index), - DataType::UInt64 => get_dict_value::(array, index), + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // was the value in the array non null? match values_index { - Some(values_index) => v.eq_array(values_array, values_index), + Some(values_index) => v.eq_array(values_array, values_index)?, None => v.is_null(), } } ScalarValue::Null => array.is_null(index), - } + }) } /// Estimate size if bytes including `Self`. For values with internal containers such as `String` @@ -3297,6 +2524,7 @@ impl ScalarValue { | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) | ScalarValue::Int16(_) | ScalarValue::Int32(_) @@ -3313,7 +2541,11 @@ impl ScalarValue { | ScalarValue::Time64Nanosecond(_) | ScalarValue::IntervalYearMonth(_) | ScalarValue::IntervalDayTime(_) - | ScalarValue::IntervalMonthDayNano(_) => 0, + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::DurationSecond(_) + | ScalarValue::DurationMillisecond(_) + | ScalarValue::DurationMicrosecond(_) + | ScalarValue::DurationNanosecond(_) => 0, ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } @@ -3328,13 +2560,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(vals, field) => { - vals.as_ref() - .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) - .unwrap_or_default() - // `field` is boxed, so it is NOT already included in `self` - + field.size() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -3430,13 +2658,17 @@ impl FromStr for ScalarValue { } } +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue::Utf8(Some(value)) + } +} + impl From> for ScalarValue { fn from(value: Vec<(&str, ScalarValue)>) -> Self { let (fields, scalars): (SchemaBuilder, Vec<_>) = value .into_iter() - .map(|(name, scalar)| { - (Field::new(name, scalar.get_datatype(), false), scalar) - }) + .map(|(name, scalar)| (Field::new(name, scalar.data_type(), false), scalar)) .unzip(); Self::Struct(Some(scalars), fields.finish().fields) @@ -3451,11 +2683,11 @@ macro_rules! impl_try_from { fn try_from(value: ScalarValue) -> Result { match value { ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( + _ => _internal_err!( "Cannot convert {:?} to {}", value, std::any::type_name::() - ))), + ), } } } @@ -3475,11 +2707,11 @@ impl TryFrom for i32 { | ScalarValue::Date32(Some(inner_value)) | ScalarValue::Time32Second(Some(inner_value)) | ScalarValue::Time32Millisecond(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( + _ => _internal_err!( "Cannot convert {:?} to {}", value, std::any::type_name::() - ))), + ), } } } @@ -3498,27 +2730,43 @@ impl TryFrom for i64 { | ScalarValue::TimestampMicrosecond(Some(inner_value), _) | ScalarValue::TimestampMillisecond(Some(inner_value), _) | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( + _ => _internal_err!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ), + } + } +} + +// special implementation for i128 because of Decimal128 +impl TryFrom for i128 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), + _ => _internal_err!( "Cannot convert {:?} to {}", value, std::any::type_name::() - ))), + ), } } } -// special implementation for i128 because of Decimal128 -impl TryFrom for i128 { +// special implementation for i256 because of Decimal128 +impl TryFrom for i256 { type Error = DataFusionError; fn try_from(value: ScalarValue) -> Result { match value { - ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( + ScalarValue::Decimal256(Some(inner_value), _, _) => Ok(inner_value), + _ => _internal_err!( "Cannot convert {:?} to {}", value, std::any::type_name::() - ))), + ), } } } @@ -3544,8 +2792,8 @@ impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError; /// Create a Null instance of ScalarValue for this datatype - fn try_from(datatype: &DataType) -> Result { - Ok(match datatype { + fn try_from(data_type: &DataType) -> Result { + Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), @@ -3560,6 +2808,9 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), DataType::Binary => ScalarValue::Binary(None), @@ -3596,19 +2847,37 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { ScalarValue::IntervalMonthDayNano(None) } + + DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(None) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(None) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(None) + } + DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - DataType::List(ref nested_type) => { - ScalarValue::new_list(None, nested_type.data_type().clone()) - } + // `ScalaValue::List` contains single element `ListArray`. + DataType::List(field) => ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, + )), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from data_type \"{datatype:?}\"" - ))); + return _not_impl_err!( + "Can't create a scalar from data_type \"{data_type:?}\"" + ); } }) } @@ -3623,12 +2892,20 @@ macro_rules! format_option { }}; } +// Implement Display trait for ScalarValue +// +// # Panics +// +// Panics if there is an error when creating a visual representation of columns via `arrow::util::pretty` impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } + ScalarValue::Decimal256(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Boolean(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, @@ -3646,40 +2923,9 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::LargeUtf8(e) => format_option!(f, e)?, - ScalarValue::Binary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::FixedSizeBinary(_, e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::LargeBinary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::List(e, _) => match e { + ScalarValue::Binary(e) + | ScalarValue::FixedSizeBinary(_, e) + | ScalarValue::LargeBinary(e) => match e { Some(l) => write!( f, "{}", @@ -3690,6 +2936,16 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + // ScalarValue List should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}")? + } ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -3699,6 +2955,10 @@ impl fmt::Display for ScalarValue { ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::DurationSecond(e) => format_option!(f, e)?, + ScalarValue::DurationMillisecond(e) => format_option!(f, e)?, + ScalarValue::DurationMicrosecond(e) => format_option!(f, e)?, + ScalarValue::DurationNanosecond(e) => format_option!(f, e)?, ScalarValue::Struct(e, fields) => match e { Some(l) => write!( f, @@ -3722,6 +2982,7 @@ impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), + ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), @@ -3759,7 +3020,9 @@ impl fmt::Debug for ScalarValue { } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), - ScalarValue::List(_, _) => write!(f, "List([{self}])"), + ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), + ScalarValue::List(_) => write!(f, "List({self})"), + ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3781,6 +3044,16 @@ impl fmt::Debug for ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { write!(f, "IntervalMonthDayNano(\"{self}\")") } + ScalarValue::DurationSecond(_) => write!(f, "DurationSecond(\"{self}\")"), + ScalarValue::DurationMillisecond(_) => { + write!(f, "DurationMillisecond(\"{self}\")") + } + ScalarValue::DurationMicrosecond(_) => { + write!(f, "DurationMicrosecond(\"{self}\")") + } + ScalarValue::DurationNanosecond(_) => { + write!(f, "DurationNanosecond(\"{self}\")") + } ScalarValue::Struct(e, fields) => { // Use Debug representation of field values match e { @@ -3802,7 +3075,7 @@ impl fmt::Debug for ScalarValue { } } -/// Trait used to map a NativeTime to a ScalarType. +/// Trait used to map a NativeType to a ScalarValue pub trait ScalarType { /// returns a scalar from an optional T fn scalar(r: Option) -> ScalarValue; @@ -3840,19 +3113,219 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; - use arrow::compute::{self, concat, is_null}; + use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; + #[test] + fn test_to_array_of_size_for_list() { + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + ])]); + + let sv = ScalarValue::List(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + let actual_list_arr = as_list_array(&actual_arr); + + let arr = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(1), None, Some(2)]), + ]); + + assert_eq!(&arr, actual_list_arr); + } + + #[test] + fn test_to_array_of_size_for_fsl() { + let values = Int32Array::from_iter([Some(1), None, Some(2)]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let sv = ScalarValue::FixedSizeList(Arc::new(arr)); + let actual_arr = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + + let expected_values = + Int32Array::from_iter([Some(1), None, Some(2), Some(1), None, Some(2)]); + let expected_arr = + FixedSizeListArray::new(field, 3, Arc::new(expected_values), None); + + assert_eq!( + &expected_arr, + as_fixed_size_list_array(actual_arr.as_ref()).unwrap() + ); + } + + #[test] + fn test_list_to_array_string() { + let scalars = vec![ + ScalarValue::from("rust"), + ScalarValue::from("arrow"), + ScalarValue::from("data-fusion"), + ]; + + let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + + let expected = array_into_list_array(Arc::new(StringArray::from(vec![ + "rust", + "arrow", + "data-fusion", + ]))); + let result = as_list_array(&array); + assert_eq!(result, &expected); + } + + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + + #[test] + fn iter_to_array_primitive_test() { + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + } + + #[test] + fn iter_to_array_string_test() { + let arr1 = + array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let arr2 = + array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + + let scalars = vec![ + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let result = as_list_array(&array); + + // build expected array + let string_builder = StringBuilder::with_capacity(5, 25); + let mut list_of_string_builder = ListBuilder::new(string_builder); + + list_of_string_builder.values().append_value("foo"); + list_of_string_builder.values().append_value("bar"); + list_of_string_builder.values().append_value("baz"); + list_of_string_builder.append(true); + + list_of_string_builder.values().append_value("rust"); + list_of_string_builder.values().append_value("world"); + list_of_string_builder.append(true); + let expected = list_of_string_builder.finish(); + + assert_eq!(result, &expected); + } + + #[test] + fn test_list_scalar_eq_to_array() { + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![None, Some(5)]), + ])); + + let fsl_array: ArrayRef = + Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ], + 3, + )); + + for arr in [list_array, fsl_array] { + for i in 0..arr.len() { + let scalar = ScalarValue::List(arr.slice(i, 1)); + assert!(scalar.eq_array(&arr, i).unwrap()); + } + } + } #[test] fn scalar_add_trait_test() -> Result<()> { @@ -3894,14 +3367,17 @@ mod tests { } #[test] - fn scalar_sub_trait_int32_overflow_test() -> Result<()> { + fn scalar_sub_trait_int32_overflow_test() { let int_value = ScalarValue::Int32(Some(i32::MAX)); let int_value_2 = ScalarValue::Int32(Some(i32::MIN)); - assert!(matches!( - int_value.sub_checked(&int_value_2), - Err(DataFusionError::Execution(msg)) if msg == "Overflow while calculating ScalarValue." - )); - Ok(()) + let err = int_value + .sub_checked(&int_value_2) + .unwrap_err() + .strip_backtrace(); + assert_eq!( + err, + "Arrow error: Compute error: Overflow happened on: 2147483647 - -2147483648" + ) } #[test] @@ -3914,14 +3390,14 @@ mod tests { } #[test] - fn scalar_sub_trait_int64_overflow_test() -> Result<()> { + fn scalar_sub_trait_int64_overflow_test() { let int_value = ScalarValue::Int64(Some(i64::MAX)); let int_value_2 = ScalarValue::Int64(Some(i64::MIN)); - assert!(matches!( - int_value.sub_checked(&int_value_2), - Err(DataFusionError::Execution(msg)) if msg == "Overflow while calculating ScalarValue." - )); - Ok(()) + let err = int_value + .sub_checked(&int_value_2) + .unwrap_err() + .strip_backtrace(); + assert_eq!(err, "Arrow error: Compute error: Overflow happened on: 9223372036854775807 - -9223372036854775808") } #[test] @@ -3969,11 +3445,11 @@ mod tests { { let scalar_result = left.add_checked(&right); - let left_array = left.to_array(); - let right_array = right.to_array(); + let left_array = left.to_array().expect("Failed to convert to array"); + let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); - let arrow_result = compute::add_checked(arrow_left_array, arrow_right_array); + let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } @@ -4003,7 +3479,7 @@ mod tests { #[test] fn scalar_decimal_test() -> Result<()> { let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); - assert_eq!(DataType::Decimal128(10, 1), decimal_value.get_datatype()); + assert_eq!(DataType::Decimal128(10, 1), decimal_value.data_type()); let try_into_value: i128 = decimal_value.clone().try_into().unwrap(); assert_eq!(123_i128, try_into_value); assert!(!decimal_value.is_null()); @@ -4018,22 +3494,30 @@ mod tests { } // decimal scalar to array - let array = decimal_value.to_array(); + let array = decimal_value + .to_array() + .expect("Failed to convert to array"); let array = as_decimal128_array(&array)?; assert_eq!(1, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array.value(0)); // decimal scalar to array with size - let array = decimal_value.to_array_of_size(10); + let array = decimal_value + .to_array_of_size(10) + .expect("Failed to convert to array of size"); let array_decimal = as_decimal128_array(&array)?; assert_eq!(10, array.len()); assert_eq!(DataType::Decimal128(10, 1), array.data_type().clone()); assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value.eq_array(&array, 1)); - assert!(decimal_value.eq_array(&array, 5)); + assert!(decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays")); + assert!(decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays")); // test try from array assert_eq!( decimal_value, @@ -4064,7 +3548,7 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ]; // convert the vec to decimal array and check the result - let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); assert_eq!(3, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); @@ -4074,19 +3558,22 @@ mod tests { ScalarValue::Decimal128(Some(3), 10, 2), ScalarValue::Decimal128(None, 10, 2), ]; - let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + let array = ScalarValue::iter_to_array(decimal_vec).unwrap(); assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); assert!(ScalarValue::try_new_decimal128(1, 10, 2) .unwrap() - .eq_array(&array, 0)); + .eq_array(&array, 0) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(2, 10, 2) .unwrap() - .eq_array(&array, 1)); + .eq_array(&array, 1) + .expect("Failed to compare arrays")); assert!(ScalarValue::try_new_decimal128(3, 10, 2) .unwrap() - .eq_array(&array, 2)); + .eq_array(&array, 2) + .expect("Failed to compare arrays")); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -4095,17 +3582,74 @@ mod tests { Ok(()) } + #[test] + fn test_list_partial_cmp() { + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + } + #[test] fn scalar_value_to_array_u64() -> Result<()> { let value = ScalarValue::UInt64(Some(13u64)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt64(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint64_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -4115,14 +3659,14 @@ mod tests { #[test] fn scalar_value_to_array_u32() -> Result<()> { let value = ScalarValue::UInt32(Some(13u32)); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(!array.is_null(0)); assert_eq!(array.value(0), 13); let value = ScalarValue::UInt32(None); - let array = value.to_array(); + let array = value.to_array().expect("Failed to convert to array"); let array = as_uint32_array(&array)?; assert_eq!(array.len(), 1); assert!(array.is_null(0)); @@ -4131,31 +3675,52 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::List( - None, - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); - let list_array = as_list_array(&list_array_ref).unwrap(); + let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); + + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + + #[test] + fn scalar_large_list_null_to_array() { + let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); - assert!(list_array.is_null(0)); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); } #[test] fn scalar_list_to_array() -> Result<()> { - let list_array_ref = ScalarValue::List( - Some(vec![ - ScalarValue::UInt64(Some(100)), - ScalarValue::UInt64(None), - ScalarValue::UInt64(Some(101)), - ]), - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } - let list_array = as_list_array(&list_array_ref)?; + #[test] + fn scalar_large_list_to_array() -> Result<()> { + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -4237,6 +3802,8 @@ mod tests { } #[test] + // despite clippy claiming they are useless, the code doesn't compile otherwise. + #[allow(clippy::useless_vec)] fn scalar_iter_to_array_boolean() { check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); @@ -4299,7 +3866,7 @@ mod tests { fn scalar_iter_to_array_empty() { let scalars = vec![] as Vec; - let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars).unwrap_err(); assert!( result .to_string() @@ -4317,13 +3884,13 @@ mod tests { ScalarValue::Dictionary(Box::new(key_type), Box::new(value)) } - let scalars = vec![ + let scalars = [ make_val(Some("Foo".into())), make_val(None), make_val(Some("Bar".into())), ]; - let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars).unwrap(); let array = as_dictionary_array::(&array).unwrap(); let values_array = as_string_array(array.values()).unwrap(); @@ -4345,9 +3912,9 @@ mod tests { fn scalar_iter_to_array_mismatched_types() { use ScalarValue::*; // If the scalar values are not all the correct type, error here - let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; + let scalars = [Boolean(Some(true)), Int32(Some(5))]; - let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars).unwrap_err(); assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), "{}", result); } @@ -4367,6 +3934,78 @@ mod tests { ); } + #[test] + fn scalar_try_from_array_list_array_null() { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ]); + + let non_null_list_scalar = ScalarValue::try_from_array(&list, 0).unwrap(); + let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap(); + + let data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + + assert_eq!(non_null_list_scalar.data_type(), data_type.clone()); + assert_eq!(null_list_scalar.data_type(), data_type); + } + + #[test] + fn scalar_try_from_list() { + let data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = &data_type; + let scalar: ScalarValue = data_type.try_into().unwrap(); + + let expected = ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 1, + )); + + assert_eq!(expected, scalar) + } + + #[test] + fn scalar_try_from_list_of_list() { + let data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))); + let data_type = &data_type; + let scalar: ScalarValue = data_type.try_into().unwrap(); + + let expected = ScalarValue::List(new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + 1, + )); + + assert_eq!(expected, scalar) + } + + #[test] + fn scalar_try_from_not_equal_list_nested_list() { + let list_data_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let data_type = &list_data_type; + let list_scalar: ScalarValue = data_type.try_into().unwrap(); + + let nested_list_data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))); + let data_type = &nested_list_data_type; + let nested_list_scalar: ScalarValue = data_type.try_into().unwrap(); + + assert_ne!(list_scalar, nested_list_scalar); + } + #[test] fn scalar_try_from_dict_datatype() { let data_type = @@ -4379,6 +4018,8 @@ mod tests { assert_eq!(expected, data_type.try_into().unwrap()) } + // this test fails on aarch, so don't run it there + #[cfg(not(target_arch = "aarch64"))] #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, @@ -4433,21 +4074,21 @@ mod tests { }}; } - let bool_vals = vec![Some(true), None, Some(false)]; - let f32_vals = vec![Some(-1.0), None, Some(1.0)]; + let bool_vals = [Some(true), None, Some(false)]; + let f32_vals = [Some(-1.0), None, Some(1.0)]; let f64_vals = make_typed_vec!(f32_vals, f64); - let i8_vals = vec![Some(-1), None, Some(1)]; + let i8_vals = [Some(-1), None, Some(1)]; let i16_vals = make_typed_vec!(i8_vals, i16); let i32_vals = make_typed_vec!(i8_vals, i32); let i64_vals = make_typed_vec!(i8_vals, i64); - let u8_vals = vec![Some(0), None, Some(1)]; + let u8_vals = [Some(0), None, Some(1)]; let u16_vals = make_typed_vec!(u8_vals, u16); let u32_vals = make_typed_vec!(u8_vals, u32); let u64_vals = make_typed_vec!(u8_vals, u64); - let str_vals = vec![Some("foo"), None, Some("bar")]; + let str_vals = [Some("foo"), None, Some("bar")]; /// Test each value in `scalar` with the corresponding element /// at `array`. Assumes each element is unique (aka not equal @@ -4614,7 +4255,9 @@ mod tests { for (index, scalar) in scalars.into_iter().enumerate() { assert!( - scalar.eq_array(&array, index), + scalar + .eq_array(&array, index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be equal to {array:?} at index {index}" ); @@ -4622,7 +4265,7 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index), + !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -4651,55 +4294,6 @@ mod tests { assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Equal) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Greater) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Less) - ); - - // For different data type, `partial_cmp` returns None. - assert_eq!( - List( - Some(vec![Int64(Some(1)), Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - None - ); - assert_eq!( ScalarValue::from(vec![ ("A", ScalarValue::from(1.0)), @@ -4724,53 +4318,16 @@ mod tests { ])), None ); - // Different type of intervals can be compared. - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 2))) - < IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 14, 0, 1 - ))), - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 4))) - >= IntervalDayTime(Some(IntervalDayTimeType::make_value(119, 1))) - ); - assert!( - IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 86_399_999))) - >= IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 0))) - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(2, 12))) - == IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 36, 0, 0 - ))), - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 0))) - != IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 1))) - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 4))) - == IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 16))), - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 3))) - > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 2, - 28, - 999_999_999 - ))), - ); - assert!( - IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 1))) - > IntervalDayTime(Some(IntervalDayTimeType::make_value(29, 9_999))), - ); - assert!( - IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value(1, 12, 34))) - > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 0, 142, 34 - ))) - ); + } + + #[test] + fn test_scalar_value_from_string() { + let scalar = ScalarValue::from("foo"); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from("foo".to_string()); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); + let scalar = ScalarValue::from_str("foo").unwrap(); + assert_eq!(scalar, ScalarValue::Utf8(Some("foo".to_string()))); } #[test] @@ -4791,7 +4348,7 @@ mod tests { Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), - ScalarValue::Utf8(Some("Hello".to_string())), + ScalarValue::from("Hello"), ScalarValue::from(vec![ ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), @@ -4821,7 +4378,9 @@ mod tests { ); // Convert to length-2 array - let array = scalar.to_array_of_size(2); + let array = scalar + .to_array_of_size(2) + .expect("Failed to convert to array of size"); let expected = Arc::new(StructArray::from(vec![ ( @@ -4959,38 +4518,40 @@ mod tests { )); // Define primitive list scalars - let l0 = ScalarValue::List( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l1 = ScalarValue::List( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l2 = ScalarValue::List( - Some(vec![ScalarValue::from(6i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); + let l0 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let l1 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]), + )); + let l2 = ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![Some(6)])]))); // Define struct scalars let s0 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("First")))), + ("A", ScalarValue::from("First")), ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Second")))), + ("A", ScalarValue::from("Second")), ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ - ("A", ScalarValue::Utf8(Some(String::from("Third")))), + ("A", ScalarValue::from("Third")), ("primitive_list", l2), ]); @@ -5016,15 +4577,19 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = - ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.get_datatype()); - let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.get_datatype()); + let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); + let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array))); + + let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); + let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array))); + + let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); + let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array))); - let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.get_datatype()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); - let array = as_list_array(&array).unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders let field_a_builder = StringBuilder::with_capacity(4, 1024); @@ -5144,54 +4709,37 @@ mod tests { assert_eq!(array, &expected); } + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::::from_lengths([1]), + Arc::new(a1), + None, + ) + } + #[test] fn test_nested_lists() { // Define inner list scalars - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); - - let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - let array = as_list_array(&array).unwrap(); + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); + + let array = ScalarValue::iter_to_array(vec![ + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), + ]) + .unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -5199,6 +4747,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -5207,14 +4756,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish(); @@ -5230,11 +4771,11 @@ mod tests { ); assert_eq!( - scalar.get_datatype(), + scalar.data_type(), DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); - let array = scalar.to_array(); + let array = scalar.to_array().expect("Failed to convert to array"); assert_eq!(array.len(), 1); assert_eq!( array.data_type(), @@ -5243,7 +4784,7 @@ mod tests { let newscalar = ScalarValue::try_from_array(&array, 0).unwrap(); assert_eq!( - newscalar.get_datatype(), + newscalar.data_type(), DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())) ); } @@ -5258,7 +4799,7 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); check_scalar_cast( - ScalarValue::Utf8(Some("foo".to_string())), + ScalarValue::from("foo"), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); @@ -5271,16 +4812,18 @@ mod tests { // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` fn check_scalar_cast(scalar: ScalarValue, desired_type: DataType) { // convert from scalar --> Array to call cast - let scalar_array = scalar.to_array(); + let scalar_array = scalar.to_array().expect("Failed to convert to array"); // cast the actual value let cast_array = kernels::cast::cast(&scalar_array, &desired_type).unwrap(); // turn it back to a scalar let cast_scalar = ScalarValue::try_from_array(&cast_array, 0).unwrap(); - assert_eq!(cast_scalar.get_datatype(), desired_type); + assert_eq!(cast_scalar.data_type(), desired_type); // Some time later the "cast" scalar is turned back into an array: - let array = cast_scalar.to_array_of_size(10); + let array = cast_scalar + .to_array_of_size(10) + .expect("Failed to convert to array of size"); // The datatype should be "Dictionary" but is actually Utf8!!! assert_eq!(array.data_type(), &desired_type) @@ -5329,8 +4872,16 @@ mod tests { }; } - expect_operation_error!(expect_add_error, add, "Operator + is not implemented"); - expect_operation_error!(expect_sub_error, sub, "Operator - is not implemented"); + expect_operation_error!( + expect_add_error, + add, + "Invalid arithmetic operation: UInt64 + Int32" + ); + expect_operation_error!( + expect_sub_error, + sub, + "Invalid arithmetic operation: UInt64 - Int32" + ); macro_rules! decimal_op_test_cases { ($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => { @@ -5350,7 +4901,7 @@ mod tests { decimal_op_test_cases!( add, [ - [Some(123), 10, 2, Some(124), 10, 2, Some(123 + 124), 10, 2], + [Some(123), 10, 2, Some(124), 10, 2, Some(123 + 124), 11, 2], // test sum decimal with diff scale [ Some(123), @@ -5360,7 +4911,7 @@ mod tests { 10, 2, Some(123 + 124 * 10_i128.pow(1)), - 10, + 12, 3 ], // diff precision and scale for decimal data type @@ -5372,7 +4923,7 @@ mod tests { 11, 3, Some(123 * 10_i128.pow(3 - 2) + 124), - 11, + 12, 3 ] ] @@ -5385,17 +4936,17 @@ mod tests { add, [ // Case: (None, Some, 0) - [None, 10, 2, Some(123), 10, 2, Some(123), 10, 2], + [None, 10, 2, Some(123), 10, 2, None, 11, 2], // Case: (Some, None, 0) - [Some(123), 10, 2, None, 10, 2, Some(123), 10, 2], + [Some(123), 10, 2, None, 10, 2, None, 11, 2], // Case: (Some, None, _) + Side=False - [Some(123), 8, 2, None, 10, 3, Some(1230), 10, 3], + [Some(123), 8, 2, None, 10, 3, None, 11, 3], // Case: (None, Some, _) + Side=False - [None, 8, 2, Some(123), 10, 3, Some(123), 10, 3], + [None, 8, 2, Some(123), 10, 3, None, 11, 3], // Case: (Some, None, _) + Side=True - [Some(123), 8, 4, None, 10, 3, Some(123), 10, 4], + [Some(123), 8, 4, None, 10, 3, None, 12, 4], // Case: (None, Some, _) + Side=True - [None, 10, 3, Some(123), 8, 4, Some(123), 10, 4] + [None, 10, 3, Some(123), 8, 4, None, 12, 4] ] ); } @@ -5529,10 +5080,7 @@ mod tests { (ScalarValue::Int8(None), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Int16(None)), // Unsupported types - ( - ScalarValue::Utf8(Some("foo".to_string())), - ScalarValue::Utf8(Some("bar".to_string())), - ), + (ScalarValue::from("foo"), ScalarValue::from("bar")), ( ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -5543,11 +5091,6 @@ mod tests { ScalarValue::Decimal128(Some(123), 5, 5), ScalarValue::Decimal128(Some(120), 5, 5), ), - // Overflows - ( - ScalarValue::Int8(Some(i8::MAX)), - ScalarValue::Int8(Some(i8::MIN)), - ), ]; for (lhs, rhs) in cases { let distance = lhs.distance(&rhs); @@ -5595,36 +5138,6 @@ mod tests { ScalarValue::new_interval_mdn(12, 15, 123_456), ScalarValue::new_interval_mdn(24, 30, 246_912), ), - ( - ScalarValue::new_interval_ym(0, 1), - ScalarValue::new_interval_dt(29, 86_390), - ScalarValue::new_interval_mdn(1, 29, 86_390_000_000), - ), - ( - ScalarValue::new_interval_ym(0, 1), - ScalarValue::new_interval_mdn(2, 10, 999_999_999), - ScalarValue::new_interval_mdn(3, 10, 999_999_999), - ), - ( - ScalarValue::new_interval_dt(400, 123_456), - ScalarValue::new_interval_ym(1, 1), - ScalarValue::new_interval_mdn(13, 400, 123_456_000_000), - ), - ( - ScalarValue::new_interval_dt(65, 321), - ScalarValue::new_interval_mdn(2, 5, 1_000_000), - ScalarValue::new_interval_mdn(2, 70, 322_000_000), - ), - ( - ScalarValue::new_interval_mdn(12, 15, 123_456), - ScalarValue::new_interval_ym(2, 0), - ScalarValue::new_interval_mdn(36, 15, 123_456), - ), - ( - ScalarValue::new_interval_mdn(12, 15, 100_000), - ScalarValue::new_interval_dt(370, 1), - ScalarValue::new_interval_mdn(12, 385, 1_100_000), - ), ]; for (lhs, rhs, expected) in cases.iter() { let result = lhs.add(rhs).unwrap(); @@ -5652,36 +5165,6 @@ mod tests { ScalarValue::new_interval_mdn(12, 15, 123_456), ScalarValue::new_interval_mdn(0, 0, 0), ), - ( - ScalarValue::new_interval_ym(0, 1), - ScalarValue::new_interval_dt(29, 999_999), - ScalarValue::new_interval_mdn(1, -29, -999_999_000_000), - ), - ( - ScalarValue::new_interval_ym(0, 1), - ScalarValue::new_interval_mdn(2, 10, 999_999_999), - ScalarValue::new_interval_mdn(-1, -10, -999_999_999), - ), - ( - ScalarValue::new_interval_dt(400, 123_456), - ScalarValue::new_interval_ym(1, 1), - ScalarValue::new_interval_mdn(-13, 400, 123_456_000_000), - ), - ( - ScalarValue::new_interval_dt(65, 321), - ScalarValue::new_interval_mdn(2, 5, 1_000_000), - ScalarValue::new_interval_mdn(-2, 60, 320_000_000), - ), - ( - ScalarValue::new_interval_mdn(12, 15, 123_456), - ScalarValue::new_interval_ym(2, 0), - ScalarValue::new_interval_mdn(-12, 15, 123_456), - ), - ( - ScalarValue::new_interval_mdn(12, 15, 100_000), - ScalarValue::new_interval_dt(370, 1), - ScalarValue::new_interval_mdn(12, -355, -900_000), - ), ]; for (lhs, rhs, expected) in cases.iter() { let result = lhs.sub(rhs).unwrap(); @@ -5689,25 +5172,11 @@ mod tests { } } - #[test] - fn timestamp_op_tests() { - // positive interval, edge cases - let test_data = get_timestamp_test_data(1); - for (lhs, rhs, expected) in test_data.into_iter() { - assert_eq!(expected, lhs.sub(rhs).unwrap()) - } - - // negative interval, edge cases - let test_data = get_timestamp_test_data(-1); - for (rhs, lhs, expected) in test_data.into_iter() { - assert_eq!(expected, lhs.sub(rhs).unwrap()); - } - } #[test] fn timestamp_op_random_tests() { // timestamp1 + (or -) interval = timestamp2 // timestamp2 - timestamp1 (or timestamp1 - timestamp2) = interval ? - let sample_size = 1000000; + let sample_size = 1000; let timestamps1 = get_random_timestamps(sample_size); let intervals = get_random_intervals(sample_size); // ts(sec) + interval(ns) = ts(sec); however, @@ -5716,18 +5185,12 @@ mod tests { for (idx, ts1) in timestamps1.iter().enumerate() { if idx % 2 == 0 { let timestamp2 = ts1.add(intervals[idx].clone()).unwrap(); - assert_eq!( - intervals[idx], - timestamp2.sub(ts1).unwrap(), - "index:{idx}, operands: {timestamp2:?} (-) {ts1:?}" - ); + let back = timestamp2.sub(intervals[idx].clone()).unwrap(); + assert_eq!(ts1, &back); } else { let timestamp2 = ts1.sub(intervals[idx].clone()).unwrap(); - assert_eq!( - intervals[idx], - ts1.sub(timestamp2.clone()).unwrap(), - "index:{idx}, operands: {ts1:?} (-) {timestamp2:?}" - ); + let back = timestamp2.add(intervals[idx].clone()).unwrap(); + assert_eq!(ts1, &back); }; } } @@ -5806,289 +5269,40 @@ mod tests { let arrays = scalars .iter() .map(ScalarValue::to_array) - .collect::>(); + .collect::>>() + .expect("Failed to convert to array"); let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); let array = concat(&arrays).unwrap(); check_array(array); } - fn get_timestamp_test_data( - sign: i32, - ) -> Vec<(ScalarValue, ScalarValue, ScalarValue)> { - vec![ - ( - // 1st test case, having the same time but different with timezones - // Since they are timestamps with nanosecond precision, expected type is - // [`IntervalMonthDayNanoType`] - ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_nano_opt(12, 0, 0, 000_000_000) - .unwrap() - .timestamp_nanos(), - ), - Some("+12:00".into()), - ), - ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_nano_opt(0, 0, 0, 000_000_000) - .unwrap() - .timestamp_nanos(), - ), - Some("+00:00".into()), - ), - ScalarValue::new_interval_mdn(0, 0, 0), - ), - // 2nd test case, january with 31 days plus february with 28 days, with timezone - ( - ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(2023, 3, 1) - .unwrap() - .and_hms_micro_opt(2, 0, 0, 000_000) - .unwrap() - .timestamp_micros(), - ), - Some("+01:00".into()), - ), - ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_micro_opt(0, 0, 0, 000_000) - .unwrap() - .timestamp_micros(), - ), - Some("-01:00".into()), - ), - ScalarValue::new_interval_mdn(0, sign * 59, 0), - ), - // 3rd test case, 29-days long february minus previous, year with timezone - ( - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2024, 2, 29) - .unwrap() - .and_hms_milli_opt(10, 10, 0, 000) - .unwrap() - .timestamp_millis(), - ), - Some("+10:10".into()), - ), - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 12, 31) - .unwrap() - .and_hms_milli_opt(1, 0, 0, 000) - .unwrap() - .timestamp_millis(), - ), - Some("+01:00".into()), - ), - ScalarValue::new_interval_dt(sign * 60, 0), - ), - // 4th test case, leap years occur mostly every 4 years, but every 100 years - // we skip a leap year unless the year is divisible by 400, so 31 + 28 = 59 - ( - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2100, 3, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - Some("-11:59".into()), - ), - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2100, 1, 1) - .unwrap() - .and_hms_opt(23, 58, 0) - .unwrap() - .timestamp(), - ), - Some("+11:59".into()), - ), - ScalarValue::new_interval_dt(sign * 59, 0), - ), - // 5th test case, without timezone positively seemed, but with timezone, - // negative resulting interval - ( - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(6, 00, 0, 000) - .unwrap() - .timestamp_millis(), - ), - Some("+06:00".into()), - ), - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(0, 0, 0, 000) - .unwrap() - .timestamp_millis(), - ), - Some("-12:00".into()), - ), - ScalarValue::new_interval_dt(0, sign * -43_200_000), - ), - // 6th test case, no problem before unix epoch beginning - ( - ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(1970, 1, 1) - .unwrap() - .and_hms_micro_opt(1, 2, 3, 15) - .unwrap() - .timestamp_micros(), - ), - None, - ), - ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(1969, 1, 1) - .unwrap() - .and_hms_micro_opt(0, 0, 0, 000_000) - .unwrap() - .timestamp_micros(), - ), - None, - ), - ScalarValue::new_interval_mdn( - 0, - 365 * sign, - sign as i64 * 3_723_000_015_000, - ), - ), - // 7th test case, no problem with big intervals - ( - ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2100, 1, 1) - .unwrap() - .and_hms_nano_opt(0, 0, 0, 0) - .unwrap() - .timestamp_nanos(), - ), - None, - ), - ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2000, 1, 1) - .unwrap() - .and_hms_nano_opt(0, 0, 0, 000_000_000) - .unwrap() - .timestamp_nanos(), - ), - None, - ), - ScalarValue::new_interval_mdn(0, sign * 36525, 0), - ), - // 8th test case, no problem detecting 366-days long years - ( - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2041, 1, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - None, - ), - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2040, 1, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - None, - ), - ScalarValue::new_interval_dt(sign * 366, 0), - ), - // 9th test case, no problem with unrealistic timezones - ( - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 3) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - Some("+23:59".into()), - ), - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_opt(0, 2, 0) - .unwrap() - .timestamp(), - ), - Some("-23:59".into()), - ), - ScalarValue::new_interval_dt(0, 0), - ), - // 10th test case, parsing different types of timezone input - ( - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 3, 17) - .unwrap() - .and_hms_opt(14, 10, 0) - .unwrap() - .timestamp(), - ), - Some("Europe/Istanbul".into()), - ), - ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 3, 17) - .unwrap() - .and_hms_opt(4, 10, 0) - .unwrap() - .timestamp(), - ), - Some("America/Los_Angeles".into()), - ), - ScalarValue::new_interval_dt(0, 0), - ), - // 11th test case, negative results - ( - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 3, 17) - .unwrap() - .and_hms_milli_opt(4, 10, 0, 0) - .unwrap() - .timestamp_millis(), - ), - None, - ), - ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 3, 17) - .unwrap() - .and_hms_milli_opt(4, 10, 0, 1) - .unwrap() - .timestamp_millis(), - ), - None, - ), - ScalarValue::new_interval_dt(0, -sign), - ), - ] + #[test] + fn test_build_timestamp_millisecond_list() { + let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; + let arr = ScalarValue::new_list( + &values, + &DataType::Timestamp(TimeUnit::Millisecond, None), + ); + assert_eq!(1, arr.len()); + } + + #[test] + fn test_newlist_timestamp_zone() { + let s: &'static str = "UTC"; + let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))]; + let arr = ScalarValue::new_list( + &values, + &DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), + ); + assert_eq!(1, arr.len()); + assert_eq!( + arr.data_type(), + &DataType::List(Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), + true + ))) + ); } fn get_random_timestamps(sample_size: u64) -> Vec { @@ -6145,7 +5359,8 @@ mod tests { .unwrap() .and_hms_nano_opt(hour, minute, second, nanosec) .unwrap() - .timestamp_nanos(), + .timestamp_nanos_opt() + .unwrap(), ), None, )) @@ -6155,6 +5370,9 @@ mod tests { } fn get_random_intervals(sample_size: u64) -> Vec { + const MILLISECS_IN_ONE_DAY: i64 = 86_400_000; + const NANOSECS_IN_ONE_DAY: i64 = 86_400_000_000_000; + let vector_size = sample_size; let mut intervals = vec![]; let mut rng = rand::thread_rng(); diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index d0f150a3166e4..7ad8992ca9aec 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -17,35 +17,445 @@ //! This module provides data structures to represent statistics +use std::fmt::{self, Debug, Display}; + use crate::ScalarValue; +use arrow_schema::Schema; + +/// Represents a value with a degree of certainty. `Precision` is used to +/// propagate information the precision of statistical values. +#[derive(Clone, PartialEq, Eq, Default)] +pub enum Precision { + /// The exact value is known + Exact(T), + /// The value is not known exactly, but is likely close to this value + Inexact(T), + /// Nothing is known about the value + #[default] + Absent, +} + +impl Precision { + /// If we have some value (exact or inexact), it returns that value. + /// Otherwise, it returns `None`. + pub fn get_value(&self) -> Option<&T> { + match self { + Precision::Exact(value) | Precision::Inexact(value) => Some(value), + Precision::Absent => None, + } + } + + /// Transform the value in this [`Precision`] object, if one exists, using + /// the given function. Preserves the exactness state. + pub fn map(self, f: F) -> Precision + where + F: Fn(T) -> T, + { + match self { + Precision::Exact(val) => Precision::Exact(f(val)), + Precision::Inexact(val) => Precision::Inexact(f(val)), + _ => self, + } + } + + /// Returns `Some(true)` if we have an exact value, `Some(false)` if we + /// have an inexact value, and `None` if there is no value. + pub fn is_exact(&self) -> Option { + match self { + Precision::Exact(_) => Some(true), + Precision::Inexact(_) => Some(false), + _ => None, + } + } + + /// Returns the maximum of two (possibly inexact) values, conservatively + /// propagating exactness information. If one of the input values is + /// [`Precision::Absent`], the result is `Absent` too. + pub fn max(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + Precision::Exact(if a >= b { a.clone() } else { b.clone() }) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(if a >= b { a.clone() } else { b.clone() }) + } + (_, _) => Precision::Absent, + } + } + + /// Returns the minimum of two (possibly inexact) values, conservatively + /// propagating exactness information. If one of the input values is + /// [`Precision::Absent`], the result is `Absent` too. + pub fn min(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + Precision::Exact(if a >= b { b.clone() } else { a.clone() }) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(if a >= b { b.clone() } else { a.clone() }) + } + (_, _) => Precision::Absent, + } + } + + /// Demotes the precision state from exact to inexact (if present). + pub fn to_inexact(self) -> Self { + match self { + Precision::Exact(value) => Precision::Inexact(value), + _ => self, + } + } +} + +impl Precision { + /// Calculates the sum of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn add(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a + b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a + b), + (_, _) => Precision::Absent, + } + } + + /// Calculates the difference of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn sub(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a - b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a - b), + (_, _) => Precision::Absent, + } + } + + /// Calculates the multiplication of two (possibly inexact) [`usize`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn multiply(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a * b), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a * b), + (_, _) => Precision::Absent, + } + } + + /// Return the estimate of applying a filter with estimated selectivity + /// `selectivity` to this Precision. A selectivity of `1.0` means that all + /// rows are selected. A selectivity of `0.5` means half the rows are + /// selected. Will always return inexact statistics. + pub fn with_estimated_selectivity(self, selectivity: f64) -> Self { + self.map(|v| ((v as f64 * selectivity).ceil()) as usize) + .to_inexact() + } +} + +impl Precision { + /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values, + /// conservatively propagating exactness information. If one of the input + /// values is [`Precision::Absent`], the result is `Absent` too. + pub fn add(&self, other: &Precision) -> Precision { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => { + if let Ok(result) = a.add(b) { + Precision::Exact(result) + } else { + Precision::Absent + } + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + if let Ok(result) = a.add(b) { + Precision::Inexact(result) + } else { + Precision::Absent + } + } + (_, _) => Precision::Absent, + } + } +} + +impl Debug for Precision { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Precision::Exact(inner) => write!(f, "Exact({:?})", inner), + Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Absent => write!(f, "Absent"), + } + } +} + +impl Display for Precision { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Precision::Exact(inner) => write!(f, "Exact({:?})", inner), + Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Absent => write!(f, "Absent"), + } + } +} + /// Statistics for a relation /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons /// and the transformations output are not always predictable. -#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Statistics { - /// The number of table rows - pub num_rows: Option, - /// total bytes of the table rows - pub total_byte_size: Option, - /// Statistics on a column level - pub column_statistics: Option>, - /// If true, any field that is `Some(..)` is the actual value in the data provided by the operator (it is not - /// an estimate). Any or all other fields might still be None, in which case no information is known. - /// if false, any field that is `Some(..)` may contain an inexact estimate and may not be the actual value. - pub is_exact: bool, + /// The number of table rows. + pub num_rows: Precision, + /// Total bytes of the table rows. + pub total_byte_size: Precision, + /// Statistics on a column level. It contains a [`ColumnStatistics`] for + /// each field in the schema of the the table to which the [`Statistics`] refer. + pub column_statistics: Vec, +} + +impl Statistics { + /// Returns a [`Statistics`] instance for the given schema by assigning + /// unknown statistics to each column in the schema. + pub fn new_unknown(schema: &Schema) -> Self { + Self { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: Statistics::unknown_column(schema), + } + } + + /// Returns an unbounded `ColumnStatistics` for each field in the schema. + pub fn unknown_column(schema: &Schema) -> Vec { + schema + .fields() + .iter() + .map(|_| ColumnStatistics::new_unknown()) + .collect() + } + + /// If the exactness of a [`Statistics`] instance is lost, this function relaxes + /// the exactness of all information by converting them [`Precision::Inexact`]. + pub fn into_inexact(self) -> Self { + Statistics { + num_rows: self.num_rows.to_inexact(), + total_byte_size: self.total_byte_size.to_inexact(), + column_statistics: self + .column_statistics + .into_iter() + .map(|cs| ColumnStatistics { + null_count: cs.null_count.to_inexact(), + max_value: cs.max_value.to_inexact(), + min_value: cs.min_value.to_inexact(), + distinct_count: cs.distinct_count.to_inexact(), + }) + .collect::>(), + } + } +} + +impl Display for Statistics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // string of column statistics + let column_stats = self + .column_statistics + .iter() + .enumerate() + .map(|(i, cs)| { + let s = format!("(Col[{}]:", i); + let s = if cs.min_value != Precision::Absent { + format!("{} Min={}", s, cs.min_value) + } else { + s + }; + let s = if cs.max_value != Precision::Absent { + format!("{} Max={}", s, cs.max_value) + } else { + s + }; + let s = if cs.null_count != Precision::Absent { + format!("{} Null={}", s, cs.null_count) + } else { + s + }; + let s = if cs.distinct_count != Precision::Absent { + format!("{} Distinct={}", s, cs.distinct_count) + } else { + s + }; + + s + ")" + }) + .collect::>() + .join(","); + + write!( + f, + "Rows={}, Bytes={}, [{}]", + self.num_rows, self.total_byte_size, column_stats + )?; + + Ok(()) + } } /// Statistics for a column within a relation -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub struct ColumnStatistics { /// Number of null values on column - pub null_count: Option, + pub null_count: Precision, /// Maximum value of column - pub max_value: Option, + pub max_value: Precision, /// Minimum value of column - pub min_value: Option, + pub min_value: Precision, /// Number of distinct values - pub distinct_count: Option, + pub distinct_count: Precision, +} + +impl ColumnStatistics { + /// Column contains a single non null value (e.g constant). + pub fn is_singleton(&self) -> bool { + match (&self.min_value, &self.max_value) { + // Min and max values are the same and not infinity. + (Precision::Exact(min), Precision::Exact(max)) => { + !min.is_null() && !max.is_null() && (min == max) + } + (_, _) => false, + } + } + + /// Returns a [`ColumnStatistics`] instance having all [`Precision::Absent`] parameters. + pub fn new_unknown() -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + distinct_count: Precision::Absent, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_value() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::::Absent; + + assert_eq!(*exact_precision.get_value().unwrap(), 42); + assert_eq!(*inexact_precision.get_value().unwrap(), 23); + assert_eq!(absent_precision.get_value(), None); + } + + #[test] + fn test_map() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::Absent; + + let squared = |x| x * x; + + assert_eq!(exact_precision.map(squared), Precision::Exact(1764)); + assert_eq!(inexact_precision.map(squared), Precision::Inexact(529)); + assert_eq!(absent_precision.map(squared), Precision::Absent); + } + + #[test] + fn test_is_exact() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(23); + let absent_precision = Precision::::Absent; + + assert_eq!(exact_precision.is_exact(), Some(true)); + assert_eq!(inexact_precision.is_exact(), Some(false)); + assert_eq!(absent_precision.is_exact(), None); + } + + #[test] + fn test_max() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.max(&precision2), Precision::Inexact(42)); + assert_eq!(precision1.max(&precision3), Precision::Exact(42)); + assert_eq!(precision2.max(&precision3), Precision::Inexact(30)); + assert_eq!(precision1.max(&absent_precision), Precision::Absent); + } + + #[test] + fn test_min() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.min(&precision2), Precision::Inexact(23)); + assert_eq!(precision1.min(&precision3), Precision::Exact(30)); + assert_eq!(precision2.min(&precision3), Precision::Inexact(23)); + assert_eq!(precision1.min(&absent_precision), Precision::Absent); + } + + #[test] + fn test_to_inexact() { + let exact_precision = Precision::Exact(42); + let inexact_precision = Precision::Inexact(42); + let absent_precision = Precision::::Absent; + + assert_eq!(exact_precision.clone().to_inexact(), inexact_precision); + assert_eq!(inexact_precision.clone().to_inexact(), inexact_precision); + assert_eq!(absent_precision.clone().to_inexact(), absent_precision); + } + + #[test] + fn test_add() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.add(&precision2), Precision::Inexact(65)); + assert_eq!(precision1.add(&precision3), Precision::Exact(72)); + assert_eq!(precision2.add(&precision3), Precision::Inexact(53)); + assert_eq!(precision1.add(&absent_precision), Precision::Absent); + } + + #[test] + fn test_sub() { + let precision1 = Precision::Exact(42); + let precision2 = Precision::Inexact(23); + let precision3 = Precision::Exact(30); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.sub(&precision2), Precision::Inexact(19)); + assert_eq!(precision1.sub(&precision3), Precision::Exact(12)); + assert_eq!(precision1.sub(&absent_precision), Precision::Absent); + } + + #[test] + fn test_multiply() { + let precision1 = Precision::Exact(6); + let precision2 = Precision::Inexact(3); + let precision3 = Precision::Exact(5); + let absent_precision = Precision::Absent; + + assert_eq!(precision1.multiply(&precision2), Precision::Inexact(18)); + assert_eq!(precision1.multiply(&precision3), Precision::Exact(30)); + assert_eq!(precision2.multiply(&precision3), Precision::Inexact(15)); + assert_eq!(precision1.multiply(&absent_precision), Precision::Absent); + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index cd05f8082dab3..55681ece1016b 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -299,7 +299,7 @@ impl<'a> TableReference<'a> { /// Forms a [`TableReference`] by parsing `s` as a multipart SQL /// identifier. See docs on [`TableReference`] for more details. pub fn parse_str(s: &'a str) -> Self { - let mut parts = parse_identifiers_normalized(s); + let mut parts = parse_identifiers_normalized(s, false); match parts.len() { 1 => Self::Bare { diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index d6f80b8b16267..9a44337821570 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -17,6 +17,87 @@ //! Utility functions to make testing DataFusion based crates easier +use std::{error::Error, path::PathBuf}; + +/// Compares formatted output of a record batch with an expected +/// vector of strings, with the result of pretty formatting record +/// batches. This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// +/// `assert_batch_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +#[macro_export] +macro_rules! assert_batches_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let expected_lines: Vec = + $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); + + let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( + $CHUNKS, + &$crate::format::DEFAULT_FORMAT_OPTIONS, + ) + .unwrap() + .to_string(); + + let actual_lines: Vec<&str> = formatted.trim().lines().collect(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} + +/// Compares formatted output of a record batch with an expected +/// vector of strings in a way that order does not matter. +/// This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// +/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +#[macro_export] +macro_rules! assert_batches_sorted_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let mut expected_lines: Vec = + $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); + + // sort except for header + footer + let num_lines = expected_lines.len(); + if num_lines > 3 { + expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() + } + + let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( + $CHUNKS, + &$crate::format::DEFAULT_FORMAT_OPTIONS, + ) + .unwrap() + .to_string(); + // fix for windows: \r\n --> + + let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); + + // sort except for header + footer + let num_lines = actual_lines.len(); + if num_lines > 3 { + actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() + } + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} + /// A macro to assert that one string is contained within another with /// a nice error message if they are not. /// @@ -62,3 +143,153 @@ macro_rules! assert_not_contains { ); }; } + +/// Returns the arrow test data directory, which is by default stored +/// in a git submodule rooted at `testing/data`. +/// +/// The default can be overridden by the optional environment +/// variable `ARROW_TEST_DATA` +/// +/// panics when the directory can not be found. +/// +/// Example: +/// ``` +/// let testdata = datafusion_common::test_util::arrow_test_data(); +/// let csvdata = format!("{}/csv/aggregate_test_100.csv", testdata); +/// assert!(std::path::PathBuf::from(csvdata).exists()); +/// ``` +pub fn arrow_test_data() -> String { + match get_data_dir("ARROW_TEST_DATA", "../../testing/data") { + Ok(pb) => pb.display().to_string(), + Err(err) => panic!("failed to get arrow data dir: {err}"), + } +} + +/// Returns the parquet test data directory, which is by default +/// stored in a git submodule rooted at +/// `parquet-testing/data`. +/// +/// The default can be overridden by the optional environment variable +/// `PARQUET_TEST_DATA` +/// +/// panics when the directory can not be found. +/// +/// Example: +/// ``` +/// let testdata = datafusion_common::test_util::parquet_test_data(); +/// let filename = format!("{}/binary.parquet", testdata); +/// assert!(std::path::PathBuf::from(filename).exists()); +/// ``` +#[cfg(feature = "parquet")] +pub fn parquet_test_data() -> String { + match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") { + Ok(pb) => pb.display().to_string(), + Err(err) => panic!("failed to get parquet data dir: {err}"), + } +} + +/// Returns a directory path for finding test data. +/// +/// udf_env: name of an environment variable +/// +/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR) +/// +/// Returns either: +/// The path referred to in `udf_env` if that variable is set and refers to a directory +/// The submodule_data directory relative to CARGO_MANIFEST_PATH +pub fn get_data_dir( + udf_env: &str, + submodule_data: &str, +) -> Result> { + // Try user defined env. + if let Ok(dir) = std::env::var(udf_env) { + let trimmed = dir.trim().to_string(); + if !trimmed.is_empty() { + let pb = PathBuf::from(trimmed); + if pb.is_dir() { + return Ok(pb); + } else { + return Err(format!( + "the data dir `{}` defined by env {} not found", + pb.display(), + udf_env + ) + .into()); + } + } + } + + // The env is undefined or its value is trimmed to empty, let's try default dir. + + // env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package", + // set by `cargo run` or `cargo test`, see: + // https://doc.rust-lang.org/cargo/reference/environment-variables.html + let dir = env!("CARGO_MANIFEST_DIR"); + + let pb = PathBuf::from(dir).join(submodule_data); + if pb.is_dir() { + Ok(pb) + } else { + Err(format!( + "env `{}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\ + HINT: try running `git submodule update --init`", + udf_env, + pb.display(), + ).into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + #[test] + fn test_data_dir() { + let udf_env = "get_data_dir"; + let cwd = env::current_dir().unwrap(); + + let existing_pb = cwd.join(".."); + let existing = existing_pb.display().to_string(); + let existing_str = existing.as_str(); + + let non_existing = cwd.join("non-existing-dir").display().to_string(); + let non_existing_str = non_existing.as_str(); + + env::set_var(udf_env, non_existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_err()); + + env::set_var(udf_env, ""); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, " "); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::set_var(udf_env, existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + + env::remove_var(udf_env); + let res = get_data_dir(udf_env, non_existing_str); + assert!(res.is_err()); + + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + } + + #[test] + fn test_happy() { + let res = arrow_test_data(); + assert!(PathBuf::from(res).is_dir()); + + let res = parquet_test_data(); + assert!(PathBuf::from(res).is_dir()); + } +} diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2919d9a39c9c8..5da9636ffe185 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -125,6 +125,17 @@ pub trait TreeNode: Sized { after_op.map_children(|node| node.transform_down(op)) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op = op(self)?.into(); + after_op.map_children(|node| node.transform_down_mut(op)) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. @@ -138,6 +149,19 @@ pub trait TreeNode: Sized { Ok(new_node) } + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal) using a mutable function, `F`. + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up_mut(self, op: &mut F) -> Result + where + F: FnMut(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; + + let new_node = op(after_op_children)?.into(); + Ok(new_node) + } + /// Transform the tree node using the given [TreeNodeRewriter] /// It performs a depth first walk of an node and its children. /// diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs new file mode 100644 index 0000000000000..fd92267f9b4c3 --- /dev/null +++ b/datafusion/common/src/unnest.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`UnnestOptions`] for unnesting structured types + +/// Options for unnesting a column that contains a list type, +/// replicating values in the other, non nested rows. +/// +/// Conceptually this operation is like joining each row with all the +/// values in the list column. +/// +/// If `preserve_nulls` is false, nulls and empty lists +/// from the input column are not carried through to the output. This +/// is the default behavior for other systems such as ClickHouse and +/// DuckDB +/// +/// If `preserve_nulls` is true (the default), nulls from the input +/// column are carried through to the output. +/// +/// # Examples +/// +/// ## `Unnest(c1)`, preserve_nulls: false +/// ```text +/// ┌─────────┐ ┌─────┐ ┌─────────┐ ┌─────┐ +/// │ {1, 2} │ │ A │ Unnest │ 1 │ │ A │ +/// ├─────────┤ ├─────┤ ├─────────┤ ├─────┤ +/// │ null │ │ B │ │ 2 │ │ A │ +/// ├─────────┤ ├─────┤ ────────────▶ ├─────────┤ ├─────┤ +/// │ {} │ │ D │ │ 3 │ │ E │ +/// ├─────────┤ ├─────┤ └─────────┘ └─────┘ +/// │ {3} │ │ E │ c1 c2 +/// └─────────┘ └─────┘ +/// c1 c2 +/// ``` +/// +/// ## `Unnest(c1)`, preserve_nulls: true +/// ```text +/// ┌─────────┐ ┌─────┐ ┌─────────┐ ┌─────┐ +/// │ {1, 2} │ │ A │ Unnest │ 1 │ │ A │ +/// ├─────────┤ ├─────┤ ├─────────┤ ├─────┤ +/// │ null │ │ B │ │ 2 │ │ A │ +/// ├─────────┤ ├─────┤ ────────────▶ ├─────────┤ ├─────┤ +/// │ {} │ │ D │ │ null │ │ B │ +/// ├─────────┤ ├─────┤ ├─────────┤ ├─────┤ +/// │ {3} │ │ E │ │ 3 │ │ E │ +/// └─────────┘ └─────┘ └─────────┘ └─────┘ +/// c1 c2 c1 c2 +/// ``` +#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] +pub struct UnnestOptions { + /// Should nulls in the input be preserved? Defaults to true + pub preserve_nulls: bool, +} + +impl Default for UnnestOptions { + fn default() -> Self { + Self { + // default to true to maintain backwards compatible behavior + preserve_nulls: true, + } + } +} + +impl UnnestOptions { + /// Create a new [`UnnestOptions`] with default values + pub fn new() -> Self { + Default::default() + } + + /// Set the behavior with nulls in the input as described on + /// [`Self`] + pub fn with_preserve_nulls(mut self, preserve_nulls: bool) -> Self { + self.preserve_nulls = preserve_nulls; + self + } +} diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 2edcd07846da1..fecab8835e50a 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -17,20 +17,65 @@ //! This module provides the bisect function, which implements binary search. +use crate::error::_internal_err; use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; +use arrow::buffer::OffsetBuffer; use arrow::compute; -use arrow::compute::{lexicographical_partition_ranges, SortColumn, SortOptions}; -use arrow::datatypes::UInt32Type; +use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; use std::borrow::{Borrow, Cow}; use std::cmp::Ordering; +use std::collections::HashSet; use std::ops::Range; use std::sync::Arc; +/// Applies an optional projection to a [`SchemaRef`], returning the +/// projected schema +/// +/// Example: +/// ``` +/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; +/// use datafusion_common::project_schema; +/// +/// // Schema with columns 'a', 'b', and 'c' +/// let schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Int64, true), +/// Field::new("c", DataType::Utf8, true), +/// ])); +/// +/// // Pick columns 'c' and 'b' +/// let projection = Some(vec![2,1]); +/// let projected_schema = project_schema( +/// &schema, +/// projection.as_ref() +/// ).unwrap(); +/// +/// let expected_schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Int64, true), +/// ])); +/// +/// assert_eq!(projected_schema, expected_schema); +/// ``` +pub fn project_schema( + schema: &SchemaRef, + projection: Option<&Vec>, +) -> Result { + let schema = match projection { + Some(columns) => Arc::new(schema.project(columns)?), + None => Arc::clone(schema), + }; + Ok(schema) +} + /// Given column vectors, returns row at `idx`. pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result> { columns @@ -45,8 +90,12 @@ pub fn get_record_batch_at_indices( indices: &PrimitiveArray, ) -> Result { let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new(record_batch.schema(), new_columns) - .map_err(DataFusionError::ArrowError) + RecordBatch::try_new_with_options( + record_batch.schema(), + new_columns, + &RecordBatchOptions::new().with_row_count(Some(indices.len())), + ) + .map_err(DataFusionError::ArrowError) } /// This function compares two tuples depending on the given sort options. @@ -90,7 +139,7 @@ pub fn bisect( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -141,7 +190,7 @@ pub fn linear_search( ) -> Result { let low: usize = 0; let high: usize = item_columns - .get(0) + .first() .ok_or_else(|| { DataFusionError::Internal("Column array shouldn't be empty".to_string()) })? @@ -177,9 +226,10 @@ where Ok(low) } -/// This function finds the partition points according to `partition_columns`. -/// If there are no sort columns, then the result will be a single element -/// vector containing one partition range spanning all data. +/// Given a list of 0 or more already sorted columns, finds the +/// partition ranges that would partition equally across columns. +/// +/// See [`partition`] for more details. pub fn evaluate_partition_ranges( num_rows: usize, partition_columns: &[SortColumn], @@ -190,7 +240,8 @@ pub fn evaluate_partition_ranges( end: num_rows, }] } else { - lexicographical_partition_ranges(partition_columns)?.collect() + let cols: Vec<_> = partition_columns.iter().map(|x| x.values.clone()).collect(); + partition(&cols)?.ranges() }) } @@ -245,13 +296,14 @@ pub fn get_arrayref_at_indices( .collect() } -pub(crate) fn parse_identifiers_normalized(s: &str) -> Vec { +pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() .into_iter() .map(|id| match id.quote_style { Some(_) => id.value, - None => id.value.to_ascii_lowercase(), + None if ignore_case => id.value, + _ => id.value.to_ascii_lowercase(), }) .collect::>() } @@ -290,6 +342,102 @@ pub fn longest_consecutive_prefix>( count } +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_list_array(arr: ArrayRef) -> ListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + ListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + +/// Wrap an array into a single element `LargeListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + +/// Wrap arrays into a single element `ListArray`. +/// +/// Example: +/// ``` +/// use arrow::array::{Int32Array, ListArray, ArrayRef}; +/// use arrow::datatypes::{Int32Type, Field}; +/// use std::sync::Arc; +/// +/// let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; +/// let arr2 = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef; +/// +/// let list_arr = datafusion_common::utils::arrays_into_list_array([arr1, arr2]).unwrap(); +/// +/// let expected = ListArray::from_iter_primitive::( +/// vec![ +/// Some(vec![Some(1), Some(2), Some(3)]), +/// Some(vec![Some(4), Some(5), Some(6)]), +/// ] +/// ); +/// +/// assert_eq!(list_arr, expected); +pub fn arrays_into_list_array( + arr: impl IntoIterator, +) -> Result { + let arr = arr.into_iter().collect::>(); + if arr.is_empty() { + return _internal_err!("Cannot wrap empty array into list array"); + } + + let lens = arr.iter().map(|x| x.len()).collect::>(); + // Assume data type is consistent + let data_type = arr[0].data_type().to_owned(); + let values = arr.iter().map(|x| x.as_ref()).collect::>(); + Ok(ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(lens), + arrow::compute::concat(values.as_slice())?, + None, + )) +} + +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + if let DataType::List(field) = data_type { + base_type(field.data_type()) + } else { + data_type.to_owned() + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + if let DataType::List(field) = data_type { + 1 + list_ndims(field.data_type()) + } else { + 0 + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// @@ -386,6 +534,64 @@ pub mod datafusion_strsim { } } +/// Merges collections `first` and `second`, removes duplicates and sorts the +/// result, returning it as a [`Vec`]. +pub fn merge_and_order_indices, S: Borrow>( + first: impl IntoIterator, + second: impl IntoIterator, +) -> Vec { + let mut result: Vec<_> = first + .into_iter() + .map(|e| *e.borrow()) + .chain(second.into_iter().map(|e| *e.borrow())) + .collect::>() + .into_iter() + .collect(); + result.sort(); + result +} + +/// Calculates the set difference between sequences `first` and `second`, +/// returning the result as a [`Vec`]. Preserves the ordering of `first`. +pub fn set_difference, S: Borrow>( + first: impl IntoIterator, + second: impl IntoIterator, +) -> Vec { + let set: HashSet<_> = second.into_iter().map(|e| *e.borrow()).collect(); + first + .into_iter() + .map(|e| *e.borrow()) + .filter(|e| !set.contains(e)) + .collect() +} + +/// Checks whether the given index sequence is monotonically non-decreasing. +pub fn is_sorted>(sequence: impl IntoIterator) -> bool { + // TODO: Remove this function when `is_sorted` graduates from Rust nightly. + let mut previous = 0; + for item in sequence.into_iter() { + let current = *item.borrow(); + if current < previous { + return false; + } + previous = current; + } + true +} + +/// Find indices of each element in `targets` inside `items`. If one of the +/// elements is absent in `items`, returns an error. +pub fn find_indices>( + items: &[T], + targets: impl IntoIterator, +) -> Result> { + targets + .into_iter() + .map(|target| items.iter().position(|e| target.borrow().eq(e))) + .collect::>() + .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) +} + #[cfg(test)] mod tests { use crate::ScalarValue; @@ -704,4 +910,49 @@ mod tests { "cloned `Arc` should point to same data as the original" ); } + + #[test] + fn test_merge_and_order_indices() { + assert_eq!( + merge_and_order_indices([0, 3, 4], [1, 3, 5]), + vec![0, 1, 3, 4, 5] + ); + // Result should be ordered, even if inputs are not + assert_eq!( + merge_and_order_indices([3, 0, 4], [5, 1, 3]), + vec![0, 1, 3, 4, 5] + ); + } + + #[test] + fn test_set_difference() { + assert_eq!(set_difference([0, 3, 4], [1, 2]), vec![0, 3, 4]); + assert_eq!(set_difference([0, 3, 4], [1, 2, 4]), vec![0, 3]); + // return value should have same ordering with the in1 + assert_eq!(set_difference([3, 4, 0], [1, 2, 4]), vec![3, 0]); + assert_eq!(set_difference([0, 3, 4], [4, 1, 2]), vec![0, 3]); + assert_eq!(set_difference([3, 4, 0], [4, 1, 2]), vec![3, 0]); + } + + #[test] + fn test_is_sorted() { + assert!(is_sorted::([])); + assert!(is_sorted([0])); + assert!(is_sorted([0, 3, 4])); + assert!(is_sorted([0, 1, 2])); + assert!(is_sorted([0, 1, 4])); + assert!(is_sorted([0usize; 0])); + assert!(is_sorted([1, 2])); + assert!(!is_sorted([3, 2])); + } + + #[test] + fn test_find_indices() -> Result<()> { + assert_eq!(find_indices(&[0, 3, 4], [0, 3, 4])?, vec![0, 1, 2]); + assert_eq!(find_indices(&[0, 3, 4], [0, 4, 3])?, vec![0, 2, 1]); + assert_eq!(find_indices(&[3, 0, 4], [0, 3])?, vec![1, 0]); + assert!(find_indices(&[0, 3], [0, 3, 4]).is_err()); + assert!(find_indices(&[0, 3, 4], [0, 2]).is_err()); + Ok(()) + } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 65d203eca60bb..7caf91e24f2f3 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -20,14 +20,14 @@ name = "datafusion" description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" keywords = ["arrow", "query", "sql"] include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70" [lib] name = "datafusion" @@ -36,89 +36,94 @@ path = "src/lib.rs" [features] # Used to enable the avro format avro = ["apache-avro", "num-traits", "datafusion-common/avro"] +backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression"] crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"] -default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "compression"] -# Enables support for non-scalar, binary operations on dictionaries -# Note: this results in significant additional codegen -dictionary_expressions = ["datafusion-physical-expr/dictionary_expressions", "datafusion-optimizer/dictionary_expressions"] +default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"] +encoding_expressions = ["datafusion-physical-expr/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] -pyarrow = ["datafusion-common/pyarrow"] +parquet = ["datafusion-common/parquet", "dep:parquet"] +pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = ["datafusion-physical-expr/regex_expressions", "datafusion-optimizer/regex_expressions"] +serde = ["arrow-schema/serde"] simd = ["arrow/simd"] unicode_expressions = ["datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions"] [dependencies] ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } -apache-avro = { version = "0.14", optional = true } +apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-schema = { workspace = true } async-compression = { version = "0.4.0", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } -async-trait = "0.1.41" -bytes = "1.4" +async-trait = { workspace = true } +bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } -chrono = { version = "0.4.23", default-features = false } -dashmap = "5.4.0" -datafusion-common = { path = "../common", version = "26.0.0", features = ["parquet", "object_store"] } -datafusion-execution = { path = "../execution", version = "26.0.0" } -datafusion-expr = { path = "../expr", version = "26.0.0" } -datafusion-optimizer = { path = "../optimizer", version = "26.0.0", default-features = false } -datafusion-physical-expr = { path = "../physical-expr", version = "26.0.0", default-features = false } -datafusion-row = { path = "../row", version = "26.0.0" } -datafusion-sql = { path = "../sql", version = "26.0.0" } +chrono = { workspace = true } +dashmap = { workspace = true } +datafusion-common = { path = "../common", version = "33.0.0", features = ["object_store"], default-features = false } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-optimizer = { path = "../optimizer", version = "33.0.0", default-features = false } +datafusion-physical-expr = { path = "../physical-expr", version = "33.0.0", default-features = false } +datafusion-physical-plan = { workspace = true } +datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } -futures = "0.3" +futures = { workspace = true } glob = "0.3.0" +half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } -indexmap = "1.9.2" -itertools = "0.10" -lazy_static = { version = "^1.4.0" } -log = "^0.4" +indexmap = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } num-traits = { version = "0.2", optional = true } -num_cpus = "1.13.0" -object_store = "0.6.1" -parking_lot = "0.12" -parquet = { workspace = true } -percent-encoding = "2.2.0" +num_cpus = { workspace = true } +object_store = { workspace = true } +parking_lot = { workspace = true } +parquet = { workspace = true, optional = true, default-features = true } pin-project-lite = "^0.2.7" -rand = "0.8" -smallvec = { version = "1.6", features = ["union"] } -sqlparser = { version = "0.34", features = ["visitor"] } -tempfile = "3" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } +rand = { workspace = true } +sqlparser = { workspace = true } +tempfile = { workspace = true } +tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-util = { version = "0.7.4", features = ["io"] } -url = "2.2" +url = { workspace = true } uuid = { version = "1.0", features = ["v4"] } xz2 = { version = "0.1", optional = true } -zstd = { version = "0.12", optional = true, default-features = false } - +zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] -async-trait = "0.1.53" -bigdecimal = "0.3.0" +async-trait = { workspace = true } +bigdecimal = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" -ctor = "0.2.0" -doc-comment = "0.3" -env_logger = "0.10" -half = "2.2.1" +ctor = { workspace = true } +doc-comment = { workspace = true } +env_logger = { workspace = true } +half = { workspace = true } postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } -rstest = "0.17.0" +rand = { version = "0.8", features = ["small_rng"] } +rand_distr = "0.4.3" +regex = "1.5.4" +rstest = { workspace = true } rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } -sqllogictest = "0.13.2" +serde_json = { workspace = true } test-utils = { path = "../../test-utils" } -thiserror = "1.0.37" +thiserror = { workspace = true } tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = "0.26.1" +nix = { version = "0.27.1", features = ["fs"] } [[bench]] harness = false name = "aggregate_query_sql" +[[bench]] +harness = false +name = "distinct_query_sql" + [[bench]] harness = false name = "sort_limit_query_sql" @@ -159,7 +164,10 @@ name = "sql_query_with_io" harness = false name = "sort" -[[test]] +[[bench]] +harness = false +name = "topk_aggregate" + +[[bench]] harness = false -name = "sqllogictests" -path = "tests/sqllogictests/src/main.rs" +name = "array_expression" diff --git a/datafusion/row/README.md b/datafusion/core/README.md similarity index 72% rename from datafusion/row/README.md rename to datafusion/core/README.md index eef4dfd554e38..5a9493d086cd1 100644 --- a/datafusion/row/README.md +++ b/datafusion/core/README.md @@ -17,13 +17,10 @@ under the License. --> -# DataFusion Row +# DataFusion Common -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate is a submodule of DataFusion that provides an optimized row based format for row-based operations. - -See the documentation in [`lib.rs`] for more details. +This crate contains the main entrypoints and high level DataFusion APIs such as SessionContext, and DataFrame and ListingTable. [df]: https://crates.io/crates/datafusion -[`lib.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion/row/src/lib.rs diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs new file mode 100644 index 0000000000000..95bc93e0e353a --- /dev/null +++ b/datafusion/core/benches/array_expression.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::{ArrayRef, Int64Array, ListArray}; +use datafusion_physical_expr::array_expressions; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // Construct large arrays for benchmarking + + let array_len = 100000000; + + let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); + let list_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + let from_array = Int64Array::from_value(2, 3); + let to_array = Int64Array::from_value(-2, 3); + + let args = vec![ + Arc::new(list_array) as ArrayRef, + Arc::new(from_array) as ArrayRef, + Arc::new(to_array) as ArrayRef, + ]; + + let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); + let expected_array = ListArray::from_iter_primitive::(vec![ + Some(array.clone()), + Some(array.clone()), + Some(array), + ]); + + // Benchmark array functions + + c.bench_function("array_replace", |b| { + b.iter(|| { + assert_eq!( + array_expressions::array_replace_all(args.as_slice()) + .unwrap() + .as_list::(), + criterion::black_box(&expected_array) + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 9169c3dda48ed..9d2864919225a 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -25,11 +25,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use arrow_array::builder::{Int64Builder, StringBuilder}; use datafusion::datasource::MemTable; use datafusion::error::Result; +use datafusion_common::DataFusionError; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::fmt::Write; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, @@ -107,7 +112,7 @@ fn create_record_batch( ) -> RecordBatch { // the 4 here is the number of different keys. // a higher number increase sparseness - let vs = vec![0, 1, 2, 3]; + let vs = [0, 1, 2, 3]; let keys: Vec = (0..batch_size) .map( // use random numbers to avoid spurious compiler optimizations wrt to branching @@ -156,3 +161,83 @@ pub fn create_record_batches( }) .collect::>() } + +/// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition +/// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[allow(dead_code)] +pub(crate) fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02X}"); + output + }) + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +/// The Schema used by make_data +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs new file mode 100644 index 0000000000000..c242798a56f00 --- /dev/null +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use data_utils::{create_table_provider, make_data}; +use datafusion::execution::context::SessionContext; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; + +use parking_lot::Mutex; +use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; + +fn query(ctx: Arc>, sql: &str) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + array_len: usize, + batch_size: usize, +) -> Result>> { + let ctx = SessionContext::new(); + let provider = create_table_provider(partitions_len, array_len, batch_size)?; + ctx.register_table("t", provider)?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark_limited_distinct(c: &mut Criterion) { + let partitions_len = 10; + let array_len = 1 << 26; // 64 M + let batch_size = 8192; + let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); + + let mut group = c.benchmark_group("custom-measurement-time"); + group.measurement_time(Duration::from_secs(40)); + + group.bench_function("distinct_group_by_u64_narrow_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_100", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 100", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_1000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 1000", + ) + }) + }); + + group.bench_function("distinct_group_by_u64_narrow_limit_10000", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT DISTINCT u64_narrow FROM t GROUP BY u64_narrow LIMIT 10000", + ) + }) + }); + + group.bench_function("group_by_multiple_columns_limit_10", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, u64_wide, utf8, f64 FROM t GROUP BY 1, 2, 3, 4 LIMIT 10", + ) + }) + }); + group.finish(); +} + +async fn distinct_with_limit( + plan: Arc, + ctx: Arc, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + Ok(()) +} + +fn run(plan: Arc, ctx: Arc) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { distinct_with_limit(plan.clone(), ctx.clone()).await }), + ) + .unwrap(); +} + +pub async fn create_context_sampled_data( + sql: &str, + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, false /* asc */).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let cfg = SessionConfig::new(); + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + Ok((physical_plan, ctx.task_ctx())) +} + +fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limit = 10; + let partitions = 100; + let samples = 100_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_100_partitions_100_000_samples_limit_100 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_100_partitions_100_000_samples_limit_100.0.clone(), + distinct_trace_id_100_partitions_100_000_samples_limit_100.1.clone())), + ); + + let partitions = 10; + let samples = 1_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let distinct_trace_id_10_partitions_1_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_10_partitions_1_000_000_samples_limit_10.0.clone(), + distinct_trace_id_10_partitions_1_000_000_samples_limit_10.1.clone())), + ); + + let partitions = 1; + let samples = 10_000_000; + let sql = + format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); + + let rt = Runtime::new().unwrap(); + let distinct_trace_id_1_partition_10_000_000_samples_limit_10 = rt.block_on(async { + create_context_sampled_data(sql.as_str(), partitions, samples) + .await + .unwrap() + }); + + c.bench_function( + format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + |b| b.iter(|| run(distinct_trace_id_1_partition_10_000_000_samples_limit_10.0.clone(), + distinct_trace_id_1_partition_10_000_000_samples_limit_10.1.clone())), + ); +} + +criterion_group!( + benches, + criterion_benchmark_limited_distinct, + criterion_benchmark_limited_distinct_sampled +); +criterion_main!(benches); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index 876b1fe7e198c..6c9ab315761e3 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { let partitions = 4; let config = SessionConfig::new().with_target_partitions(partitions); - let context = SessionContext::with_config(config); + let context = SessionContext::new_with_config(config); let local_rt = tokio::runtime::Builder::new_current_thread() .build() diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 30f21a964d5f7..540f7212e96e9 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -22,7 +22,15 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_array_of_size 100000", |b| { let scalar = ScalarValue::Int32(Some(100)); - b.iter(|| assert_eq!(scalar.to_array_of_size(100000).null_count(), 0)) + b.iter(|| { + assert_eq!( + scalar + .to_array_of_size(100000) + .expect("Failed to convert to array of size") + .null_count(), + 0 + ) + }) }); } diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 4045702d6308e..fbb94d66db581 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -329,8 +329,8 @@ fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { let mut tuples: Vec<_> = gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values().into_iter()) - .zip(gen.utf8_high_cardinality_values().into_iter()) + .zip(gen.utf8_low_cardinality_values()) + .zip(gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -362,9 +362,9 @@ fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { let mut tuples: Vec<_> = gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values().into_iter()) - .zip(gen.utf8_low_cardinality_values().into_iter()) - .zip(gen.i64_values().into_iter()) + .zip(gen.utf8_low_cardinality_values()) + .zip(gen.utf8_low_cardinality_values()) + .zip(gen.i64_values()) .collect(); if sorted { diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index 62160067143e3..cfd4b8bc4bba8 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -86,8 +86,9 @@ fn create_context() -> Arc> { rt.block_on(async { // create local session context - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(1)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(1), + ); let table_provider = Arc::new(csv.await); let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state()) @@ -98,7 +99,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().first().unwrap().clone(); ctx } diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 1d96df0cecaa6..c7a838385bd68 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -93,10 +93,9 @@ async fn setup_files(store: Arc) { for partition in 0..TABLE_PARTITIONS { for file in 0..PARTITION_FILES { let data = create_parquet_file(&mut rng, file * FILE_ROWS); - let location = Path::try_from(format!( + let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" - )) - .unwrap(); + )); store.put(&location, data).await.unwrap(); } } @@ -120,7 +119,7 @@ async fn setup_context(object_store: Arc) -> SessionContext { let config = SessionConfig::new().with_target_partitions(THREADS); let rt = Arc::new(RuntimeEnv::default()); rt.register_object_store(&Url::parse("data://my_store").unwrap(), object_store); - let context = SessionContext::with_config_rt(config, rt); + let context = SessionContext::new_with_config_rt(config, rt); for table_id in 0..TABLES { let table_name = table_name(table_id); diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs new file mode 100644 index 0000000000000..922cbd2b42292 --- /dev/null +++ b/datafusion/core/benches/topk_aggregate.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod data_utils; +use arrow::util::pretty::pretty_format_batches; +use criterion::{criterion_group, criterion_main, Criterion}; +use data_utils::make_data; +use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; +use std::sync::Arc; +use tokio::runtime::Runtime; + +async fn create_context( + limit: usize, + partition_cnt: i32, + sample_cnt: i32, + asc: bool, + use_topk: bool, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, asc).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + let df = ctx.sql(sql.as_str()).await?; + let physical_plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + Ok((physical_plan, ctx.task_ctx())) +} + +fn run(plan: Arc, ctx: Arc, asc: bool) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await }), + ) + .unwrap(); +} + +async fn aggregate( + plan: Arc, + ctx: Arc, + asc: bool, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + let expected_asc = r#" ++----------------------------------+--------------------------+ +| trace_id | max(traces.timestamp_ms) | ++----------------------------------+--------------------------+ +| 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | +| 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | +| 02801bbe533190a9f8713d75222f445d | 16909009999997 | +| 9e31b3b5a620de32b68fefa5aeea57f1 | 16909009999996 | +| 2d88a860e9bd1cfaa632d8e7caeaa934 | 16909009999995 | +| a47edcef8364ab6f191dd9103e51c171 | 16909009999994 | +| 36a3fa2ccfbf8e00337f0b1254384db6 | 16909009999993 | +| 0756be84f57369012e10de18b57d8a2f | 16909009999992 | +| d4d6bf9845fa5897710e3a8db81d5907 | 16909009999991 | +| 3c2cc1abe728a66b61e14880b53482a0 | 16909009999990 | ++----------------------------------+--------------------------+ + "# + .trim(); + if asc { + assert_eq!(actual.trim(), expected_asc); + } + + Ok(()) +} + +fn criterion_benchmark(c: &mut Criterion) { + let limit = 10; + let partitions = 10; + let samples = 1_000_000; + + let rt = Runtime::new().unwrap(); + let topk_real = rt.block_on(async { + create_context(limit, partitions, samples, false, true) + .await + .unwrap() + }); + let topk_asc = rt.block_on(async { + create_context(limit, partitions, samples, true, true) + .await + .unwrap() + }); + let real = rt.block_on(async { + create_context(limit, partitions, samples, false, false) + .await + .unwrap() + }); + let asc = rt.block_on(async { + create_context(limit, partitions, samples, true, false) + .await + .unwrap() + }); + + c.bench_function( + format!("aggregate {} time-series rows", partitions * samples).as_str(), + |b| b.iter(|| run(real.0.clone(), real.1.clone(), false)), + ); + + c.bench_function( + format!("aggregate {} worst-case rows", partitions * samples).as_str(), + |b| b.iter(|| run(asc.0.clone(), asc.1.clone(), true)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} time-series rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_real.0.clone(), topk_real.1.clone(), false)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} worst-case rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_asc.0.clone(), topk_asc.1.clone(), true)), + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/src/catalog/catalog.rs b/datafusion/core/src/catalog/catalog.rs deleted file mode 100644 index 393d98dcb8848..0000000000000 --- a/datafusion/core/src/catalog/catalog.rs +++ /dev/null @@ -1,276 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Describes the interface and built-in implementations of catalogs, -//! representing collections of named schemas. - -use crate::catalog::schema::SchemaProvider; -use dashmap::DashMap; -use datafusion_common::{DataFusionError, Result}; -use std::any::Any; -use std::sync::Arc; - -/// Represent a list of named catalogs -pub trait CatalogList: Sync + Send { - /// Returns the catalog list as [`Any`](std::any::Any) - /// so that it can be downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Adds a new catalog to this catalog list - /// If a catalog of the same name existed before, it is replaced in the list and returned. - fn register_catalog( - &self, - name: String, - catalog: Arc, - ) -> Option>; - - /// Retrieves the list of available catalog names - fn catalog_names(&self) -> Vec; - - /// Retrieves a specific catalog by name, provided it exists. - fn catalog(&self, name: &str) -> Option>; -} - -/// Simple in-memory list of catalogs -pub struct MemoryCatalogList { - /// Collection of catalogs containing schemas and ultimately TableProviders - pub catalogs: DashMap>, -} - -impl MemoryCatalogList { - /// Instantiates a new `MemoryCatalogList` with an empty collection of catalogs - pub fn new() -> Self { - Self { - catalogs: DashMap::new(), - } - } -} - -impl Default for MemoryCatalogList { - fn default() -> Self { - Self::new() - } -} - -impl CatalogList for MemoryCatalogList { - fn as_any(&self) -> &dyn Any { - self - } - - fn register_catalog( - &self, - name: String, - catalog: Arc, - ) -> Option> { - self.catalogs.insert(name, catalog) - } - - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() - } - - fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| c.value().clone()) - } -} - -impl Default for MemoryCatalogProvider { - fn default() -> Self { - Self::new() - } -} - -/// Represents a catalog, comprising a number of named schemas. -pub trait CatalogProvider: Sync + Send { - /// Returns the catalog provider as [`Any`](std::any::Any) - /// so that it can be downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Retrieves the list of available schema names in this catalog. - fn schema_names(&self) -> Vec; - - /// Retrieves a specific schema from the catalog by name, provided it exists. - fn schema(&self, name: &str) -> Option>; - - /// Adds a new schema to this catalog. - /// - /// If a schema of the same name existed before, it is replaced in - /// the catalog and returned. - /// - /// By default returns a "Not Implemented" error - fn register_schema( - &self, - name: &str, - schema: Arc, - ) -> Result>> { - // use variables to avoid unused variable warnings - let _ = name; - let _ = schema; - Err(DataFusionError::NotImplemented( - "Registering new schemas is not supported".to_string(), - )) - } - - /// Removes a schema from this catalog. Implementations of this method should return - /// errors if the schema exists but cannot be dropped. For example, in DataFusion's - /// default in-memory catalog, [`MemoryCatalogProvider`], a non-empty schema - /// will only be successfully dropped when `cascade` is true. - /// This is equivalent to how DROP SCHEMA works in PostgreSQL. - /// - /// Implementations of this method should return None if schema with `name` - /// does not exist. - /// - /// By default returns a "Not Implemented" error - fn deregister_schema( - &self, - _name: &str, - _cascade: bool, - ) -> Result>> { - Err(DataFusionError::NotImplemented( - "Deregistering new schemas is not supported".to_string(), - )) - } -} - -/// Simple in-memory implementation of a catalog. -pub struct MemoryCatalogProvider { - schemas: DashMap>, -} - -impl MemoryCatalogProvider { - /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. - pub fn new() -> Self { - Self { - schemas: DashMap::new(), - } - } -} - -impl CatalogProvider for MemoryCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - self.schemas.iter().map(|s| s.key().clone()).collect() - } - - fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| s.value().clone()) - } - - fn register_schema( - &self, - name: &str, - schema: Arc, - ) -> Result>> { - Ok(self.schemas.insert(name.into(), schema)) - } - - fn deregister_schema( - &self, - name: &str, - cascade: bool, - ) -> Result>> { - if let Some(schema) = self.schema(name) { - let table_names = schema.table_names(); - match (table_names.is_empty(), cascade) { - (true, _) | (false, true) => { - let (_, removed) = self.schemas.remove(name).unwrap(); - Ok(Some(removed)) - } - (false, false) => Err(DataFusionError::Execution(format!( - "Cannot drop schema {} because other tables depend on it: {}", - name, - itertools::join(table_names.iter(), ", ") - ))), - } - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::catalog::schema::MemorySchemaProvider; - use crate::datasource::empty::EmptyTable; - use crate::datasource::TableProvider; - use arrow::datatypes::Schema; - - #[test] - fn default_register_schema_not_supported() { - // mimic a new CatalogProvider and ensure it does not support registering schemas - struct TestProvider {} - impl CatalogProvider for TestProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - unimplemented!() - } - - fn schema(&self, _name: &str) -> Option> { - unimplemented!() - } - } - - let schema = Arc::new(MemorySchemaProvider::new()) as _; - let catalog = Arc::new(TestProvider {}); - - match catalog.register_schema("foo", schema) { - Ok(_) => panic!("unexpected OK"), - Err(e) => assert_eq!(e.to_string(), "This feature is not implemented: Registering new schemas is not supported"), - }; - } - - #[test] - fn memory_catalog_dereg_nonempty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) - as Arc; - schema.register_table("t".into(), test_table).unwrap(); - - cat.register_schema("foo", schema.clone()).unwrap(); - - assert!( - cat.deregister_schema("foo", false).is_err(), - "dropping empty schema without cascade should error" - ); - assert!(cat.deregister_schema("foo", true).unwrap().is_some()); - } - - #[test] - fn memory_catalog_dereg_empty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - cat.register_schema("foo", schema.clone()).unwrap(); - - assert!(cat.deregister_schema("foo", false).unwrap().is_some()); - } - - #[test] - fn memory_catalog_dereg_missing() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - assert!(cat.deregister_schema("foo", false).unwrap().is_none()); - } -} diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index d30b490f28a4b..3a8fef2d25ab0 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -28,15 +28,18 @@ use arrow::{ record_batch::RecordBatch, }; -use crate::config::{ConfigEntry, ConfigOptions}; -use crate::datasource::streaming::{PartitionStream, StreamingTable}; +use crate::datasource::streaming::StreamingTable; use crate::datasource::TableProvider; use crate::execution::context::TaskContext; use crate::logical_expr::TableType; use crate::physical_plan::stream::RecordBatchStreamAdapter; use crate::physical_plan::SendableRecordBatchStream; +use crate::{ + config::{ConfigEntry, ConfigOptions}, + physical_plan::streaming::PartitionStream, +}; -use super::{catalog::CatalogList, schema::SchemaProvider}; +use super::{schema::SchemaProvider, CatalogList}; pub(crate) const INFORMATION_SCHEMA: &str = "information_schema"; pub(crate) const TABLES: &str = "tables"; @@ -623,7 +626,8 @@ impl InformationSchemaDfSettings { fn new(config: InformationSchemaConfig) -> Self { let schema = Arc::new(Schema::new(vec![ Field::new("name", DataType::Utf8, false), - Field::new("setting", DataType::Utf8, true), + Field::new("value", DataType::Utf8, true), + Field::new("description", DataType::Utf8, true), ])); Self { schema, config } @@ -632,7 +636,8 @@ impl InformationSchemaDfSettings { fn builder(&self) -> InformationSchemaDfSettingsBuilder { InformationSchemaDfSettingsBuilder { names: StringBuilder::new(), - settings: StringBuilder::new(), + values: StringBuilder::new(), + descriptions: StringBuilder::new(), schema: self.schema.clone(), } } @@ -661,13 +666,15 @@ impl PartitionStream for InformationSchemaDfSettings { struct InformationSchemaDfSettingsBuilder { schema: SchemaRef, names: StringBuilder, - settings: StringBuilder, + values: StringBuilder, + descriptions: StringBuilder, } impl InformationSchemaDfSettingsBuilder { fn add_setting(&mut self, entry: ConfigEntry) { self.names.append_value(entry.key); - self.settings.append_option(entry.value); + self.values.append_option(entry.value); + self.descriptions.append_value(entry.description); } fn finish(&mut self) -> RecordBatch { @@ -675,7 +682,8 @@ impl InformationSchemaDfSettingsBuilder { self.schema.clone(), vec![ Arc::new(self.names.finish()), - Arc::new(self.settings.finish()), + Arc::new(self.values.finish()), + Arc::new(self.descriptions.finish()), ], ) .unwrap() diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index cb63659997b58..c3c6826895421 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -16,21 +16,25 @@ // under the License. //! listing_schema contains a SchemaProvider that scans ObjectStores for tables automatically + +use std::any::Any; +use std::collections::{HashMap, HashSet}; +use std::path::Path; +use std::sync::{Arc, Mutex}; + use crate::catalog::schema::SchemaProvider; -use crate::datasource::datasource::TableProviderFactory; +use crate::datasource::provider::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; -use async_trait::async_trait; + use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference}; +use datafusion_common::{Constraints, DFSchema, DataFusionError, OwnedTableReference}; use datafusion_expr::CreateExternalTable; + +use async_trait::async_trait; use futures::TryStreamExt; use itertools::Itertools; use object_store::ObjectStore; -use std::any::Any; -use std::collections::{HashMap, HashSet}; -use std::path::Path; -use std::sync::{Arc, Mutex}; /// A [`SchemaProvider`] that scans an [`ObjectStore`] to automatically discover tables /// @@ -88,12 +92,7 @@ impl ListingSchemaProvider { /// Reload table information from ObjectStore pub async fn refresh(&self, state: &SessionState) -> datafusion_common::Result<()> { - let entries: Vec<_> = self - .store - .list(Some(&self.path)) - .await? - .try_collect() - .await?; + let entries: Vec<_> = self.store.list(Some(&self.path)).try_collect().await?; let base = Path::new(self.path.as_ref()); let mut tables = HashSet::new(); for file in entries.iter() { @@ -149,6 +148,8 @@ impl ListingSchemaProvider { order_exprs: vec![], unbounded: false, options: Default::default(), + constraints: Constraints::empty(), + column_defaults: Default::default(), }, ) .await?; diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index b7843ed66b832..ce27d57da00d8 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -15,17 +15,263 @@ // specific language governing permissions and limitations // under the License. -//! This module contains interfaces and default implementations -//! of table namespacing concepts, including catalogs and schemas. - -// TODO(clippy): Having a `catalog::catalog` module path is unclear and ambiguous. -// The parent module should probably be renamed to something that more accurately -// describes its content. Something along the lines of `database_meta`, `metadata` -// or `meta`, perhaps? -#![allow(clippy::module_inception)] -pub mod catalog; +//! Interfaces and default implementations of catalogs and schemas. + pub mod information_schema; pub mod listing_schema; pub mod schema; pub use datafusion_sql::{ResolvedTableReference, TableReference}; + +use crate::catalog::schema::SchemaProvider; +use dashmap::DashMap; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use std::any::Any; +use std::sync::Arc; + +/// Represent a list of named catalogs +pub trait CatalogList: Sync + Send { + /// Returns the catalog list as [`Any`] + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Adds a new catalog to this catalog list + /// If a catalog of the same name existed before, it is replaced in the list and returned. + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option>; + + /// Retrieves the list of available catalog names + fn catalog_names(&self) -> Vec; + + /// Retrieves a specific catalog by name, provided it exists. + fn catalog(&self, name: &str) -> Option>; +} + +/// Simple in-memory list of catalogs +pub struct MemoryCatalogList { + /// Collection of catalogs containing schemas and ultimately TableProviders + pub catalogs: DashMap>, +} + +impl MemoryCatalogList { + /// Instantiates a new `MemoryCatalogList` with an empty collection of catalogs + pub fn new() -> Self { + Self { + catalogs: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogList { + fn default() -> Self { + Self::new() + } +} + +impl CatalogList for MemoryCatalogList { + fn as_any(&self) -> &dyn Any { + self + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.catalogs.insert(name, catalog) + } + + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.catalogs.get(name).map(|c| c.value().clone()) + } +} + +/// Represents a catalog, comprising a number of named schemas. +pub trait CatalogProvider: Sync + Send { + /// Returns the catalog provider as [`Any`] + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Retrieves the list of available schema names in this catalog. + fn schema_names(&self) -> Vec; + + /// Retrieves a specific schema from the catalog by name, provided it exists. + fn schema(&self, name: &str) -> Option>; + + /// Adds a new schema to this catalog. + /// + /// If a schema of the same name existed before, it is replaced in + /// the catalog and returned. + /// + /// By default returns a "Not Implemented" error + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + // use variables to avoid unused variable warnings + let _ = name; + let _ = schema; + not_impl_err!("Registering new schemas is not supported") + } + + /// Removes a schema from this catalog. Implementations of this method should return + /// errors if the schema exists but cannot be dropped. For example, in DataFusion's + /// default in-memory catalog, [`MemoryCatalogProvider`], a non-empty schema + /// will only be successfully dropped when `cascade` is true. + /// This is equivalent to how DROP SCHEMA works in PostgreSQL. + /// + /// Implementations of this method should return None if schema with `name` + /// does not exist. + /// + /// By default returns a "Not Implemented" error + fn deregister_schema( + &self, + _name: &str, + _cascade: bool, + ) -> Result>> { + not_impl_err!("Deregistering new schemas is not supported") + } +} + +/// Simple in-memory implementation of a catalog. +pub struct MemoryCatalogProvider { + schemas: DashMap>, +} + +impl MemoryCatalogProvider { + /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. + pub fn new() -> Self { + Self { + schemas: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogProvider { + fn default() -> Self { + Self::new() + } +} + +impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.schemas.iter().map(|s| s.key().clone()).collect() + } + + fn schema(&self, name: &str) -> Option> { + self.schemas.get(name).map(|s| s.value().clone()) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + Ok(self.schemas.insert(name.into(), schema)) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + if let Some(schema) = self.schema(name) { + let table_names = schema.table_names(); + match (table_names.is_empty(), cascade) { + (true, _) | (false, true) => { + let (_, removed) = self.schemas.remove(name).unwrap(); + Ok(Some(removed)) + } + (false, false) => exec_err!( + "Cannot drop schema {} because other tables depend on it: {}", + name, + itertools::join(table_names.iter(), ", ") + ), + } + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::catalog::schema::MemorySchemaProvider; + use crate::datasource::empty::EmptyTable; + use crate::datasource::TableProvider; + use arrow::datatypes::Schema; + + #[test] + fn default_register_schema_not_supported() { + // mimic a new CatalogProvider and ensure it does not support registering schemas + struct TestProvider {} + impl CatalogProvider for TestProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + unimplemented!() + } + + fn schema(&self, _name: &str) -> Option> { + unimplemented!() + } + } + + let schema = Arc::new(MemorySchemaProvider::new()) as _; + let catalog = Arc::new(TestProvider {}); + + match catalog.register_schema("foo", schema) { + Ok(_) => panic!("unexpected OK"), + Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), + }; + } + + #[test] + fn memory_catalog_dereg_nonempty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) + as Arc; + schema.register_table("t".into(), test_table).unwrap(); + + cat.register_schema("foo", schema.clone()).unwrap(); + + assert!( + cat.deregister_schema("foo", false).is_err(), + "dropping empty schema without cascade should error" + ); + assert!(cat.deregister_schema("foo", true).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_empty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + cat.register_schema("foo", schema.clone()).unwrap(); + + assert!(cat.deregister_schema("foo", false).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_missing() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + assert!(cat.deregister_schema("foo", false).unwrap().is_none()); + } +} diff --git a/datafusion/core/src/catalog/schema.rs b/datafusion/core/src/catalog/schema.rs index 9d3b47546e392..1bb2df914ab25 100644 --- a/datafusion/core/src/catalog/schema.rs +++ b/datafusion/core/src/catalog/schema.rs @@ -20,6 +20,7 @@ use async_trait::async_trait; use dashmap::DashMap; +use datafusion_common::exec_err; use std::any::Any; use std::sync::Arc; @@ -47,18 +48,14 @@ pub trait SchemaProvider: Sync + Send { name: String, table: Arc, ) -> Result>> { - Err(DataFusionError::Execution( - "schema provider does not support registering tables".to_owned(), - )) + exec_err!("schema provider does not support registering tables") } /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). #[allow(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { - Err(DataFusionError::Execution( - "schema provider does not support deregistering tables".to_owned(), - )) + exec_err!("schema provider does not support deregistering tables") } /// If supported by the implementation, checks the table exist in the schema provider or not. @@ -110,9 +107,7 @@ impl SchemaProvider for MemorySchemaProvider { table: Arc, ) -> Result>> { if self.table_exist(name.as_str()) { - return Err(DataFusionError::Execution(format!( - "The table {name} already exists" - ))); + return exec_err!("The table {name} already exists"); } Ok(self.tables.insert(name, table)) } @@ -133,8 +128,8 @@ mod tests { use arrow::datatypes::Schema; use crate::assert_batches_eq; - use crate::catalog::catalog::{CatalogProvider, MemoryCatalogProvider}; use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; + use crate::catalog::{CatalogProvider, MemoryCatalogProvider}; use crate::datasource::empty::EmptyTable; use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; use crate::prelude::SessionContext; @@ -199,7 +194,7 @@ mod tests { let actual = df.collect().await.unwrap(); - let expected = vec![ + let expected = [ "+----+----------+", "| id | bool_col |", "+----+----------+", diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe/mod.rs similarity index 70% rename from datafusion/core/src/dataframe.rs rename to datafusion/core/src/dataframe/mod.rs index 7d0fddcf82268..c40dd522a4579 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -15,44 +15,100 @@ // specific language governing permissions and limitations // under the License. -//! DataFrame API for building and executing query plans. +//! [`DataFrame`] API for building and executing query plans. + +#[cfg(feature = "parquet")] +mod parquet; use std::any::Any; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; -use arrow::compute::{cast, concat}; -use arrow::datatypes::{DataType, Field}; -use async_trait::async_trait; -use datafusion_common::{DataFusionError, SchemaError}; -use parquet::file::properties::WriterProperties; - -use datafusion_common::{Column, DFSchema, ScalarValue}; -use datafusion_expr::{ - avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, - TableProviderFilterPushDown, UNNAMED_TABLE, -}; - -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; +use crate::arrow::datatypes::{Schema, SchemaRef}; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; -use crate::datasource::physical_plan::{plan_to_csv, plan_to_json, plan_to_parquet}; use crate::datasource::{provider_as_source, MemTable, TableProvider}; use crate::error::Result; use crate::execution::{ context::{SessionState, TaskContext}, FunctionRegistry, }; +use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, utils::find_window_exprs, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Partitioning, TableType, + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, +}; +use crate::physical_plan::{ + collect, collect_partitioned, execute_stream, execute_stream_partitioned, + ExecutionPlan, SendableRecordBatchStream, }; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{collect, collect_partitioned}; -use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; use crate::prelude::SessionContext; +use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; +use arrow::compute::{cast, concat}; +use arrow::csv::WriterBuilder; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + Column, DFSchema, DataFusionError, FileType, FileTypeWriterOptions, ParamValues, + SchemaError, UnnestOptions, +}; +use datafusion_expr::dml::CopyOptions; +use datafusion_expr::{ + avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, + TableProviderFilterPushDown, UNNAMED_TABLE, +}; + +use async_trait::async_trait; + +/// Contains options that control how data is +/// written out from a DataFrame +pub struct DataFrameWriteOptions { + /// Controls if existing data should be overwritten + overwrite: bool, + /// Controls if all partitions should be coalesced into a single output file + /// Generally will have slower performance when set to true. + single_file_output: bool, + /// Sets compression by DataFusion applied after file serialization. + /// Allows compression of CSV and JSON. + /// Not supported for parquet. + compression: CompressionTypeVariant, +} + +impl DataFrameWriteOptions { + /// Create a new DataFrameWriteOptions with default values + pub fn new() -> Self { + DataFrameWriteOptions { + overwrite: false, + single_file_output: false, + compression: CompressionTypeVariant::UNCOMPRESSED, + } + } + /// Set the overwrite option to true or false + pub fn with_overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } + + /// Set the single_file_output value to true or false + pub fn with_single_file_output(mut self, single_file_output: bool) -> Self { + self.single_file_output = single_file_output; + self + } + + /// Sets the compression type applied to the output file(s) + pub fn with_compression(mut self, compression: CompressionTypeVariant) -> Self { + self.compression = compression; + self + } +} + +impl Default for DataFrameWriteOptions { + fn default() -> Self { + Self::new() + } +} + /// DataFrame represents a logical set of rows with the same named columns. /// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or /// [Spark DataFrame](https://spark.apache.org/docs/latest/sql-programming-guide.html) @@ -151,6 +207,11 @@ impl DataFrame { /// Expand each list element of a column to multiple rows. /// + /// Seee also: + /// + /// 1. [`UnnestOptions`] documentation for the behavior of `unnest` + /// 2. [`Self::unnest_column_with_options`] + /// /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -163,8 +224,21 @@ impl DataFrame { /// # } /// ``` pub fn unnest_column(self, column: &str) -> Result { + self.unnest_column_with_options(column, UnnestOptions::new()) + } + + /// Expand each list element of a column to multiple rows, with + /// behavior controlled by [`UnnestOptions`]. + /// + /// Please see the documentation on [`UnnestOptions`] for more + /// details about the meaning of unnest. + pub fn unnest_column_with_options( + self, + column: &str, + options: UnnestOptions, + ) -> Result { let plan = LogicalPlanBuilder::from(self.plan) - .unnest_column(column)? + .unnest_column_with_options(column, options)? .build()?; Ok(DataFrame::new(self.session_state, plan)) } @@ -218,6 +292,14 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } + /// Apply one or more window functions ([`Expr::WindowFunction`]) to extend the schema + pub fn window(self, window_exprs: Vec) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .window(window_exprs)? + .build()?; + Ok(DataFrame::new(self.session_state, plan)) + } + /// Limit the number of rows returned from this DataFrame. /// /// `skip` - Number of rows to skip before fetch any row @@ -499,12 +581,21 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } - /// Join this DataFrame with another DataFrame using the specified columns as join keys. + /// Join this `DataFrame` with another `DataFrame` using explicitly specified + /// columns and an optional filter expression. + /// + /// See [`join_on`](Self::join_on) for a more concise way to specify the + /// join condition. Since DataFusion will automatically identify and + /// optimize equality predicates there is no performance difference between + /// this function and `join_on` + /// + /// `left_cols` and `right_cols` are used to form "equijoin" predicates (see + /// example below), which are then combined with the optional `filter` + /// expression. /// - /// Filter expression expected to contain non-equality predicates that can not be pushed - /// down to any of join inputs. - /// In case of outer join, filter applied to only matched rows. + /// Note that in case of outer join, the `filter` is applied to only matched rows. /// + /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; @@ -517,11 +608,14 @@ impl DataFrame { /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; + /// // Perform the equivalent of `left INNER JOIN right ON (a = a2 AND b = b2)` + /// // finding all pairs of rows from `left` and `right` where `a = a2` and `b = b2`. /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"], None)?; /// let batches = join.collect().await?; /// # Ok(()) /// # } /// ``` + /// pub fn join( self, right: DataFrame, @@ -541,10 +635,13 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, plan)) } - /// Join this DataFrame with another DataFrame using the specified expressions. + /// Join this `DataFrame` with another `DataFrame` using the specified + /// expressions. + /// + /// Note that DataFusion automatically optimizes joins, including + /// identifying and optimizing equality predicates. /// - /// Simply a thin wrapper over [`join`](Self::join) where the join keys are not provided, - /// and the provided expressions are AND'ed together to form the filter expression. + /// # Example /// /// ``` /// # use datafusion::prelude::*; @@ -563,6 +660,10 @@ impl DataFrame { /// col("b").alias("b2"), /// col("c").alias("c2"), /// ])?; + /// + /// // Perform the equivalent of `left INNER JOIN right ON (a != a2 AND b != b2)` + /// // finding all pairs of rows from `left` and `right` where + /// // where `a != a2` and `b != b2`. /// let join_on = left.join_on( /// right, /// JoinType::Inner, @@ -580,12 +681,7 @@ impl DataFrame { ) -> Result { let expr = on_exprs.into_iter().reduce(Expr::and); let plan = LogicalPlanBuilder::from(self.plan) - .join( - right.plan, - join_type, - (Vec::::new(), Vec::::new()), - expr, - )? + .join_on(right.plan, join_type, expr)? .build()?; Ok(DataFrame::new(self.session_state, plan)) } @@ -699,7 +795,8 @@ impl DataFrame { Ok(pretty::print_batches(&results)?) } - fn task_ctx(&self) -> TaskContext { + /// Get a new TaskContext to run in this session + pub fn task_ctx(&self) -> TaskContext { TaskContext::from(&self.session_state) } @@ -916,29 +1013,82 @@ impl DataFrame { )) } - /// Write a `DataFrame` to a CSV file. - pub async fn write_csv(self, path: &str) -> Result<()> { - let plan = self.session_state.create_physical_plan(&self.plan).await?; - let task_ctx = Arc::new(self.task_ctx()); - plan_to_csv(task_ctx, plan, path).await + /// Write this DataFrame to the referenced table + /// This method uses on the same underlying implementation + /// as the SQL Insert Into statement. + /// Unlike most other DataFrame methods, this method executes + /// eagerly, writing data, and returning the count of rows written. + pub async fn write_table( + self, + table_name: &str, + write_options: DataFrameWriteOptions, + ) -> Result, DataFusionError> { + let arrow_schema = Schema::from(self.schema()); + let plan = LogicalPlanBuilder::insert_into( + self.plan, + table_name.to_owned(), + &arrow_schema, + write_options.overwrite, + )? + .build()?; + DataFrame::new(self.session_state, plan).collect().await } - /// Write a `DataFrame` to a Parquet file. - pub async fn write_parquet( + /// Write a `DataFrame` to a CSV file. + pub async fn write_csv( self, path: &str, - writer_properties: Option, - ) -> Result<()> { - let plan = self.session_state.create_physical_plan(&self.plan).await?; - let task_ctx = Arc::new(self.task_ctx()); - plan_to_parquet(task_ctx, plan, path, writer_properties).await + options: DataFrameWriteOptions, + writer_properties: Option, + ) -> Result, DataFusionError> { + if options.overwrite { + return Err(DataFusionError::NotImplemented( + "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), + )); + } + let props = match writer_properties { + Some(props) => props, + None => WriterBuilder::new(), + }; + + let file_type_writer_options = + FileTypeWriterOptions::CSV(CsvWriterOptions::new(props, options.compression)); + let copy_options = CopyOptions::WriterOptions(Box::new(file_type_writer_options)); + + let plan = LogicalPlanBuilder::copy_to( + self.plan, + path.into(), + FileType::CSV, + options.single_file_output, + copy_options, + )? + .build()?; + DataFrame::new(self.session_state, plan).collect().await } /// Executes a query and writes the results to a partitioned JSON file. - pub async fn write_json(self, path: impl AsRef) -> Result<()> { - let plan = self.session_state.create_physical_plan(&self.plan).await?; - let task_ctx = Arc::new(self.task_ctx()); - plan_to_json(task_ctx, plan, path).await + pub async fn write_json( + self, + path: &str, + options: DataFrameWriteOptions, + ) -> Result, DataFusionError> { + if options.overwrite { + return Err(DataFusionError::NotImplemented( + "Overwrites are not implemented for DataFrame::write_json.".to_owned(), + )); + } + let file_type_writer_options = + FileTypeWriterOptions::JSON(JsonWriterOptions::new(options.compression)); + let copy_options = CopyOptions::WriterOptions(Box::new(file_type_writer_options)); + let plan = LogicalPlanBuilder::copy_to( + self.plan, + path.into(), + FileType::JSON, + options.single_file_output, + copy_options, + )? + .build()?; + DataFrame::new(self.session_state, plan).collect().await } /// Add an additional column to the DataFrame. @@ -973,10 +1123,7 @@ impl DataFrame { col_exists = true; new_column.clone() } else { - Expr::Column(Column { - relation: None, - name: f.name().into(), - }) + col(f.qualified_column()) } }) .collect(); @@ -1006,12 +1153,21 @@ impl DataFrame { /// ``` pub fn with_column_renamed( self, - old_name: impl Into, + old_name: impl Into, new_name: &str, ) -> Result { - let old_name: Column = old_name.into(); + let ident_opts = self + .session_state + .config_options() + .sql_parser + .enable_ident_normalization; + let old_column: Column = if ident_opts { + Column::from_qualified_name(old_name) + } else { + Column::from_qualified_name_ignore_case(old_name) + }; - let field_to_rename = match self.plan.schema().field_from_column(&old_name) { + let field_to_rename = match self.plan.schema().field_from_column(&old_column) { Ok(field) => field, // no-op if field not found Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { @@ -1038,9 +1194,65 @@ impl DataFrame { Ok(DataFrame::new(self.session_state, project_plan)) } - /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values - pub fn with_param_values(self, param_values: Vec) -> Result { - let plan = self.plan.with_param_values(param_values)?; + /// Replace all parameters in logical plan with the specified + /// values, in preparation for execution. + /// + /// # Example + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::{error::Result, assert_batches_eq}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// # use datafusion_common::ScalarValue; + /// let mut ctx = SessionContext::new(); + /// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $1") + /// .await? + /// // replace $1 with value 2 + /// .with_param_values(vec![ + /// // value at index 0 --> $1 + /// ScalarValue::from(2i64) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); + /// // Note you can also provide named parameters + /// let results = ctx + /// .sql("SELECT a FROM example WHERE b = $my_param") + /// .await? + /// // replace $my_param with value 2 + /// // Note you can also use a HashMap as well + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(2i64)) + /// ])? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+", + /// "| a |", + /// "+---+", + /// "| 1 |", + /// "+---+", + /// ], + /// &results + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn with_param_values(self, query_values: impl Into) -> Result { + let plan = self.plan.with_param_values(query_values)?; Ok(Self::new(self.session_state, plan)) } @@ -1058,7 +1270,7 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::with_state(self.session_state.clone()); + let context = SessionContext::new_with_state(self.session_state.clone()); let mem_table = MemTable::try_new( SchemaRef::from(self.schema().clone()), self.collect_partitioned().await?, @@ -1107,15 +1319,16 @@ impl TableProvider for DataFrameTableProvider { limit: Option, ) -> Result> { let mut expr = LogicalPlanBuilder::from(self.plan.clone()); - if let Some(p) = projection { - expr = expr.select(p.iter().copied())? - } - // Add filter when given let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new)); if let Some(filter) = filter { expr = expr.filter(filter)? } + + if let Some(p) = projection { + expr = expr.select(p.iter().copied())? + } + // add a limit if given if let Some(l) = limit { expr = expr.limit(0, Some(l))? @@ -1129,26 +1342,129 @@ impl TableProvider for DataFrameTableProvider { mod tests { use std::vec; - use arrow::array::Int32Array; - use arrow::datatypes::DataType; + use super::*; + use crate::execution::context::SessionConfig; + use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; + use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + use arrow::array::{self, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, - BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + BinaryExpr, BuiltInWindowFunction, Operator, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunction, }; use datafusion_physical_expr::expressions::Column; + use datafusion_physical_plan::get_plan_string; + + pub fn table_with_constraints() -> Arc { + let dual_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + dual_schema.clone(), + vec![ + Arc::new(array::Int32Array::from(vec![1])), + Arc::new(array::StringArray::from(vec!["a"])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(dual_schema, vec![vec![batch]]) + .unwrap() + .with_constraints(Constraints::new_unverified(vec![Constraint::PrimaryKey( + vec![0], + )])); + Arc::new(provider) + } - use crate::execution::context::SessionConfig; - use crate::execution::options::{CsvReadOptions, ParquetReadOptions}; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util; - use crate::test_util::parquet_test_data; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } - use super::*; + #[tokio::test] + async fn test_array_agg_ord_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } #[tokio::test] async fn select_columns() -> Result<()> { @@ -1220,7 +1536,7 @@ mod tests { let df_results = df.collect().await?; assert_batches_sorted_eq!( - vec!["+------+", "| f.c1 |", "+------+", "| 1 |", "| 10 |", "+------+",], + ["+------+", "| f.c1 |", "+------+", "| 1 |", "| 10 |", "+------+"], &df_results ); @@ -1244,8 +1560,7 @@ mod tests { let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_batches_sorted_eq!( - vec![ - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", + ["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |", "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", @@ -1253,14 +1568,269 @@ mod tests { "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - ], + "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+"], &df ); Ok(()) } + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // expr list contains id, name + let expr_list = vec![col_id, col_name]; + let df = df.select(expr_list)?; + let physical_plan = df.clone().create_physical_plan().await?; + let expected = vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + let condition2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_name), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))), + )); + // Predicate refers to id, and name fields + let predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(condition1), + Operator::And, + Box::new(condition2), + )); + let df = df.filter(predicate)?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1 AND name@1 = a", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk3() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + let col_name = Expr::Column(datafusion_common::Column { + relation: None, + name: "name".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=0 + let df = df.filter(predicate)?; + // Select expression refers to id, and name columns. + // id, name + let df = df.select(vec![col_id.clone(), col_name.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+",], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk4() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let table1 = table_with_constraints(); + let df = ctx.read_table(table1)?; + let col_id = Expr::Column(datafusion_common::Column { + relation: None, + name: "id".to_string(), + }); + + // group by contains id column + let group_expr = vec![col_id.clone()]; + let aggr_expr = vec![]; + // group by id, + let df = df.aggregate(group_expr, aggr_expr)?; + + let condition1 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col_id.clone()), + Operator::Eq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))), + )); + // Predicate refers to id field + let predicate = condition1; + // id=1 + let df = df.filter(predicate)?; + // Select expression refers to id column. + // id + let df = df.select(vec![col_id.clone()])?; + let physical_plan = df.clone().create_physical_plan().await?; + + // In this case aggregate shouldn't be expanded, since these + // columns are not used. + let expected = vec![ + "CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: id@0 = 1", + " AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // Since id and name are functionally dependant, we can use name among expression + // even if it is not part of the group by expression. + let df_results = collect(physical_plan, ctx.task_ctx()).await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ "+----+", + "| id |", + "+----+", + "| 1 |", + "+----+",], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; @@ -1293,8 +1863,7 @@ mod tests { #[rustfmt::skip] assert_batches_sorted_eq!( - vec![ - "+----+", + ["+----+", "| c1 |", "+----+", "| a |", @@ -1302,8 +1871,7 @@ mod tests { "| c |", "| d |", "| e |", - "+----+", - ], + "+----+"], &df_results ); @@ -1321,7 +1889,7 @@ mod tests { // try to sort on some value not present in input to distinct .sort(vec![col("c2").sort(true, true)]) .unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); + assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); Ok(()) } @@ -1379,7 +1947,7 @@ mod tests { .join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))]) .expect_err("join didn't fail check"); let expected = "Schema error: Ambiguous reference to unqualified field c1"; - assert_eq!(join.to_string(), expected); + assert_eq!(join.strip_backtrace(), expected); Ok(()) } @@ -1526,7 +2094,7 @@ mod tests { let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+----+-----------------------------+", "| c1 | SUM(aggregate_test_100.c12) |", "+----+-----------------------------+", @@ -1535,14 +2103,14 @@ mod tests { "| c | 13.860958726523545 |", "| d | 8.793968289758968 |", "| e | 10.206140546981722 |", - "+----+-----------------------------+", + "+----+-----------------------------+" ], &df_results ); // the results are the same as the results from the view, modulo the leaf table name assert_batches_sorted_eq!( - vec![ + [ "+----+---------------------+", "| c1 | SUM(test_table.c12) |", "+----+---------------------+", @@ -1551,7 +2119,7 @@ mod tests { "| c | 13.860958726523545 |", "| d | 8.793968289758968 |", "| e | 10.206140546981722 |", - "+----+---------------------+", + "+----+---------------------+" ], table_results ); @@ -1570,31 +2138,6 @@ mod tests { Ok(ctx.sql(sql).await?.into_unoptimized_plan()) } - async fn test_table_with_name(name: &str) -> Result { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, name).await?; - ctx.table(name).await - } - - async fn test_table() -> Result { - test_table_with_name("aggregate_test_100").await - } - - async fn register_aggregate_csv( - ctx: &mut SessionContext, - table_name: &str, - ) -> Result<()> { - let schema = test_util::aggr_test_schema(); - let testdata = test_util::arrow_test_data(); - ctx.register_csv( - table_name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new().schema(schema.as_ref()), - ) - .await?; - Ok(()) - } - #[tokio::test] async fn with_column() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; @@ -1609,7 +2152,7 @@ mod tests { let df_results = df.clone().collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+----+----+-----+-----+", "| c1 | c2 | c3 | sum |", "+----+----+-----+-----+", @@ -1619,7 +2162,7 @@ mod tests { "| a | 3 | 13 | 16 |", "| a | 3 | 14 | 17 |", "| a | 3 | 17 | 20 |", - "+----+----+-----+-----+", + "+----+----+-----+-----+" ], &df_results ); @@ -1632,7 +2175,7 @@ mod tests { .await?; assert_batches_sorted_eq!( - vec![ + [ "+-----+----+-----+-----+", "| c1 | c2 | c3 | sum |", "+-----+----+-----+-----+", @@ -1642,7 +2185,7 @@ mod tests { "| 16 | 3 | 13 | 16 |", "| 17 | 3 | 14 | 17 |", "| 20 | 3 | 17 | 20 |", - "+-----+----+-----+-----+", + "+-----+----+-----+-----+" ], &df_results_overwrite ); @@ -1655,7 +2198,7 @@ mod tests { .await?; assert_batches_sorted_eq!( - vec![ + [ "+----+----+-----+-----+", "| c1 | c2 | c3 | sum |", "+----+----+-----+-----+", @@ -1665,7 +2208,7 @@ mod tests { "| a | 4 | 13 | 16 |", "| a | 4 | 14 | 17 |", "| a | 4 | 17 | 20 |", - "+----+----+-----+-----+", + "+----+----+-----+-----+" ], &df_results_overwrite_self ); @@ -1673,6 +2216,131 @@ mod tests { Ok(()) } + // Test issue: https://github.com/apache/arrow-datafusion/issues/7790 + // The join operation outputs two identical column names, but they belong to different relations. + #[tokio::test] + async fn with_column_join_same_columns() -> Result<()> { + let df = test_table().await?.select_columns(&["c1"])?; + let ctx = SessionContext::new(); + + let table = df.into_view(); + ctx.register_table("t1", table.clone())?; + ctx.register_table("t2", table)?; + let df = ctx + .table("t1") + .await? + .join( + ctx.table("t2").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + ])? + .limit(0, Some(1))?; + + let df_results = df.clone().collect().await?; + assert_batches_sorted_eq!( + [ + "+----+----+", + "| c1 | c1 |", + "+----+----+", + "| a | a |", + "+----+----+", + ], + &df_results + ); + + let df_with_column = df.clone().with_column("new_column", lit(true))?; + + assert_eq!( + "\ + Projection: t1.c1, t2.c1, Boolean(true) AS new_column\ + \n Limit: skip=0, fetch=1\ + \n Sort: t1.c1 ASC NULLS FIRST\ + \n Inner Join: t1.c1 = t2.c1\ + \n TableScan: t1\ + \n TableScan: t2", + format!("{:?}", df_with_column.logical_plan()) + ); + + assert_eq!( + "\ + Projection: t1.c1, t2.c1, Boolean(true) AS new_column\ + \n Limit: skip=0, fetch=1\ + \n Sort: t1.c1 ASC NULLS FIRST, fetch=1\ + \n Inner Join: t1.c1 = t2.c1\ + \n SubqueryAlias: t1\ + \n TableScan: aggregate_test_100 projection=[c1]\ + \n SubqueryAlias: t2\ + \n TableScan: aggregate_test_100 projection=[c1]", + format!("{:?}", df_with_column.clone().into_optimized_plan()?) + ); + + let df_results = df_with_column.collect().await?; + + assert_batches_sorted_eq!( + [ + "+----+----+------------+", + "| c1 | c1 | new_column |", + "+----+----+------------+", + "| a | a | true |", + "+----+----+------------+", + ], + &df_results + ); + Ok(()) + } + + // Table 't1' self join + // Supplementary test of issue: https://github.com/apache/arrow-datafusion/issues/7790 + #[tokio::test] + async fn with_column_self_join() -> Result<()> { + let df = test_table().await?.select_columns(&["c1"])?; + let ctx = SessionContext::new(); + + ctx.register_table("t1", df.into_view())?; + + let df = ctx + .table("t1") + .await? + .join( + ctx.table("t1").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + ])? + .limit(0, Some(1))?; + + let df_results = df.clone().collect().await?; + assert_batches_sorted_eq!( + [ + "+----+----+", + "| c1 | c1 |", + "+----+----+", + "| a | a |", + "+----+----+", + ], + &df_results + ); + + let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err(); + let expected_err = "Error during planning: Projections require unique expression names \ + but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \ + Consider aliasing (\"AS\") one of them."; + assert_eq!(actual_err.strip_backtrace(), expected_err); + + Ok(()) + } + #[tokio::test] async fn with_column_renamed() -> Result<()> { let df = test_table() @@ -1700,12 +2368,12 @@ mod tests { .await?; assert_batches_sorted_eq!( - vec![ + [ "+-----+-----+----+-------+", "| one | two | c3 | total |", "+-----+-----+----+-------+", "| a | 3 | 13 | 16 |", - "+-----+-----+----+-------+", + "+-----+-----+----+-------+" ], &df_sum_renamed ); @@ -1736,7 +2404,7 @@ mod tests { .with_column_renamed("c2", "AAA") .unwrap_err(); let expected_err = "Schema error: Ambiguous reference to unqualified field c2"; - assert_eq!(actual_err.to_string(), expected_err); + assert_eq!(actual_err.strip_backtrace(), expected_err); Ok(()) } @@ -1772,12 +2440,12 @@ mod tests { let df_results = df.clone().collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+----+----+-----+----+----+-----+", "| c1 | c2 | c3 | c1 | c2 | c3 |", "+----+----+-----+----+----+-----+", "| a | 1 | -85 | a | 1 | -85 |", - "+----+----+-----+----+----+-----+", + "+----+----+-----+----+----+-----+" ], &df_results ); @@ -1809,12 +2477,12 @@ mod tests { let df_results = df_renamed.collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+-----+----+-----+----+----+-----+", "| AAA | c2 | c3 | c1 | c2 | c3 |", "+-----+----+-----+----+----+-----+", "| a | 1 | -85 | a | 1 | -85 |", - "+-----+----+-----+----+----+-----+", + "+-----+----+-----+----+----+-----+" ], &df_results ); @@ -1823,28 +2491,53 @@ mod tests { } #[tokio::test] - async fn filter_pushdown_dataframe() -> Result<()> { - let ctx = SessionContext::new(); + async fn with_column_renamed_case_sensitive() -> Result<()> { + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; + let mut ctx = SessionContext::new_with_config(config); + let name = "aggregate_test_100"; + register_aggregate_csv(&mut ctx, name).await?; + let df = ctx.table(name); - ctx.register_parquet( - "test", - &format!("{}/alltypes_plain.snappy.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; + let df = df + .await? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(0, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .select_columns(&["c1"])?; - ctx.register_table("t1", ctx.table("test").await?.into_view())?; + let df_renamed = df.clone().with_column_renamed("c1", "CoLuMn1")?; - let df = ctx - .table("t1") - .await? - .filter(col("id").eq(lit(1)))? - .select_columns(&["bool_col", "int_col"])?; + let res = &df_renamed.clone().collect().await?; + + assert_batches_sorted_eq!( + [ + "+---------+", + "| CoLuMn1 |", + "+---------+", + "| a |", + "+---------+" + ], + res + ); - let plan = df.explain(false, false)?.collect().await?; - // Filters all the way to Parquet - let formatted = pretty::pretty_format_batches(&plan)?.to_string(); - assert!(formatted.contains("FilterExec: id@0 = 1")); + let df_renamed = df_renamed + .with_column_renamed("CoLuMn1", "c1")? + .collect() + .await?; + + assert_batches_sorted_eq!( + ["+----+", "| c1 |", "+----+", "| a |", "+----+"], + &df_renamed + ); Ok(()) } @@ -1860,12 +2553,12 @@ mod tests { let df_results = df.clone().collect().await?; df.clone().show().await?; assert_batches_sorted_eq!( - vec![ + [ "+----+----+-----+", "| c2 | c3 | sum |", "+----+----+-----+", "| 2 | 1 | 3 |", - "+----+----+-----+", + "+----+----+-----+" ], &df_results ); @@ -1926,13 +2619,13 @@ mod tests { let df_results = df.collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+------+-------+", "| f.c1 | f.c2 |", "+------+-------+", "| 1 | hello |", "| 10 | hello |", - "+------+-------+", + "+------+-------+" ], &df_results ); @@ -1958,12 +2651,12 @@ mod tests { let df_results = df.collect().await?; let cached_df_results = cached_df.collect().await?; assert_batches_sorted_eq!( - vec![ + [ "+----+----+-----+", "| c2 | c3 | sum |", "+----+----+-----+", "| 2 | 1 | 3 |", - "+----+----+-----+", + "+----+----+-----+" ], &cached_df_results ); diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs new file mode 100644 index 0000000000000..36ef90c987e35 --- /dev/null +++ b/datafusion/core/src/dataframe/parquet.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::file_options::parquet_writer::{ + default_builder, ParquetWriterOptions, +}; +use parquet::file::properties::WriterProperties; + +use super::{ + CompressionTypeVariant, CopyOptions, DataFrame, DataFrameWriteOptions, + DataFusionError, FileType, FileTypeWriterOptions, LogicalPlanBuilder, RecordBatch, +}; + +impl DataFrame { + /// Write a `DataFrame` to a Parquet file. + pub async fn write_parquet( + self, + path: &str, + options: DataFrameWriteOptions, + writer_properties: Option, + ) -> Result, DataFusionError> { + if options.overwrite { + return Err(DataFusionError::NotImplemented( + "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), + )); + } + match options.compression{ + CompressionTypeVariant::UNCOMPRESSED => (), + _ => return Err(DataFusionError::Configuration("DataFrame::write_parquet method does not support compression set via DataFrameWriteOptions. Set parquet compression via writer_properties instead.".to_owned())) + } + let props = match writer_properties { + Some(props) => props, + None => default_builder(self.session_state.config_options())?.build(), + }; + let file_type_writer_options = + FileTypeWriterOptions::Parquet(ParquetWriterOptions::new(props)); + let copy_options = CopyOptions::WriterOptions(Box::new(file_type_writer_options)); + let plan = LogicalPlanBuilder::copy_to( + self.plan, + path.into(), + FileType::PARQUET, + options.single_file_output, + copy_options, + )? + .build()?; + DataFrame::new(self.session_state, plan).collect().await + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use object_store::local::LocalFileSystem; + use parquet::basic::{BrotliLevel, GzipLevel, ZstdLevel}; + use parquet::file::reader::FileReader; + use tempfile::TempDir; + use url::Url; + + use datafusion_expr::{col, lit}; + + use crate::arrow::util::pretty; + use crate::execution::context::SessionContext; + use crate::execution::options::ParquetReadOptions; + use crate::test_util; + + use super::super::Result; + use super::*; + + #[tokio::test] + async fn filter_pushdown_dataframe() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_parquet( + "test", + &format!( + "{}/alltypes_plain.snappy.parquet", + test_util::parquet_test_data() + ), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_table("t1", ctx.table("test").await?.into_view())?; + + let df = ctx + .table("t1") + .await? + .filter(col("id").eq(lit(1)))? + .select_columns(&["bool_col", "int_col"])?; + + let plan = df.explain(false, false)?.collect().await?; + // Filters all the way to Parquet + let formatted = pretty::pretty_format_batches(&plan)?.to_string(); + assert!(formatted.contains("FilterExec: id@0 = 1")); + + Ok(()) + } + + #[tokio::test] + async fn write_parquet_with_compression() -> Result<()> { + let test_df = test_util::test_table().await?; + + let output_path = "file://local/test.parquet"; + let test_compressions = vec![ + parquet::basic::Compression::SNAPPY, + parquet::basic::Compression::LZ4, + parquet::basic::Compression::LZ4_RAW, + parquet::basic::Compression::GZIP(GzipLevel::default()), + parquet::basic::Compression::BROTLI(BrotliLevel::default()), + parquet::basic::Compression::ZSTD(ZstdLevel::default()), + ]; + for compression in test_compressions.into_iter() { + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + let ctx = &test_df.session_state; + ctx.runtime_env().register_object_store(&local_url, local); + df.write_parquet( + output_path, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(compression) + .build(), + ), + ) + .await?; + + // Check that file actually used the specified compression + let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?; + + let reader = + parquet::file::serialized_reader::SerializedFileReader::new(file) + .unwrap(); + + let parquet_metadata = reader.metadata(); + + let written_compression = + parquet_metadata.row_group(0).column(0).compression(); + + assert_eq!(written_compression, compression); + } + + Ok(()) + } +} diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs similarity index 60% rename from datafusion/core/src/avro_to_arrow/arrow_array_reader.rs rename to datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 311e199f28c46..855a8d0dbf40c 100644 --- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -35,12 +35,13 @@ use crate::arrow::error::ArrowError; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::bit_util; use crate::error::{DataFusionError, Result}; +use apache_avro::schema::RecordSchema; use apache_avro::{ schema::{Schema as AvroSchema, SchemaKind}, types::Value, AvroResult, Error as AvroError, Reader as AvroReader, }; -use arrow::array::{BinaryArray, GenericListArray}; +use arrow::array::{BinaryArray, FixedSizeBinaryArray, GenericListArray}; use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::ArrowError::SchemaError; use arrow::error::Result as ArrowResult; @@ -77,16 +78,72 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { pub fn schema_lookup(schema: AvroSchema) -> Result> { match schema { - AvroSchema::Record { - lookup: ref schema_lookup, - .. - } => Ok(schema_lookup.clone()), + AvroSchema::Record(RecordSchema { + fields, mut lookup, .. + }) => { + for field in fields { + Self::child_schema_lookup(&field.name, &field.schema, &mut lookup)?; + } + Ok(lookup) + } _ => Err(DataFusionError::ArrowError(SchemaError( "expected avro schema to be a record".to_string(), ))), } } + fn child_schema_lookup<'b>( + parent_field_name: &str, + schema: &AvroSchema, + schema_lookup: &'b mut BTreeMap, + ) -> Result<&'b BTreeMap> { + match schema { + AvroSchema::Union(us) => { + let has_nullable = us + .find_schema_with_known_schemata::( + &Value::Null, + None, + &None, + ) + .is_some(); + let sub_schemas = us.variants(); + if has_nullable && sub_schemas.len() == 2 { + if let Some(sub_schema) = + sub_schemas.iter().find(|&s| !matches!(s, AvroSchema::Null)) + { + Self::child_schema_lookup( + parent_field_name, + sub_schema, + schema_lookup, + )?; + } + } + } + AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { + lookup.iter().for_each(|(field_name, pos)| { + schema_lookup + .insert(format!("{}.{}", parent_field_name, field_name), *pos); + }); + + for field in fields { + let sub_parent_field_name = + format!("{}.{}", parent_field_name, field.name); + Self::child_schema_lookup( + &sub_parent_field_name, + &field.schema, + schema_lookup, + )?; + } + } + AvroSchema::Array(schema) => { + let sub_parent_field_name = format!("{}.element", parent_field_name); + Self::child_schema_lookup(&sub_parent_field_name, schema, schema_lookup)?; + } + _ => (), + } + Ok(schema_lookup) + } + /// Read the next batch of records pub fn next_batch(&mut self, batch_size: usize) -> Option> { let rows_result = self @@ -114,7 +171,8 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let rows = rows.iter().collect::>>(); let projection = self.projection.clone().unwrap_or_default(); - let arrays = self.build_struct_array(&rows, self.schema.fields(), &projection); + let arrays = + self.build_struct_array(&rows, "", self.schema.fields(), &projection); let projected_fields = if projection.is_empty() { self.schema.fields().clone() } else { @@ -272,6 +330,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { for row in rows { if let Some(value) = self.field_lookup(col_name, row) { + let value = maybe_resolve_union(value); // value can be an array or a scalar let vals: Vec> = if let Value::String(v) = value { vec![Some(v.to_string())] @@ -411,6 +470,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { /// Build a nested GenericListArray from a list of unnested `Value`s fn build_nested_list_array( &self, + parent_field_name: &str, rows: &[&Value], list_field: &Field, ) -> ArrowResult { @@ -497,13 +557,19 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .collect::() .into_data(), DataType::List(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; + let child = self.build_nested_list_array::( + parent_field_name, + &flatten_values(rows), + field, + )?; child.to_data() } DataType::LargeList(field) => { - let child = - self.build_nested_list_array::(&flatten_values(rows), field)?; + let child = self.build_nested_list_array::( + parent_field_name, + &flatten_values(rows), + field, + )?; child.to_data() } DataType::Struct(fields) => { @@ -518,26 +584,39 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let num_bytes = bit_util::ceil(array_item_count, 8); let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); let mut struct_index = 0; - let rows: Vec> = rows + let null_struct_array = vec![("null".to_string(), Value::Null)]; + let rows: Vec<&Vec<(String, Value)>> = rows .iter() - .map(|row| { + .map(|v| maybe_resolve_union(v)) + .flat_map(|row| { if let Value::Array(values) = row { - values.iter().for_each(|_| { - bit_util::set_bit(&mut null_buffer, struct_index); - struct_index += 1; - }); values .iter() - .map(|v| ("".to_string(), v.clone())) - .collect::>() + .map(maybe_resolve_union) + .map(|v| match v { + Value::Record(record) => { + bit_util::set_bit(&mut null_buffer, struct_index); + struct_index += 1; + record + } + Value::Null => { + struct_index += 1; + &null_struct_array + } + other => panic!("expected Record, got {other:?}"), + }) + .collect::>>() } else { struct_index += 1; - vec![("null".to_string(), Value::Null)] + vec![&null_struct_array] } }) .collect(); - let rows = rows.iter().collect::>>(); - let arrays = self.build_struct_array(&rows, fields, &[])?; + + let sub_parent_field_name = + format!("{}.{}", parent_field_name, list_field.name()); + let arrays = + self.build_struct_array(&rows, &sub_parent_field_name, fields, &[])?; let data_type = DataType::Struct(fields.clone()); ArrayDataBuilder::new(data_type) .len(rows.len()) @@ -574,6 +653,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn build_struct_array( &self, rows: RecordSlice, + parent_field_name: &str, struct_fields: &Fields, projection: &[String], ) -> ArrowResult> { @@ -581,78 +661,83 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .iter() .filter(|field| projection.is_empty() || projection.contains(field.name())) .map(|field| { + let field_path = if parent_field_name.is_empty() { + field.name().to_string() + } else { + format!("{}.{}", parent_field_name, field.name()) + }; let arr = match field.data_type() { DataType::Null => Arc::new(NullArray::new(rows.len())) as ArrayRef, - DataType::Boolean => self.build_boolean_array(rows, field.name()), + DataType::Boolean => self.build_boolean_array(rows, &field_path), DataType::Float64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Float32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int16 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Int8 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt16 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::UInt8 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } // TODO: this is incomplete DataType::Timestamp(unit, _) => match unit { TimeUnit::Second => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Microsecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Millisecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Nanosecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), }, DataType::Date64 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Date32 => { - self.build_primitive_array::(rows, field.name()) + self.build_primitive_array::(rows, &field_path) } DataType::Time64(unit) => match unit { TimeUnit::Microsecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), TimeUnit::Nanosecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), t => { return Err(ArrowError::SchemaError(format!( @@ -662,14 +747,11 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { }, DataType::Time32(unit) => match unit { TimeUnit::Second => self - .build_primitive_array::( - rows, - field.name(), - ), + .build_primitive_array::(rows, &field_path), TimeUnit::Millisecond => self .build_primitive_array::( rows, - field.name(), + &field_path, ), t => { return Err(ArrowError::SchemaError(format!( @@ -680,7 +762,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Utf8 | DataType::LargeUtf8 => Arc::new( rows.iter() .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); + let maybe_value = self.field_lookup(&field_path, row); match maybe_value { None => Ok(None), Some(v) => resolve_string(v), @@ -692,27 +774,37 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Binary | DataType::LargeBinary => Arc::new( rows.iter() .map(|row| { - let maybe_value = self.field_lookup(field.name(), row); + let maybe_value = self.field_lookup(&field_path, row); maybe_value.and_then(resolve_bytes) }) .collect::(), ) as ArrayRef, + DataType::FixedSizeBinary(ref size) => { + Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size( + rows.iter().map(|row| { + let maybe_value = self.field_lookup(&field_path, row); + maybe_value.and_then(|v| resolve_fixed(v, *size as usize)) + }), + *size, + )?) as ArrayRef + } DataType::List(ref list_field) => { match list_field.data_type() { DataType::Dictionary(ref key_ty, _) => { - self.build_wrapped_list_array(rows, field.name(), key_ty)? + self.build_wrapped_list_array(rows, &field_path, key_ty)? } _ => { // extract rows by name let extracted_rows = rows .iter() .map(|row| { - self.field_lookup(field.name(), row) + self.field_lookup(&field_path, row) .unwrap_or(&Value::Null) }) .collect::>(); self.build_nested_list_array::( + &field_path, &extracted_rows, list_field, )? @@ -722,7 +814,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Dictionary(ref key_ty, ref val_ty) => self .build_string_dictionary_array( rows, - field.name(), + &field_path, key_ty, val_ty, )?, @@ -730,21 +822,31 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let len = rows.len(); let num_bytes = bit_util::ceil(len, 8); let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); + let empty_vec = vec![]; let struct_rows = rows .iter() .enumerate() - .map(|(i, row)| (i, self.field_lookup(field.name(), row))) + .map(|(i, row)| (i, self.field_lookup(&field_path, row))) .map(|(i, v)| { - if let Some(Value::Record(value)) = v { - bit_util::set_bit(&mut null_buffer, i); - value - } else { - panic!("expected struct got {v:?}"); + let v = v.map(maybe_resolve_union); + match v { + Some(Value::Record(value)) => { + bit_util::set_bit(&mut null_buffer, i); + value + } + None | Some(Value::Null) => &empty_vec, + other => { + panic!("expected struct got {other:?}"); + } } }) .collect::>>(); - let arrays = - self.build_struct_array(&struct_rows, fields, &[])?; + let arrays = self.build_struct_array( + &struct_rows, + &field_path, + fields, + &[], + )?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -857,6 +959,7 @@ fn resolve_string(v: &Value) -> ArrowResult> { Value::Bytes(bytes) => String::from_utf8(bytes.to_vec()) .map_err(AvroError::ConvertToUtf8) .map(Some), + Value::Enum(_, s) => Ok(Some(s.clone())), Value::Null => Ok(None), other => Err(AvroError::GetString(other.into())), } @@ -899,6 +1002,20 @@ fn resolve_bytes(v: &Value) -> Option> { }) } +fn resolve_fixed(v: &Value, size: usize) -> Option> { + let v = if let Value::Union(_, b) = v { b } else { v }; + match v { + Value::Fixed(n, bytes) => { + if *n == size { + Some(bytes.clone()) + } else { + None + } + } + _ => None, + } +} + fn resolve_boolean(value: &Value) -> Option { let v = if let Value::Union(_, b) = value { b @@ -957,8 +1074,9 @@ where mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; - use crate::avro_to_arrow::{Reader, ReaderBuilder}; + use crate::datasource::avro_to_arrow::{Reader, ReaderBuilder}; use arrow::datatypes::DataType; + use datafusion_common::assert_batches_eq; use datafusion_common::cast::{ as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array, }; @@ -1019,7 +1137,7 @@ mod test { let a_array = as_list_array(batch.column(col_id_index)).unwrap(); assert_eq!( *a_array.data_type(), - DataType::List(Arc::new(Field::new("bigint", DataType::Int64, true))) + DataType::List(Arc::new(Field::new("element", DataType::Int64, true))) ); let array = a_array.value(0); assert_eq!(*array.data_type(), DataType::Int64); @@ -1041,6 +1159,493 @@ mod test { assert_eq!(batch.num_rows(), 3); } + #[test] + fn test_complex_list() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "headers", + "type": ["null", { + "type": "array", + "items": ["null",{ + "name":"r2", + "type": "record", + "fields":[ + {"name":"name", "type": ["null", "string"], "default": null}, + {"name":"value", "type": ["null", "string"], "default": null} + ] + }] + }], + "default": null + } + ] + }"#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ + "headers": [ + { + "name": "a", + "value": "b" + } + ] + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + let expected = [ + "+-----------------------+", + "| headers |", + "+-----------------------+", + "| [{name: a, value: b}] |", + "+-----------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_complex_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "dns", + "type": [ + "null", + { + "type": "record", + "name": "r13", + "fields": [ + { + "name": "answers", + "type": [ + "null", + { + "type": "array", + "items": [ + "null", + { + "type": "record", + "name": "r292", + "fields": [ + { + "name": "class", + "type": ["null", "string"], + "default": null + }, + { + "name": "data", + "type": ["null", "string"], + "default": null + }, + { + "name": "name", + "type": ["null", "string"], + "default": null + }, + { + "name": "ttl", + "type": ["null", "long"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ] + } + ], + "default": null + }, + { + "name": "header_flags", + "type": [ + "null", + { + "type": "array", + "items": ["null", "string"] + } + ], + "default": null + }, + { + "name": "id", + "type": ["null", "string"], + "default": null + }, + { + "name": "op_code", + "type": ["null", "string"], + "default": null + }, + { + "name": "question", + "type": [ + "null", + { + "type": "record", + "name": "r288", + "fields": [ + { + "name": "class", + "type": ["null", "string"], + "default": null + }, + { + "name": "name", + "type": ["null", "string"], + "default": null + }, + { + "name": "registered_domain", + "type": ["null", "string"], + "default": null + }, + { + "name": "subdomain", + "type": ["null", "string"], + "default": null + }, + { + "name": "top_level_domain", + "type": ["null", "string"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + }, + { + "name": "resolved_ip", + "type": [ + "null", + { + "type": "array", + "items": ["null", "string"] + } + ], + "default": null + }, + { + "name": "response_code", + "type": ["null", "string"], + "default": null + }, + { + "name": "type", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + + let jv1 = serde_json::json!({ + "dns": { + "answers": [ + { + "data": "CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=", + "type": "1" + }, + { + "data": "CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=", + "type": "1" + }, + { + "data": "CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=", + "type": "1" + } + ], + "question": { + "name": "security.ubuntu.com", + "type": "A" + }, + "resolved_ip": [ + "67.43.156.1", + "67.43.156.2", + "67.43.156.3" + ], + "response_code": "0" + } + }); + let r1 = apache_avro::to_value(jv1) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(1) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| dns |", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| {answers: [{class: , data: CHNlY3VyaXR5BnVidW50dQMjb20AAAEAAQAAAAgABLl9vic=, name: , ttl: , type: 1}, {class: , data: CHNlY3VyaXR5BnVidW50dQNjb20AAAEAABAAAAgABLl9viQ=, name: , ttl: , type: 1}, {class: , data: CHNlT3VyaXR5BnVidW50dQNjb20AAAEAAQAAAAgABFu9Wyc=, name: , ttl: , type: 1}], header_flags: , id: , op_code: , question: {class: , name: security.ubuntu.com, registered_domain: , subdomain: , top_level_domain: , type: A}, resolved_ip: [67.43.156.1, 67.43.156.2, 67.43.156.3], response_code: 0, type: } |", + "+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_deep_nullable_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": [ + "null", + { + "type": "record", + "name": "r3", + "fields": [ + { + "name": "col3", + "type": [ + "null", + { + "type": "record", + "name": "r4", + "fields": [ + { + "name": "col4", + "type": [ + "null", + { + "type": "record", + "name": "r5", + "fields": [ + { + "name": "col5", + "type": ["null", "string"] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + ] + } + "#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": { + "col4": { + "col5": "hello" + } + } + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": { + "col4": { + "col5": null + } + } + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r3 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": { + "col3": null + } + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r4 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + w.append(r2).unwrap(); + w.append(r3).unwrap(); + w.append(r4).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(4) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + let expected = [ + "+---------------------------------------+", + "| col1 |", + "+---------------------------------------+", + "| {col2: {col3: {col4: {col5: hello}}}} |", + "| {col2: {col3: {col4: {col5: }}}} |", + "| {col2: {col3: }} |", + "| |", + "+---------------------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + + #[test] + fn test_avro_nullable_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": ["null", "string"] + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let r1 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": "hello" + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + let r3 = apache_avro::to_value(serde_json::json!({ + "col1": { + "col2": null + } + })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + w.append(r2).unwrap(); + w.append(r3).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(3) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+---------------+", + "| col1 |", + "+---------------+", + "| |", + "| {col2: hello} |", + "| {col2: } |", + "+---------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/avro_to_arrow/mod.rs b/datafusion/core/src/datasource/avro_to_arrow/mod.rs similarity index 92% rename from datafusion/core/src/avro_to_arrow/mod.rs rename to datafusion/core/src/datasource/avro_to_arrow/mod.rs index 8ca7f22ef3b12..af0bb86a3e273 100644 --- a/datafusion/core/src/avro_to_arrow/mod.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/mod.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! This module contains utilities to manipulate avro metadata. +//! This module contains code for reading [Avro] data into `RecordBatch`es +//! +//! [Avro]: https://avro.apache.org/docs/1.2.0/ #[cfg(feature = "avro")] mod arrow_array_reader; diff --git a/datafusion/core/src/avro_to_arrow/reader.rs b/datafusion/core/src/datasource/avro_to_arrow/reader.rs similarity index 96% rename from datafusion/core/src/avro_to_arrow/reader.rs rename to datafusion/core/src/datasource/avro_to_arrow/reader.rs index c5dab22a2d00c..5dc53c5c86c87 100644 --- a/datafusion/core/src/avro_to_arrow/reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/reader.rs @@ -56,17 +56,21 @@ impl ReaderBuilder { /// # Example /// /// ``` - /// extern crate apache_avro; - /// /// use std::fs::File; /// - /// fn example() -> crate::datafusion::avro_to_arrow::Reader<'static, File> { + /// use datafusion::datasource::avro_to_arrow::{Reader, ReaderBuilder}; + /// + /// fn example() -> Reader<'static, File> { /// let file = File::open("test/data/basic.avro").unwrap(); /// /// // create a builder, inferring the schema with the first 100 records - /// let builder = crate::datafusion::avro_to_arrow::ReaderBuilder::new().read_schema().with_batch_size(100); + /// let builder = ReaderBuilder::new() + /// .read_schema() + /// .with_batch_size(100); /// - /// let reader = builder.build::(file).unwrap(); + /// let reader = builder + /// .build::(file) + /// .unwrap(); /// /// reader /// } diff --git a/datafusion/core/src/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs similarity index 82% rename from datafusion/core/src/avro_to_arrow/schema.rs rename to datafusion/core/src/datasource/avro_to_arrow/schema.rs index d4c881ca54eb1..761e6b62680f5 100644 --- a/datafusion/core/src/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -17,24 +17,25 @@ use crate::arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; use crate::error::{DataFusionError, Result}; -use apache_avro::schema::{Alias, Name}; +use apache_avro::schema::{ + Alias, DecimalSchema, EnumSchema, FixedSchema, Name, RecordSchema, +}; use apache_avro::types::Value; use apache_avro::Schema as AvroSchema; use arrow::datatypes::{Field, UnionFields}; use std::collections::HashMap; -use std::convert::TryFrom; use std::sync::Arc; /// Converts an avro schema to an arrow schema pub fn to_arrow_schema(avro_schema: &apache_avro::Schema) -> Result { let mut schema_fields = vec![]; match avro_schema { - AvroSchema::Record { fields, .. } => { + AvroSchema::Record(RecordSchema { fields, .. }) => { for field in fields { schema_fields.push(schema_to_field_with_props( &field.schema, Some(&field.name), - false, + field.is_nullable(), Some(external_props(&field.schema)), )?) } @@ -72,7 +73,7 @@ fn schema_to_field_with_props( AvroSchema::Bytes => DataType::Binary, AvroSchema::String => DataType::Utf8, AvroSchema::Array(item_schema) => DataType::List(Arc::new( - schema_to_field_with_props(item_schema, None, false, None)?, + schema_to_field_with_props(item_schema, Some("element"), false, None)?, )), AvroSchema::Map(value_schema) => { let value_field = @@ -84,7 +85,13 @@ fn schema_to_field_with_props( } AvroSchema::Union(us) => { // If there are only two variants and one of them is null, set the other type as the field data type - let has_nullable = us.find_schema(&Value::Null).is_some(); + let has_nullable = us + .find_schema_with_known_schemata::( + &Value::Null, + None, + &None, + ) + .is_some(); let sub_schemas = us.variants(); if has_nullable && sub_schemas.len() == 2 { nullable = true; @@ -109,7 +116,7 @@ fn schema_to_field_with_props( DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) } } - AvroSchema::Record { name, fields, .. } => { + AvroSchema::Record(RecordSchema { fields, .. }) => { let fields: Result<_> = fields .iter() .map(|field| { @@ -122,7 +129,7 @@ fn schema_to_field_with_props( }*/ schema_to_field_with_props( &field.schema, - Some(&format!("{}.{}", name.fullname(None), field.name)), + Some(&field.name), false, Some(props), ) @@ -130,25 +137,21 @@ fn schema_to_field_with_props( .collect(); DataType::Struct(fields?) } - AvroSchema::Enum { symbols, name, .. } => { - return Ok(Field::new_dict( - name.fullname(None), - index_type(symbols.len()), - false, - 0, - false, - )) + AvroSchema::Enum(EnumSchema { .. }) => DataType::Utf8, + AvroSchema::Fixed(FixedSchema { size, .. }) => { + DataType::FixedSizeBinary(*size as i32) } - AvroSchema::Fixed { size, .. } => DataType::FixedSizeBinary(*size as i32), - AvroSchema::Decimal { + AvroSchema::Decimal(DecimalSchema { precision, scale, .. - } => DataType::Decimal128(*precision as u8, *scale as i8), + }) => DataType::Decimal128(*precision as u8, *scale as i8), AvroSchema::Uuid => DataType::FixedSizeBinary(16), AvroSchema::Date => DataType::Date32, AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), + AvroSchema::LocalTimestampMillis => todo!(), + AvroSchema::LocalTimestampMicros => todo!(), AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), }; @@ -226,50 +229,38 @@ fn default_field_name(dt: &DataType) -> &str { } } -fn index_type(len: usize) -> DataType { - if len <= usize::from(u8::MAX) { - DataType::Int8 - } else if len <= usize::from(u16::MAX) { - DataType::Int16 - } else if usize::try_from(u32::MAX).map(|i| len < i).unwrap_or(false) { - DataType::Int32 - } else { - DataType::Int64 - } -} - fn external_props(schema: &AvroSchema) -> HashMap { let mut props = HashMap::new(); match &schema { - AvroSchema::Record { + AvroSchema::Record(RecordSchema { doc: Some(ref doc), .. - } - | AvroSchema::Enum { + }) + | AvroSchema::Enum(EnumSchema { doc: Some(ref doc), .. - } - | AvroSchema::Fixed { + }) + | AvroSchema::Fixed(FixedSchema { doc: Some(ref doc), .. - } => { + }) => { props.insert("avro::doc".to_string(), doc.clone()); } _ => {} } match &schema { - AvroSchema::Record { + AvroSchema::Record(RecordSchema { name: Name { namespace, .. }, aliases: Some(aliases), .. - } - | AvroSchema::Enum { + }) + | AvroSchema::Enum(EnumSchema { name: Name { namespace, .. }, aliases: Some(aliases), .. - } - | AvroSchema::Fixed { + }) + | AvroSchema::Fixed(FixedSchema { name: Name { namespace, .. }, aliases: Some(aliases), .. - } => { + }) => { let aliases: Vec = aliases .iter() .map(|alias| aliased(alias, namespace.as_deref(), None)) @@ -308,7 +299,7 @@ mod test { use crate::arrow::datatypes::DataType::{Binary, Float32, Float64, Timestamp, Utf8}; use crate::arrow::datatypes::TimeUnit::Microsecond; use crate::arrow::datatypes::{Field, Schema}; - use apache_avro::schema::{Alias, Name}; + use apache_avro::schema::{Alias, EnumSchema, FixedSchema, Name, RecordSchema}; use apache_avro::Schema as AvroSchema; use arrow::datatypes::DataType::{Boolean, Int32, Int64}; @@ -326,7 +317,7 @@ mod test { #[test] fn test_external_props() { - let record_schema = AvroSchema::Record { + let record_schema = AvroSchema::Record(RecordSchema { name: Name { name: "record".to_string(), namespace: None, @@ -335,7 +326,8 @@ mod test { doc: Some("record documentation".to_string()), fields: vec![], lookup: Default::default(), - }; + attributes: Default::default(), + }); let props = external_props(&record_schema); assert_eq!( props.get("avro::doc"), @@ -345,7 +337,7 @@ mod test { props.get("avro::aliases"), Some(&"[fooalias,baralias]".to_string()) ); - let enum_schema = AvroSchema::Enum { + let enum_schema = AvroSchema::Enum(EnumSchema { name: Name { name: "enum".to_string(), namespace: None, @@ -353,7 +345,9 @@ mod test { aliases: Some(vec![alias("fooenum"), alias("barenum")]), doc: Some("enum documentation".to_string()), symbols: vec![], - }; + default: None, + attributes: Default::default(), + }); let props = external_props(&enum_schema); assert_eq!( props.get("avro::doc"), @@ -363,7 +357,7 @@ mod test { props.get("avro::aliases"), Some(&"[fooenum,barenum]".to_string()) ); - let fixed_schema = AvroSchema::Fixed { + let fixed_schema = AvroSchema::Fixed(FixedSchema { name: Name { name: "fixed".to_string(), namespace: None, @@ -371,7 +365,8 @@ mod test { aliases: Some(vec![alias("foofixed"), alias("barfixed")]), size: 1, doc: None, - }; + attributes: Default::default(), + }); let props = external_props(&fixed_schema); assert_eq!( props.get("avro::aliases"), @@ -447,6 +442,58 @@ mod test { assert_eq!(arrow_schema.unwrap(), expected); } + #[test] + fn test_nested_schema() { + let avro_schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": "string" + }, + { + "name": "col3", + "type": ["null", "string"], + "default": null + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + // should not use Avro Record names. + let expected_arrow_schema = Schema::new(vec![Field::new( + "col1", + arrow::datatypes::DataType::Struct( + vec![ + Field::new("col2", Utf8, false), + Field::new("col3", Utf8, true), + ] + .into(), + ), + true, + )]); + assert_eq!( + to_arrow_schema(&avro_schema).unwrap(), + expected_arrow_schema + ); + } + #[test] fn test_non_record_schema() { let arrow_schema = to_arrow_schema(&AvroSchema::String); diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs deleted file mode 100644 index 11f30f33d1399..0000000000000 --- a/datafusion/core/src/datasource/datasource.rs +++ /dev/null @@ -1,143 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Data source traits - -use std::any::Any; -use std::sync::Arc; - -use async_trait::async_trait; -use datafusion_common::{DataFusionError, Statistics}; -use datafusion_expr::{CreateExternalTable, LogicalPlan}; -pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; - -use crate::arrow::datatypes::SchemaRef; -use crate::error::Result; -use crate::execution::context::SessionState; -use crate::logical_expr::Expr; -use crate::physical_plan::ExecutionPlan; - -/// Source table -#[async_trait] -pub trait TableProvider: Sync + Send { - /// Returns the table provider as [`Any`](std::any::Any) so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Get a reference to the schema for this table - fn schema(&self) -> SchemaRef; - - /// Get the type of this table for metadata/catalog purposes. - fn table_type(&self) -> TableType; - - /// Get the create statement used to create this table, if available. - fn get_table_definition(&self) -> Option<&str> { - None - } - - /// Get the Logical Plan of this table, if available. - fn get_logical_plan(&self) -> Option<&LogicalPlan> { - None - } - - /// Create an ExecutionPlan that will scan the table. - /// The table provider will be usually responsible of grouping - /// the source data into partitions that can be efficiently - /// parallelized or distributed. - async fn scan( - &self, - state: &SessionState, - projection: Option<&Vec>, - filters: &[Expr], - // limit can be used to reduce the amount scanned - // from the datasource as a performance optimization. - // If set, it contains the amount of rows needed by the `LogicalPlan`, - // The datasource should return *at least* this number of rows if available. - limit: Option, - ) -> Result>; - - /// Tests whether the table provider can make use of a filter expression - /// to optimise data retrieval. - #[deprecated(since = "20.0.0", note = "use supports_filters_pushdown instead")] - fn supports_filter_pushdown( - &self, - _filter: &Expr, - ) -> Result { - Ok(TableProviderFilterPushDown::Unsupported) - } - - /// Tests whether the table provider can make use of any or all filter expressions - /// to optimise data retrieval. - #[allow(deprecated)] - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - filters - .iter() - .map(|f| self.supports_filter_pushdown(f)) - .collect() - } - - /// Get statistics for this table, if available - fn statistics(&self) -> Option { - None - } - - /// Return an [`ExecutionPlan`] to insert data into this table, if - /// supported. - /// - /// The returned plan should return a single row in a UInt64 - /// column called "count" such as the following - /// - /// ```text - /// +-------+, - /// | count |, - /// +-------+, - /// | 6 |, - /// +-------+, - /// ``` - /// - /// # See Also - /// - /// See [`InsertExec`] for the common pattern of inserting a - /// single stream of `RecordBatch`es. - /// - /// [`InsertExec`]: crate::physical_plan::insert::InsertExec - async fn insert_into( - &self, - _state: &SessionState, - _input: Arc, - ) -> Result> { - let msg = "Insertion not implemented for this table".to_owned(); - Err(DataFusionError::NotImplemented(msg)) - } -} - -/// A factory which creates [`TableProvider`]s at runtime given a URL. -/// -/// For example, this can be used to create a table "on the fly" -/// from a directory of files only when that name is referenced. -#[async_trait] -pub trait TableProviderFactory: Sync + Send { - /// Create a TableProvider with the given url - async fn create( - &self, - state: &SessionState, - cmd: &CreateExternalTable, - ) -> Result>; -} diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index c6fd87e7f18b3..fadf01c74c5d4 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -17,17 +17,21 @@ //! Default TableSource implementation used in DataFusion physical plans +use std::any::Any; +use std::sync::Arc; + use crate::datasource::TableProvider; + use arrow::datatypes::SchemaRef; -use datafusion_common::DataFusionError; +use datafusion_common::{internal_err, Constraints, DataFusionError}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource}; -use std::any::Any; -use std::sync::Arc; -/// DataFusion default table source, wrapping TableProvider +/// DataFusion default table source, wrapping TableProvider. /// /// This structure adapts a `TableProvider` (physical plan trait) to the `TableSource` -/// (logical plan trait) +/// (logical plan trait) and is necessary because the logical plan is contained in +/// the `datafusion_expr` crate, and is not aware of table providers, which exist in +/// the core `datafusion` crate. pub struct DefaultTableSource { /// table provider pub table_provider: Arc, @@ -41,7 +45,7 @@ impl DefaultTableSource { } impl TableSource for DefaultTableSource { - /// Returns the table source as [`Any`](std::any::Any) so that it can be + /// Returns the table source as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any { self @@ -52,6 +56,11 @@ impl TableSource for DefaultTableSource { self.table_provider.schema() } + /// Get a reference to applicable constraints, if any exists. + fn constraints(&self) -> Option<&Constraints> { + self.table_provider.constraints() + } + /// Tests whether the table provider can make use of any or all filter expressions /// to optimise data retrieval. fn supports_filters_pushdown( @@ -64,6 +73,10 @@ impl TableSource for DefaultTableSource { fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> { self.table_provider.get_logical_plan() } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.table_provider.get_column_default(column) + } } /// Wrap TableProvider in TableSource @@ -84,8 +97,6 @@ pub fn source_as_provider( .downcast_ref::() { Some(source) => Ok(source.table_provider.clone()), - _ => Err(DataFusionError::Internal( - "TableSource was not DefaultTableSource".to_string(), - )), + _ => internal_err!("TableSource was not DefaultTableSource"), } } diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 37434002c1c71..5100987520ee1 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -22,12 +22,12 @@ use std::sync::Arc; use arrow::datatypes::*; use async_trait::async_trait; +use datafusion_common::project_schema; use crate::datasource::{TableProvider, TableType}; use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::project_schema; use crate::physical_plan::{empty::EmptyExec, ExecutionPlan}; /// An empty plan that is useful for testing and generating plans @@ -77,7 +77,7 @@ impl TableProvider for EmptyTable { // even though there is no data, projections apply let projected_schema = project_schema(&self.schema, projection)?; Ok(Arc::new( - EmptyExec::new(false, projected_schema).with_partitions(self.partitions), + EmptyExec::new(projected_schema).with_partitions(self.partitions), )) } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 2b3ef7ee4eab8..07c96bdae1b41 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -19,23 +19,30 @@ //! //! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) +use std::any::Any; +use std::borrow::Cow; +use std::sync::Arc; + use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::{ArrowExec, FileScanConfig}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; + +use arrow::ipc::convert::fb_to_schema; use arrow::ipc::reader::FileReader; -use arrow_schema::{Schema, SchemaRef}; -use async_trait::async_trait; -use datafusion_common::Statistics; +use arrow::ipc::root_as_message; +use arrow_schema::{ArrowError, Schema, SchemaRef}; + +use bytes::Bytes; +use datafusion_common::{FileType, Statistics}; use datafusion_physical_expr::PhysicalExpr; -use object_store::{GetResult, ObjectMeta, ObjectStore}; -use std::any::Any; -use std::io::{Read, Seek}; -use std::sync::Arc; -/// The default file extension of arrow files -pub const DEFAULT_ARROW_EXTENSION: &str = ".arrow"; +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::StreamExt; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -54,13 +61,14 @@ impl FileFormat for ArrowFormat { ) -> Result { let mut schemas = vec![]; for object in objects { - let schema = match store.get(&object.location).await? { - GetResult::File(mut file, _) => read_arrow_schema_from_reader(&mut file)?, - r @ GetResult::Stream(_) => { - // TODO: Fetching entire file to get schema is potentially wasteful - let data = r.bytes().await?; - let mut cursor = std::io::Cursor::new(&data); - read_arrow_schema_from_reader(&mut cursor)? + let r = store.as_ref().get(&object.location).await?; + let schema = match r.payload { + GetResultPayload::File(mut file, _) => { + let reader = FileReader::try_new(&mut file, None)?; + reader.schema() + } + GetResultPayload::Stream(stream) => { + infer_schema_from_file_stream(stream).await? } }; schemas.push(schema.as_ref().clone()); @@ -73,10 +81,10 @@ impl FileFormat for ArrowFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -88,9 +96,187 @@ impl FileFormat for ArrowFormat { let exec = ArrowExec::new(conf); Ok(Arc::new(exec)) } + + fn file_type(&self) -> FileType { + FileType::ARROW + } } -fn read_arrow_schema_from_reader(reader: R) -> Result { - let reader = FileReader::try_new(reader, None)?; - Ok(reader.schema()) +const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +/// See +async fn infer_schema_from_file_stream( + mut stream: BoxStream<'static, object_store::Result>, +) -> Result { + // Expected format: + // - 6 bytes + // - 2 bytes + // - 4 bytes, not present below v0.15.0 + // - 4 bytes + // + // + + // So in first read we need at least all known sized sections, + // which is 6 + 2 + 4 + 4 = 16 bytes. + let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; + + // Files should start with these magic bytes + if bytes[0..6] != ARROW_MAGIC { + return Err(ArrowError::ParseError( + "Arrow file does not contian correct header".to_string(), + ))?; + } + + // Since continuation marker bytes added in later versions + let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { + (&bytes[12..16], 16) + } else { + (&bytes[8..12], 12) + }; + + let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]]; + let meta_len = i32::from_le_bytes(meta_len); + + // Read bytes for Schema message + let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize { + // Need to read more bytes to decode Message + let mut block_data = Vec::with_capacity(meta_len as usize); + // In case we had some spare bytes in our initial read chunk + block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]); + let size_to_read = meta_len as usize - block_data.len(); + let block_data = + collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?; + Cow::Owned(block_data) + } else { + // Already have the bytes we need + let end_index = meta_len as usize + rest_of_bytes_start_index; + let block_data = &bytes[rest_of_bytes_start_index..end_index]; + Cow::Borrowed(block_data) + }; + + // Decode Schema message + let message = root_as_message(&block_data).map_err(|err| { + ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}")) + })?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + ArrowError::IpcError("Unable to read IPC message as schema".to_string()) + })?; + let schema = fb_to_schema(ipc_schema); + + Ok(Arc::new(schema)) +} + +async fn collect_at_least_n_bytes( + stream: &mut BoxStream<'static, object_store::Result>, + n: usize, + extend_from: Option>, +) -> Result> { + let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n)); + // If extending existing buffer then ensure we read n additional bytes + let n = n + buf.len(); + while let Some(bytes) = stream.next().await.transpose()? { + buf.extend_from_slice(&bytes); + if buf.len() >= n { + break; + } + } + if buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected end of byte stream for Arrow IPC file".to_string(), + ))?; + } + Ok(buf) +} + +#[cfg(test)] +mod tests { + use chrono::DateTime; + use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; + + use crate::execution::context::SessionContext; + + use super::*; + + #[tokio::test] + async fn test_infer_schema_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"]; + + // Test chunk sizes where too small so we keep having to read more bytes + // And when large enough that first read contains all we need + for chunk_size in [7, 3000] { + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size)); + let inferred_schema = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + &[object_meta.clone()], + ) + .await?; + let actual_fields = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect::>(); + assert_eq!(expected, actual_fields); + } + + Ok(()) + } + + #[tokio::test] + async fn test_infer_schema_short_stream() -> Result<()> { + let mut bytes = std::fs::read("tests/data/example.arrow")?; + bytes.truncate(20); // should cause error that file shorter than expected + let location = Path::parse("example.arrow")?; + let in_memory_store: Arc = Arc::new(InMemory::new()); + in_memory_store.put(&location, bytes.into()).await?; + + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let object_meta = ObjectMeta { + location, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let arrow_format = ArrowFormat {}; + + let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7)); + let err = arrow_format + .infer_schema( + &state, + &(store.clone() as Arc), + &[object_meta.clone()], + ) + .await; + + assert!(err.is_err()); + assert_eq!( + "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", + err.unwrap_err().to_string() + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index ab9f1f5dd0006..a24a28ad6fdd4 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -23,19 +23,18 @@ use std::sync::Arc; use arrow::datatypes::Schema; use arrow::{self, datatypes::SchemaRef}; use async_trait::async_trait; +use datafusion_common::FileType; use datafusion_physical_expr::PhysicalExpr; -use object_store::{GetResult, ObjectMeta, ObjectStore}; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; use super::FileFormat; -use crate::avro_to_arrow::read_avro_schema_from_reader; +use crate::datasource::avro_to_arrow::read_avro_schema_from_reader; use crate::datasource::physical_plan::{AvroExec, FileScanConfig}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; -/// The default file extension of avro files -pub const DEFAULT_AVRO_EXTENSION: &str = ".avro"; /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; @@ -54,9 +53,12 @@ impl FileFormat for AvroFormat { ) -> Result { let mut schemas = vec![]; for object in objects { - let schema = match store.get(&object.location).await? { - GetResult::File(mut file, _) => read_avro_schema_from_reader(&mut file)?, - r @ GetResult::Stream(_) => { + let r = store.as_ref().get(&object.location).await?; + let schema = match r.payload { + GetResultPayload::File(mut file, _) => { + read_avro_schema_from_reader(&mut file)? + } + GetResultPayload::Stream(_) => { // TODO: Fetching entire file to get schema is potentially wasteful let data = r.bytes().await?; read_avro_schema_from_reader(&mut data.as_ref())? @@ -72,10 +74,10 @@ impl FileFormat for AvroFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -87,6 +89,10 @@ impl FileFormat for AvroFormat { let exec = AvroExec::new(conf); Ok(Arc::new(exec)) } + + fn file_type(&self) -> FileType { + FileType::AVRO + } } #[cfg(test)] @@ -106,7 +112,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; @@ -176,8 +182,7 @@ mod tests { let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); - let expected = vec![ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + let expected = ["+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", @@ -188,8 +193,7 @@ mod tests { "| 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 |", "| 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 |", "| 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - ]; + "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+"]; crate::assert_batches_eq!(expected, &batches); Ok(()) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 01bf76ccf48d1..df6689af6b73c 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -20,47 +20,46 @@ use std::any::Any; use std::collections::HashSet; use std::fmt; -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::sync::Arc; -use arrow::csv::WriterBuilder; -use arrow::datatypes::{DataType, Field, Fields, Schema}; -use arrow::{self, datatypes::SchemaRef}; use arrow_array::RecordBatch; -use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; -use async_trait::async_trait; use bytes::{Buf, Bytes}; +use datafusion_physical_plan::metrics::MetricsSet; use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; -use tokio::io::{AsyncWrite, AsyncWriteExt}; - -use super::FileFormat; -use crate::datasource::file_format::file_type::FileCompressionType; -use crate::datasource::file_format::FileWriterMode; -use crate::datasource::file_format::{ - AbortMode, AbortableWrite, AsyncPutWriter, BatchSerializer, MultiPart, - DEFAULT_SCHEMA_INFER_MAX_RECORD, -}; + +use super::write::orchestration::stateless_multipart_put; +use super::{FileFormat, DEFAULT_SCHEMA_INFER_MAX_RECORD}; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ - CsvExec, FileGroupDisplay, FileMeta, FileScanConfig, FileSinkConfig, + CsvExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, }; use crate::error::Result; use crate::execution::context::SessionState; -use crate::physical_plan::insert::{DataSink, InsertExec}; -use crate::physical_plan::Statistics; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; -/// The default file extension of csv files -pub const DEFAULT_CSV_EXTENSION: &str = ".csv"; +use arrow::csv::WriterBuilder; +use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::{self, datatypes::SchemaRef}; + +use async_trait::async_trait; + /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] pub struct CsvFormat { has_header: bool, delimiter: u8, + quote: u8, + escape: Option, schema_infer_max_rec: Option, file_compression_type: FileCompressionType, } @@ -71,6 +70,8 @@ impl Default for CsvFormat { schema_infer_max_rec: Some(DEFAULT_SCHEMA_INFER_MAX_RECORD), has_header: true, delimiter: b',', + quote: b'"', + escape: None, file_compression_type: FileCompressionType::UNCOMPRESSED, } } @@ -159,6 +160,20 @@ impl CsvFormat { self } + /// The quote character in a row. + /// - default to '"' + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// The escape character in a row. + /// - default is None + pub fn with_escape(mut self, escape: Option) -> Self { + self.escape = escape; + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -173,6 +188,16 @@ impl CsvFormat { pub fn delimiter(&self) -> u8 { self.delimiter } + + /// The quote character. + pub fn quote(&self) -> u8 { + self.quote + } + + /// The escape character. + pub fn escape(&self) -> Option { + self.escape + } } #[async_trait] @@ -211,10 +236,10 @@ impl FileFormat for CsvFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -227,6 +252,8 @@ impl FileFormat for CsvFormat { conf, self.has_header, self.delimiter, + self.quote, + self.escape, self.file_compression_type.to_owned(), ); Ok(Arc::new(exec)) @@ -237,15 +264,29 @@ impl FileFormat for CsvFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, + order_requirements: Option>, ) -> Result> { - let sink = Arc::new(CsvSink::new( - conf, - self.has_header, - self.delimiter, - self.file_compression_type.clone(), - )); + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for CSV"); + } + + if self.file_compression_type != FileCompressionType::UNCOMPRESSED { + return not_impl_err!("Inserting compressed CSV is not implemented yet."); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(CsvSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } - Ok(Arc::new(InsertExec::new(input, sink)) as _) + fn file_type(&self) -> FileType { + FileType::CSV } } @@ -292,14 +333,12 @@ impl CsvFormat { first_chunk = false; } else { if fields.len() != column_type_possibilities.len() { - return Err(DataFusionError::Execution( - format!( + return exec_err!( "Encountered unequal lengths between records on CSV file whilst inferring schema. \ Expected {} records, found {} records", column_type_possibilities.len(), fields.len() - ) - )); + ); } column_type_possibilities.iter_mut().zip(&fields).for_each( @@ -394,33 +433,19 @@ impl CsvSerializer { impl BatchSerializer for CsvSerializer { async fn serialize(&mut self, batch: RecordBatch) -> Result { let builder = self.builder.clone(); - let mut writer = builder.has_headers(self.header).build(&mut self.buffer); + let mut writer = builder.with_header(self.header).build(&mut self.buffer); writer.write(&batch)?; drop(writer); self.header = false; Ok(Bytes::from(self.buffer.drain(..).collect::>())) } -} -async fn check_for_errors( - result: Result, - writers: &mut [AbortableWrite], -) -> Result { - match result { - Ok(value) => Ok(value), - Err(e) => { - // Abort all writers before returning the error: - for writer in writers { - let mut abort_future = writer.abort_writer(); - if let Ok(abort_future) = &mut abort_future { - let _ = abort_future.await; - } - // Ignore errors that occur during abortion, - // We do try to abort all writers before returning error. - } - // After aborting writers return original error. - Err(e) - } + fn duplicate(&mut self) -> Result> { + let new_self = CsvSerializer::new() + .with_builder(self.builder.clone()) + .with_header(self.header); + self.header = false; + Ok(Box::new(new_self)) } } @@ -428,163 +453,79 @@ async fn check_for_errors( struct CsvSink { /// Config options for writing data config: FileSinkConfig, - has_header: bool, - delimiter: u8, - file_compression_type: FileCompressionType, } impl Debug for CsvSink { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CsvSink") - .field("has_header", &self.has_header) - .field("delimiter", &self.delimiter) - .field("file_compression_type", &self.file_compression_type) - .finish() + f.debug_struct("CsvSink").finish() } } -impl Display for CsvSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "CsvSink(writer_mode={:?}, file_groups={})", - self.config.writer_mode, - FileGroupDisplay(&self.config.file_groups), - ) +impl DisplayAs for CsvSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CsvSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } } } impl CsvSink { - fn new( - config: FileSinkConfig, - has_header: bool, - delimiter: u8, - file_compression_type: FileCompressionType, - ) -> Self { - Self { - config, - has_header, - delimiter, - file_compression_type, - } + fn new(config: FileSinkConfig) -> Self { + Self { config } } - - // Create a write for Csv files - async fn create_writer( + async fn multipartput_all( &self, - file_meta: FileMeta, - object_store: Arc, - ) -> Result>> { - let object = &file_meta.object_meta; - match self.config.writer_mode { - // If the mode is append, call the store's append method and return wrapped in - // a boxed trait object. - FileWriterMode::Append => { - let writer = object_store - .append(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - let writer = AbortableWrite::new( - self.file_compression_type.convert_async_writer(writer)?, - AbortMode::Append, - ); - Ok(writer) - } - // If the mode is put, create a new AsyncPut writer and return it wrapped in - // a boxed trait object - FileWriterMode::Put => { - let writer = Box::new(AsyncPutWriter::new(object.clone(), object_store)); - let writer = AbortableWrite::new( - self.file_compression_type.convert_async_writer(writer)?, - AbortMode::Put, - ); - Ok(writer) - } - // If the mode is put multipart, call the store's put_multipart method and - // return the writer wrapped in a boxed trait object. - FileWriterMode::PutMultipart => { - let (multipart_id, writer) = object_store - .put_multipart(&object.location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - self.file_compression_type.convert_async_writer(writer)?, - AbortMode::MultiPart(MultiPart::new( - object_store, - multipart_id, - object.location.clone(), - )), - )) - } - } + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let writer_options = self.config.file_type_writer_options.try_into_csv()?; + let builder = &writer_options.writer_options; + + let builder_clone = builder.clone(); + let options_clone = writer_options.clone(); + let get_serializer = move || { + let inner_clone = builder_clone.clone(); + let serializer: Box = Box::new( + CsvSerializer::new() + .with_builder(inner_clone) + .with_header(options_clone.writer_options.header()), + ); + serializer + }; + + stateless_multipart_put( + data, + context, + "csv".into(), + Box::new(get_serializer), + &self.config, + writer_options.compression.into(), + ) + .await } } #[async_trait] impl DataSink for CsvSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + async fn write_all( &self, - mut data: SendableRecordBatchStream, + data: SendableRecordBatchStream, context: &Arc, ) -> Result { - let num_partitions = self.config.file_groups.len(); - - let object_store = context - .runtime_env() - .object_store(&self.config.object_store_url)?; - - // Construct serializer and writer for each file group - let mut serializers = vec![]; - let mut writers = vec![]; - for file_group in &self.config.file_groups { - // In append mode, consider has_header flag only when file is empty (at the start). - // For other modes, use has_header flag as is. - let header = self.has_header - && (!matches!(&self.config.writer_mode, FileWriterMode::Append) - || file_group.object_meta.size == 0); - let builder = WriterBuilder::new().with_delimiter(self.delimiter); - let serializer = CsvSerializer::new() - .with_builder(builder) - .with_header(header); - serializers.push(serializer); - - let file = file_group.clone(); - let writer = self - .create_writer(file.object_meta.clone().into(), object_store.clone()) - .await?; - writers.push(writer); - } - - let mut idx = 0; - let mut row_count = 0; - // Map errors to DatafusionError. - let err_converter = - |_| DataFusionError::Internal("Unexpected FileSink Error".to_string()); - while let Some(maybe_batch) = data.next().await { - // Write data to files in a round robin fashion: - idx = (idx + 1) % num_partitions; - let serializer = &mut serializers[idx]; - let batch = check_for_errors(maybe_batch, &mut writers).await?; - row_count += batch.num_rows(); - let bytes = - check_for_errors(serializer.serialize(batch).await, &mut writers).await?; - let writer = &mut writers[idx]; - check_for_errors( - writer.write_all(&bytes).await.map_err(err_converter), - &mut writers, - ) - .await?; - } - // Perform cleanup: - let n_writers = writers.len(); - for idx in 0..n_writers { - check_for_errors( - writers[idx].shutdown().await.map_err(err_converter), - &mut writers, - ) - .await?; - } - Ok(row_count as u64) + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) } } @@ -592,8 +533,11 @@ impl DataSink for CsvSink { mod tests { use super::super::test_util::scan_format; use super::*; + use crate::arrow::util::pretty; use crate::assert_batches_eq; + use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::test_util::VariableStream; + use crate::datasource::listing::ListingOptions; use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; @@ -601,16 +545,21 @@ mod tests { use bytes::Bytes; use chrono::DateTime; use datafusion_common::cast::as_string_array; + use datafusion_common::internal_err; + use datafusion_common::stats::Precision; + use datafusion_common::FileType; + use datafusion_common::GetExt; use datafusion_expr::{col, lit}; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; + use regex::Regex; use rstest::*; #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) @@ -630,8 +579,8 @@ mod tests { assert_eq!(tt_batches, 50 /* 100/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } @@ -724,6 +673,7 @@ mod tests { last_modified: DateTime::default(), size: usize::MAX, e_tag: None, + version: None, }; let num_rows_to_read = 100; @@ -773,6 +723,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn query_compress_data( file_compression_type: FileCompressionType, @@ -811,8 +762,7 @@ mod tests { Field::new("c13", DataType::Utf8, true), ]); - let compressed_csv = - csv.with_file_compression_type(file_compression_type.clone()); + let compressed_csv = csv.with_file_compression_type(file_compression_type); //convert compressed_stream to decoded_stream let decoded_stream = compressed_csv @@ -827,6 +777,7 @@ mod tests { Ok(()) } + #[cfg(feature = "compression")] #[tokio::test] async fn query_compress_csv() -> Result<()> { let ctx = SessionContext::new(); @@ -848,15 +799,13 @@ mod tests { .collect() .await?; #[rustfmt::skip] - let expected = vec![ - "+----+------+", + let expected = ["+----+------+", "| c2 | c3 |", "+----+------+", "| 5 | 36 |", "| 5 | -31 |", "| 5 | -101 |", - "+----+------+", - ]; + "+----+------+"]; assert_batches_eq!(expected, &record_batch); Ok(()) } @@ -919,4 +868,348 @@ mod tests { ); Ok(()) } + + /// Explain the `sql` query under `ctx` to make sure the underlying csv scan is parallelized + /// e.g. "CsvExec: file_groups={2 groups:" in plan means 2 CsvExec runs concurrently + async fn count_query_csv_partitions( + ctx: &SessionContext, + sql: &str, + ) -> Result { + let df = ctx.sql(&format!("EXPLAIN {sql}")).await?; + let result = df.collect().await?; + let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + + let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap(); + + if let Some(captures) = re.captures(&plan) { + if let Some(match_) = captures.get(1) { + let n_partitions = match_.as_str().parse::().unwrap(); + return Ok(n_partitions); + } + } + + internal_err!("query contains no CsvExec") + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_basic(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + let testdata = arrow_test_data(); + ctx.register_csv( + "aggr", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().has_header(true), + ) + .await?; + + let query = "select sum(c2) from aggr;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+--------------+", + "| SUM(aggr.c2) |", + "+--------------+", + "| 285 |", + "+--------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[cfg(feature = "compression")] + #[tokio::test] + async fn test_csv_parallel_compressed(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let csv_options = CsvReadOptions::default() + .has_header(true) + .file_compression_type(FileCompressionType::GZIP) + .file_extension("csv.gz"); + let ctx = SessionContext::new_with_config(config); + let testdata = arrow_test_data(); + ctx.register_csv( + "aggr", + &format!("{testdata}/csv/aggregate_test_100.csv.gz"), + csv_options, + ) + .await?; + + let query = "select sum(c3) from aggr;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+--------------+", + "| SUM(aggr.c3) |", + "+--------------+", + "| 781 |", + "+--------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // Compressed csv won't be scanned in parallel + + Ok(()) + } + + /// Read a single empty csv file in parallel + /// + /// empty_0_byte.csv: + /// (file is empty) + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_empty_file(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + ctx.register_csv( + "empty", + "tests/data/empty_0_byte.csv", + CsvReadOptions::new().has_header(false), + ) + .await?; + + // Require a predicate to enable repartition for the optimizer + let query = "select * from empty where random() > 0.5;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["++", + "++"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // Won't get partitioned if all files are empty + + Ok(()) + } + + /// Read a single empty csv file with header in parallel + /// + /// empty.csv: + /// c1,c2,c3 + #[rstest(n_partitions, case(1), case(2), case(3))] + #[tokio::test] + async fn test_csv_parallel_empty_with_header(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + ctx.register_csv( + "empty", + "tests/data/empty.csv", + CsvReadOptions::new().has_header(true), + ) + .await?; + + // Require a predicate to enable repartition for the optimizer + let query = "select * from empty where random() > 0.5;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["++", + "++"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } + + /// Read multiple empty csv files in parallel + /// + /// all_empty + /// ├── empty0.csv + /// ├── empty1.csv + /// └── empty2.csv + /// + /// empty0.csv/empty1.csv/empty2.csv: + /// (file is empty) + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_multiple_empty_files(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + let file_format = CsvFormat::default().with_has_header(false); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::CSV.get_ext()); + ctx.register_listing_table( + "empty", + "tests/data/empty_files/all_empty/", + listing_options, + None, + None, + ) + .await + .unwrap(); + + // Require a predicate to enable repartition for the optimizer + let query = "select * from empty where random() > 0.5;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["++", + "++"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // Won't get partitioned if all files are empty + + Ok(()) + } + + /// Read multiple csv files (some are empty) in parallel + /// + /// some_empty + /// ├── a_empty.csv + /// ├── b.csv + /// ├── c_empty.csv + /// ├── d.csv + /// └── e_empty.csv + /// + /// a_empty.csv/c_empty.csv/e_empty.csv: + /// (file is empty) + /// + /// b.csv/d.csv: + /// 1\n + /// 1\n + /// 1\n + /// 1\n + /// 1\n + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_some_file_empty(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + let file_format = CsvFormat::default().with_has_header(false); + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension(FileType::CSV.get_ext()); + ctx.register_listing_table( + "empty", + "tests/data/empty_files/some_empty", + listing_options, + None, + None, + ) + .await + .unwrap(); + + // Require a predicate to enable repartition for the optimizer + let query = "select sum(column_1) from empty where column_1 > 0;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+---------------------+", + "| SUM(empty.column_1) |", + "+---------------------+", + "| 10 |", + "+---------------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(n_partitions, actual_partitions); // Won't get partitioned if all files are empty + + Ok(()) + } + + /// Parallel scan on a csv file with only 1 byte in each line + /// Testing partition byte range land on line boundaries + /// + /// one_col.csv: + /// 5\n + /// 5\n + /// (...10 rows total) + #[rstest(n_partitions, case(1), case(2), case(3), case(5), case(10), case(32))] + #[tokio::test] + async fn test_csv_parallel_one_col(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + + ctx.register_csv( + "one_col", + "tests/data/one_col.csv", + CsvReadOptions::new().has_header(false), + ) + .await?; + + let query = "select sum(column_1) from one_col where column_1 > 0;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+-----------------------+", + "| SUM(one_col.column_1) |", + "+-----------------------+", + "| 50 |", + "+-----------------------+"]; + let file_size = if cfg!(target_os = "windows") { + 30 // new line on Win is '\r\n' + } else { + 20 + }; + // A 20-Byte file at most get partitioned into 20 chunks + let expected_partitions = if n_partitions <= file_size { + n_partitions + } else { + file_size + }; + assert_batches_eq!(expected, &query_result); + assert_eq!(expected_partitions, actual_partitions); + + Ok(()) + } + + /// Parallel scan on a csv file with 2 wide rows + /// The byte range of a partition might be within some line + /// + /// wode_rows.csv: + /// 1, 1, ..., 1\n (100 columns total) + /// 2, 2, ..., 2\n + #[rstest(n_partitions, case(1), case(2), case(10), case(16))] + #[tokio::test] + async fn test_csv_parallel_wide_rows(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let ctx = SessionContext::new_with_config(config); + ctx.register_csv( + "wide_rows", + "tests/data/wide_rows.csv", + CsvReadOptions::new().has_header(false), + ) + .await?; + + let query = "select sum(column_1) + sum(column_33) + sum(column_50) + sum(column_77) + sum(column_100) as sum_of_5_cols from wide_rows where column_1 > 0;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+---------------+", + "| sum_of_5_cols |", + "+---------------+", + "| 15 |", + "+---------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(n_partitions, actual_partitions); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/file_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs similarity index 80% rename from datafusion/core/src/datasource/file_format/file_type.rs rename to datafusion/core/src/datasource/file_format/file_compression_type.rs index 567fffb323675..3dac7c293050c 100644 --- a/datafusion/core/src/datasource/file_format/file_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -15,15 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! File type abstraction +//! File Compression type abstraction use crate::error::{DataFusionError, Result}; - -use crate::datasource::file_format::arrow::DEFAULT_ARROW_EXTENSION; -use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION; -use crate::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; -use crate::datasource::file_format::json::DEFAULT_JSON_EXTENSION; -use crate::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; #[cfg(feature = "compression")] use async_compression::tokio::bufread::{ BzDecoder as AsyncBzDecoder, BzEncoder as AsyncBzEncoder, @@ -37,7 +31,7 @@ use async_compression::tokio::write::{BzEncoder, GzipEncoder, XzEncoder, ZstdEnc use bytes::Bytes; #[cfg(feature = "compression")] use bzip2::read::MultiBzDecoder; -use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{parsers::CompressionTypeVariant, FileType, GetExt}; #[cfg(feature = "compression")] use flate2::read::MultiGzDecoder; @@ -55,14 +49,8 @@ use xz2::read::XzDecoder; use zstd::Decoder as ZstdDecoder; use CompressionTypeVariant::*; -/// Define each `FileType`/`FileCompressionType`'s extension -pub trait GetExt { - /// File extension getter - fn get_ext(&self) -> String; -} - /// Readable file compression type -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FileCompressionType { variant: CompressionTypeVariant, } @@ -237,59 +225,26 @@ impl FileCompressionType { } } -/// Readable file type -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum FileType { - /// Apache Arrow file - ARROW, - /// Apache Avro file - AVRO, - /// Apache Parquet file - PARQUET, - /// CSV file - CSV, - /// JSON file - JSON, -} - -impl GetExt for FileType { - fn get_ext(&self) -> String { - match self { - FileType::ARROW => DEFAULT_ARROW_EXTENSION.to_owned(), - FileType::AVRO => DEFAULT_AVRO_EXTENSION.to_owned(), - FileType::PARQUET => DEFAULT_PARQUET_EXTENSION.to_owned(), - FileType::CSV => DEFAULT_CSV_EXTENSION.to_owned(), - FileType::JSON => DEFAULT_JSON_EXTENSION.to_owned(), - } - } -} - -impl FromStr for FileType { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - let s = s.to_uppercase(); - match s.as_str() { - "ARROW" => Ok(FileType::ARROW), - "AVRO" => Ok(FileType::AVRO), - "PARQUET" => Ok(FileType::PARQUET), - "CSV" => Ok(FileType::CSV), - "JSON" | "NDJSON" => Ok(FileType::JSON), - _ => Err(DataFusionError::NotImplemented(format!( - "Unknown FileType: {s}" - ))), - } - } +/// Trait for extending the functionality of the `FileType` enum. +pub trait FileTypeExt { + /// Given a `FileCompressionType`, return the `FileType`'s extension with compression suffix + fn get_ext_with_compression(&self, c: FileCompressionType) -> Result; } -impl FileType { - /// Given a `FileCompressionType`, return the `FileType`'s extension with compression suffix - pub fn get_ext_with_compression(&self, c: FileCompressionType) -> Result { +impl FileTypeExt for FileType { + fn get_ext_with_compression(&self, c: FileCompressionType) -> Result { let ext = self.get_ext(); match self { FileType::JSON | FileType::CSV => Ok(format!("{}{}", ext, c.get_ext())), - FileType::PARQUET | FileType::AVRO | FileType::ARROW => match c.variant { + FileType::AVRO | FileType::ARROW => match c.variant { + UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "FileCompressionType can be specified for CSV/JSON FileType.".into(), + )), + }, + #[cfg(feature = "parquet")] + FileType::PARQUET => match c.variant { UNCOMPRESSED => Ok(ext), _ => Err(DataFusionError::Internal( "FileCompressionType can be specified for CSV/JSON FileType.".into(), @@ -301,8 +256,11 @@ impl FileType { #[cfg(test)] mod tests { - use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; + use crate::datasource::file_format::file_compression_type::{ + FileCompressionType, FileTypeExt, + }; use crate::error::DataFusionError; + use datafusion_common::file_options::file_type::FileType; use std::str::FromStr; #[test] @@ -325,10 +283,13 @@ mod tests { ); } + let mut ty_ext_tuple = vec![]; + ty_ext_tuple.push((FileType::AVRO, ".avro")); + #[cfg(feature = "parquet")] + ty_ext_tuple.push((FileType::PARQUET, ".parquet")); + // Cannot specify compression for these file types - for (file_type, extension) in - [(FileType::AVRO, ".avro"), (FileType::PARQUET, ".parquet")] - { + for (file_type, extension) in ty_ext_tuple { assert_eq!( file_type .get_ext_with_compression(FileCompressionType::UNCOMPRESSED) @@ -351,24 +312,6 @@ mod tests { #[test] fn from_str() { - for (ext, file_type) in [ - ("csv", FileType::CSV), - ("CSV", FileType::CSV), - ("json", FileType::JSON), - ("JSON", FileType::JSON), - ("avro", FileType::AVRO), - ("AVRO", FileType::AVRO), - ("parquet", FileType::PARQUET), - ("PARQUET", FileType::PARQUET), - ] { - assert_eq!(FileType::from_str(ext).unwrap(), file_type); - } - - assert!(matches!( - FileType::from_str("Unknown"), - Err(DataFusionError::NotImplemented(_)) - )); - for (ext, compression_type) in [ ("gz", FileCompressionType::GZIP), ("GZ", FileCompressionType::GZIP), diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 6247e85ba8793..9893a1db45de9 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -18,32 +18,46 @@ //! Line delimited JSON format abstractions use std::any::Any; - +use std::fmt; +use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; +use super::{FileFormat, FileScanConfig}; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; +use arrow::json; use arrow::json::reader::infer_json_schema_from_iterator; use arrow::json::reader::ValueIter; +use arrow_array::RecordBatch; use async_trait::async_trait; use bytes::Buf; +use bytes::Bytes; use datafusion_physical_expr::PhysicalExpr; -use object_store::{GetResult, ObjectMeta, ObjectStore}; +use datafusion_physical_expr::PhysicalSortRequirement; +use datafusion_physical_plan::ExecutionPlan; +use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; + +use crate::datasource::physical_plan::FileGroupDisplay; +use crate::physical_plan::insert::DataSink; +use crate::physical_plan::insert::FileSinkExec; +use crate::physical_plan::SendableRecordBatchStream; +use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; -use super::FileFormat; -use super::FileScanConfig; -use crate::datasource::file_format::file_type::FileCompressionType; +use super::write::orchestration::stateless_multipart_put; + +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; -use crate::datasource::physical_plan::NdJsonExec; +use crate::datasource::physical_plan::{FileSinkConfig, NdJsonExec}; use crate::error::Result; use crate::execution::context::SessionState; -use crate::physical_plan::ExecutionPlan; -use crate::physical_plan::Statistics; -/// The default file extension of json files -pub const DEFAULT_JSON_EXTENSION: &str = ".json"; +use datafusion_common::{not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::MetricsSet; + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug)] pub struct JsonFormat { @@ -103,14 +117,15 @@ impl FileFormat for JsonFormat { should_take }; - let schema = match store.get(&object.location).await? { - GetResult::File(file, _) => { + let r = store.as_ref().get(&object.location).await?; + let schema = match r.payload { + GetResultPayload::File(file, _) => { let decoder = file_compression_type.convert_read(file)?; let mut reader = BufReader::new(decoder); let iter = ValueIter::new(&mut reader, None); infer_json_schema_from_iterator(iter.take_while(|_| take_while()))? } - r @ GetResult::Stream(_) => { + GetResultPayload::Stream(_) => { let data = r.bytes().await?; let decoder = file_compression_type.convert_read(data.reader())?; let mut reader = BufReader::new(decoder); @@ -133,10 +148,10 @@ impl FileFormat for JsonFormat { &self, _state: &SessionState, _store: &Arc, - _table_schema: SchemaRef, + table_schema: SchemaRef, _object: &ObjectMeta, ) -> Result { - Ok(Statistics::default()) + Ok(Statistics::new_unknown(&table_schema)) } async fn create_physical_plan( @@ -148,12 +163,157 @@ impl FileFormat for JsonFormat { let exec = NdJsonExec::new(conf, self.file_compression_type.to_owned()); Ok(Arc::new(exec)) } + + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Json"); + } + + if self.file_compression_type != FileCompressionType::UNCOMPRESSED { + return not_impl_err!("Inserting compressed JSON is not implemented yet."); + } + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(JsonSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + + fn file_type(&self) -> FileType { + FileType::JSON + } +} + +impl Default for JsonSerializer { + fn default() -> Self { + Self::new() + } +} + +/// Define a struct for serializing Json records to a stream +pub struct JsonSerializer { + // Inner buffer for avoiding reallocation + buffer: Vec, +} + +impl JsonSerializer { + /// Constructor for the JsonSerializer object + pub fn new() -> Self { + Self { + buffer: Vec::with_capacity(4096), + } + } +} + +#[async_trait] +impl BatchSerializer for JsonSerializer { + async fn serialize(&mut self, batch: RecordBatch) -> Result { + let mut writer = json::LineDelimitedWriter::new(&mut self.buffer); + writer.write(&batch)?; + //drop(writer); + Ok(Bytes::from(self.buffer.drain(..).collect::>())) + } + + fn duplicate(&mut self) -> Result> { + Ok(Box::new(JsonSerializer::new())) + } +} + +/// Implements [`DataSink`] for writing to a Json file. +pub struct JsonSink { + /// Config options for writing data + config: FileSinkConfig, +} + +impl Debug for JsonSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JsonSink").finish() + } +} + +impl DisplayAs for JsonSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "JsonSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +impl JsonSink { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn multipartput_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let writer_options = self.config.file_type_writer_options.try_into_json()?; + let compression = &writer_options.compression; + + let get_serializer = move || { + let serializer: Box = Box::new(JsonSerializer::new()); + serializer + }; + + stateless_multipart_put( + data, + context, + "json".into(), + Box::new(get_serializer), + &self.config, + (*compression).into(), + ) + .await + } +} + +#[async_trait] +impl DataSink for JsonSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let total_count = self.multipartput_all(data, context).await?; + Ok(total_count) + } } #[cfg(test)] mod tests { use super::super::test_util::scan_format; use datafusion_common::cast::as_int64_array; + use datafusion_common::stats::Precision; use futures::StreamExt; use object_store::local::LocalFileSystem; @@ -165,7 +325,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; @@ -184,8 +344,8 @@ mod tests { assert_eq!(tt_batches, 6 /* 12/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index a6848b0d122d5..7c2331548e5ee 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -16,6 +16,7 @@ // under the License. //! Module containing helper methods for the various file formats +//! See write.rs for write related helper methods /// Default max records to scan to infer the schema pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; @@ -23,17 +24,16 @@ pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; pub mod arrow; pub mod avro; pub mod csv; -pub mod file_type; +pub mod file_compression_type; pub mod json; pub mod options; +#[cfg(feature = "parquet")] pub mod parquet; +pub mod write; use std::any::Any; -use std::io::Error; -use std::pin::Pin; +use std::fmt; use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{fmt, mem}; use crate::arrow::datatypes::SchemaRef; use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -41,23 +41,17 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_array::RecordBatch; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::{not_impl_err, DataFusionError, FileType}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; -use bytes::Bytes; -use futures::future::BoxFuture; -use futures::ready; -use futures::FutureExt; -use object_store::path::Path; -use object_store::{MultipartId, ObjectMeta, ObjectStore}; -use tokio::io::AsyncWrite; +use object_store::{ObjectMeta, ObjectStore}; + /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the the same file formats. /// -/// [`TableProvider`]: crate::datasource::datasource::TableProvider +/// [`TableProvider`]: crate::datasource::provider::TableProvider #[async_trait] pub trait FileFormat: Send + Sync + fmt::Debug { /// Returns the table provider as [`Any`](std::any::Any) so that it can be @@ -106,211 +100,13 @@ pub trait FileFormat: Send + Sync + fmt::Debug { _input: Arc, _state: &SessionState, _conf: FileSinkConfig, + _order_requirements: Option>, ) -> Result> { - let msg = "Writer not implemented for this format".to_owned(); - Err(DataFusionError::NotImplemented(msg)) - } -} - -/// `AsyncPutWriter` is an object that facilitates asynchronous writing to object stores. -/// It is specifically designed for the `object_store` crate's `put` method and sends -/// whole bytes at once when the buffer is flushed. -pub struct AsyncPutWriter { - /// Object metadata - object_meta: ObjectMeta, - /// A shared reference to the object store - store: Arc, - /// A buffer that stores the bytes to be sent - current_buffer: Vec, - /// Used for async handling in flush method - inner_state: AsyncPutState, -} - -impl AsyncPutWriter { - /// Constructor for the `AsyncPutWriter` object - pub fn new(object_meta: ObjectMeta, store: Arc) -> Self { - Self { - object_meta, - store, - current_buffer: vec![], - // The writer starts out in buffering mode - inner_state: AsyncPutState::Buffer, - } - } - - /// Separate implementation function that unpins the [`AsyncPutWriter`] so - /// that partial borrows work correctly - fn poll_shutdown_inner( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - loop { - match &mut self.inner_state { - AsyncPutState::Buffer => { - // Convert the current buffer to bytes and take ownership of it - let bytes = Bytes::from(mem::take(&mut self.current_buffer)); - // Set the inner state to Put variant with the bytes - self.inner_state = AsyncPutState::Put { bytes } - } - AsyncPutState::Put { bytes } => { - // Send the bytes to the object store's put method - return Poll::Ready( - ready!(self - .store - .put(&self.object_meta.location, bytes.clone()) - .poll_unpin(cx)) - .map_err(Error::from), - ); - } - } - } - } -} - -/// An enum that represents the inner state of AsyncPut -enum AsyncPutState { - /// Building Bytes struct in this state - Buffer, - /// Data in the buffer is being sent to the object store - Put { bytes: Bytes }, -} - -impl AsyncWrite for AsyncPutWriter { - // Define the implementation of the AsyncWrite trait for the `AsyncPutWriter` struct - fn poll_write( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // Extend the current buffer with the incoming buffer - self.current_buffer.extend_from_slice(buf); - // Return a ready poll with the length of the incoming buffer - Poll::Ready(Ok(buf.len())) + not_impl_err!("Writer not implemented for this format") } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - // Return a ready poll with an empty result - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - // Call the poll_shutdown_inner method to handle the actual sending of data to the object store - self.poll_shutdown_inner(cx) - } -} - -/// Stores data needed during abortion of MultiPart writers -pub(crate) struct MultiPart { - /// A shared reference to the object store - store: Arc, - multipart_id: MultipartId, - location: Path, -} - -impl MultiPart { - /// Create a new `MultiPart` - pub fn new( - store: Arc, - multipart_id: MultipartId, - location: Path, - ) -> Self { - Self { - store, - multipart_id, - location, - } - } -} - -pub(crate) enum AbortMode { - Put, - Append, - MultiPart(MultiPart), -} - -/// A wrapper struct with abort method and writer -struct AbortableWrite { - writer: W, - mode: AbortMode, -} - -impl AbortableWrite { - /// Create a new `AbortableWrite` instance with the given writer, and write mode. - fn new(writer: W, mode: AbortMode) -> Self { - Self { writer, mode } - } - - /// handling of abort for different write modes - fn abort_writer(&self) -> Result>> { - match &self.mode { - AbortMode::Put => Ok(async { Ok(()) }.boxed()), - AbortMode::Append => Err(DataFusionError::Execution( - "Cannot abort in append mode".to_string(), - )), - AbortMode::MultiPart(MultiPart { - store, - multipart_id, - location, - }) => { - let location = location.clone(); - let multipart_id = multipart_id.clone(); - let store = store.clone(); - Ok(Box::pin(async move { - store - .abort_multipart(&location, &multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } - } - } -} - -impl AsyncWrite for AbortableWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - -/// An enum that defines different file writer modes. -#[derive(Debug, Clone, Copy)] -pub enum FileWriterMode { - /// Data is appended to an existing file. - Append, - /// Data is written to a new file. - Put, - /// Data is written to a new file in multiple parts. - PutMultipart, -} -/// A trait that defines the methods required for a RecordBatch serializer. -#[async_trait] -pub trait BatchSerializer: Unpin + Send { - /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&mut self, batch: RecordBatch) -> Result; + /// Returns the FileType corresponding to this FileFormat + fn file_type(&self) -> FileType; } #[cfg(test)] @@ -327,7 +123,10 @@ pub(crate) mod test_util { use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::{GetOptions, GetResult, ListResult, MultipartId}; + use object_store::{ + GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, PutOptions, + PutResult, + }; use tokio::io::AsyncWrite; pub async fn scan_format( @@ -391,7 +190,12 @@ pub(crate) mod test_util { #[async_trait] impl ObjectStore for VariableStream { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -411,18 +215,29 @@ pub(crate) mod test_util { unimplemented!() } - async fn get(&self, _location: &Path) -> object_store::Result { + async fn get(&self, location: &Path) -> object_store::Result { let bytes = self.bytes_to_repeat.clone(); + let range = 0..bytes.len() * self.max_iterations; let arc = self.iterations_detected.clone(); - Ok(GetResult::Stream( - futures::stream::repeat_with(move || { - let arc_inner = arc.clone(); - *arc_inner.lock().unwrap() += 1; - Ok(bytes.clone()) - }) - .take(self.max_iterations) - .boxed(), - )) + let stream = futures::stream::repeat_with(move || { + let arc_inner = arc.clone(); + *arc_inner.lock().unwrap() += 1; + Ok(bytes.clone()) + }) + .take(self.max_iterations) + .boxed(); + + Ok(GetResult { + payload: GetResultPayload::Stream(stream), + meta: ObjectMeta { + location: location.clone(), + last_modified: Default::default(), + size: range.end, + e_tag: None, + version: None, + }, + range: Default::default(), + }) } async fn get_opts( @@ -449,11 +264,10 @@ pub(crate) mod test_util { unimplemented!() } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { + ) -> BoxStream<'_, object_store::Result> { unimplemented!() } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 3e802362d3ae0..4c7557a4a9c06 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -21,24 +21,25 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::DataFusionError; - -use crate::datasource::file_format::arrow::{ArrowFormat, DEFAULT_ARROW_EXTENSION}; -use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION; -use crate::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; -use crate::datasource::file_format::file_type::FileCompressionType; -use crate::datasource::file_format::json::DEFAULT_JSON_EXTENSION; -use crate::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; +use datafusion_common::{plan_err, DataFusionError}; + +use crate::datasource::file_format::arrow::ArrowFormat; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::listing::ListingTableUrl; use crate::datasource::{ - file_format::{ - avro::AvroFormat, csv::CsvFormat, json::JsonFormat, parquet::ParquetFormat, - }, + file_format::{avro::AvroFormat, csv::CsvFormat, json::JsonFormat}, listing::ListingOptions, }; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; +use crate::logical_expr::Expr; +use datafusion_common::{ + DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, +}; /// Options that control the reading of CSV files. /// @@ -55,6 +56,10 @@ pub struct CsvReadOptions<'a> { pub has_header: bool, /// An optional column delimiter. Defaults to `b','`. pub delimiter: u8, + /// An optional quote character. Defaults to `b'"'`. + pub quote: u8, + /// An optional escape character. Defaults to None. + pub escape: Option, /// An optional schema representing the CSV files. If None, CSV reader will try to infer it /// based on data in file. pub schema: Option<&'a Schema>, @@ -69,6 +74,8 @@ pub struct CsvReadOptions<'a> { pub file_compression_type: FileCompressionType, /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, + /// Indicates how the file is sorted + pub file_sort_order: Vec>, } impl<'a> Default for CsvReadOptions<'a> { @@ -85,10 +92,13 @@ impl<'a> CsvReadOptions<'a> { schema: None, schema_infer_max_records: DEFAULT_SCHEMA_INFER_MAX_RECORD, delimiter: b',', + quote: b'"', + escape: None, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, + file_sort_order: vec![], } } @@ -110,6 +120,18 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specify quote to use for CSV read + pub fn quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// Specify delimiter to use for CSV read + pub fn escape(mut self, escape: u8) -> Self { + self.escape = Some(escape); + self + } + /// Specify the file extension for CSV file selection pub fn file_extension(mut self, file_extension: &'a str) -> Self { self.file_extension = file_extension; @@ -153,6 +175,12 @@ impl<'a> CsvReadOptions<'a> { self.file_compression_type = file_compression_type; self } + + /// Configure if file has known sort order + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = file_sort_order; + self + } } /// Options that control the reading of Parquet files. @@ -177,6 +205,11 @@ pub struct ParquetReadOptions<'a> { /// /// If None specified, uses value in SessionConfig pub skip_metadata: Option, + /// An optional schema representing the parquet files. If None, parquet reader will try to infer it + /// based on data in file. + pub schema: Option<&'a Schema>, + /// Indicates how the file is sorted + pub file_sort_order: Vec>, } impl<'a> Default for ParquetReadOptions<'a> { @@ -186,6 +219,8 @@ impl<'a> Default for ParquetReadOptions<'a> { table_partition_cols: vec![], parquet_pruning: None, skip_metadata: None, + schema: None, + file_sort_order: vec![], } } } @@ -205,6 +240,12 @@ impl<'a> ParquetReadOptions<'a> { self } + /// Specify schema to use for parquet read + pub fn schema(mut self, schema: &'a Schema) -> Self { + self.schema = Some(schema); + self + } + /// Specify table_partition_cols for partition pruning pub fn table_partition_cols( mut self, @@ -213,6 +254,12 @@ impl<'a> ParquetReadOptions<'a> { self.table_partition_cols = table_partition_cols; self } + + /// Configure if file has known sort order + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = file_sort_order; + self + } } /// Options that control the reading of ARROW files. @@ -336,6 +383,8 @@ pub struct NdJsonReadOptions<'a> { pub file_compression_type: FileCompressionType, /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, + /// Indicates how the file is sorted + pub file_sort_order: Vec>, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -347,6 +396,7 @@ impl<'a> Default for NdJsonReadOptions<'a> { table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, + file_sort_order: vec![], } } } @@ -387,6 +437,12 @@ impl<'a> NdJsonReadOptions<'a> { self.schema = Some(schema); self } + + /// Configure if file has known sort order + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = file_sort_order; + self + } } #[async_trait] @@ -421,10 +477,9 @@ pub trait ReadOptions<'a> { .to_listing_options(config) .infer_schema(&state, &table_path) .await?), - (None, true) => Err(DataFusionError::Plan( - "Schema inference for infinite data sources is not supported." - .to_string(), - )), + (None, true) => { + plan_err!("Schema inference for infinite data sources is not supported.") + } } } } @@ -435,6 +490,8 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { let file_format = CsvFormat::default() .with_has_header(self.has_header) .with_delimiter(self.delimiter) + .with_quote(self.quote) + .with_escape(self.escape) .with_schema_infer_max_rec(Some(self.schema_infer_max_records)) .with_file_compression_type(self.file_compression_type.to_owned()); @@ -442,8 +499,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) - // TODO: Add file sort order into CsvReadOptions and introduce here. - .with_file_sort_order(vec![]) + .with_file_sort_order(self.file_sort_order.clone()) .with_infinite_source(self.infinite) } @@ -458,6 +514,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { } } +#[cfg(feature = "parquet")] #[async_trait] impl ReadOptions<'_> for ParquetReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { @@ -469,6 +526,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { .with_file_extension(self.file_extension) .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) + .with_file_sort_order(self.file_sort_order.clone()) } async fn get_resolved_schema( @@ -477,11 +535,8 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { state: SessionState, table_path: ListingTableUrl, ) -> Result { - // with parquet we resolve the schema in all cases - Ok(self - .to_listing_options(config) - .infer_schema(&state, &table_path) - .await?) + self._get_resolved_schema(config, state, table_path, self.schema, false) + .await } } @@ -489,6 +544,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { impl ReadOptions<'_> for NdJsonReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { let file_format = JsonFormat::default() + .with_schema_infer_max_rec(Some(self.schema_infer_max_records)) .with_file_compression_type(self.file_compression_type.to_owned()); ListingOptions::new(Arc::new(file_format)) @@ -496,6 +552,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_infinite_source(self.infinite) + .with_file_sort_order(self.file_sort_order.clone()) } async fn get_resolved_schema( @@ -512,7 +569,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { #[async_trait] impl ReadOptions<'_> for AvroReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { - let file_format = AvroFormat::default(); + let file_format = AvroFormat; ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) @@ -535,7 +592,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { #[async_trait] impl ReadOptions<'_> for ArrowReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { - let file_format = ArrowFormat::default(); + let file_format = ArrowFormat; ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index afdf9ab7a781e..41fc4c56f90c9 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -17,44 +17,64 @@ //! Parquet format abstractions +use arrow_array::RecordBatch; +use async_trait::async_trait; +use datafusion_common::stats::Precision; +use datafusion_physical_plan::metrics::MetricsSet; +use parquet::arrow::arrow_writer::{ + compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, + ArrowLeafColumn, +}; +use parquet::file::writer::SerializedFileWriter; use std::any::Any; +use std::fmt; +use std::fmt::Debug; +use std::io::Write; use std::sync::Arc; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{self, Receiver, Sender}; +use tokio::task::{JoinHandle, JoinSet}; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{Fields, Schema}; -use async_trait::async_trait; use bytes::{BufMut, BytesMut}; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::PhysicalExpr; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use futures::{StreamExt, TryStreamExt}; use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -use parquet::arrow::parquet_to_arrow_schema; +use parquet::arrow::{ + arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, +}; use parquet::file::footer::{decode_footer, decode_metadata}; use parquet::file::metadata::ParquetMetaData; +use parquet::file::properties::WriterProperties; use parquet::file::statistics::Statistics as ParquetStatistics; -use super::FileFormat; -use super::FileScanConfig; +use super::write::demux::start_demuxer_task; +use super::write::{create_writer, AbortableWrite}; +use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, }; use crate::arrow::datatypes::DataType; use crate::config::ConfigOptions; -use crate::datasource::physical_plan::{ParquetExec, SchemaAdapter}; -use crate::datasource::{create_max_min_accs, get_col_stats}; +use crate::datasource::physical_plan::{ + FileGroupDisplay, FileSinkConfig, ParquetExec, SchemaAdapter, +}; use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; -use crate::physical_plan::{Accumulator, ExecutionPlan, Statistics}; - -/// The default file extension of parquet files -pub const DEFAULT_PARQUET_EXTENSION: &str = ".parquet"; - -/// The number of files to read in parallel when inferring schema -const SCHEMA_INFERENCE_CONCURRENCY: usize = 32; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; +use crate::physical_plan::{ + Accumulator, DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, + Statistics, +}; /// The Apache Parquet `FileFormat` implementation /// @@ -175,7 +195,7 @@ impl FileFormat for ParquetFormat { ) }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 - .buffered(SCHEMA_INFERENCE_CONCURRENCY) + .buffered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; @@ -232,6 +252,32 @@ impl FileFormat for ParquetFormat { self.metadata_size_hint(state.config_options()), ))) } + + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + if conf.overwrite { + return not_impl_err!("Overwrites are not implemented yet for Parquet"); + } + + let sink_schema = conf.output_schema().clone(); + let sink = Arc::new(ParquetSink::new(conf)); + + Ok(Arc::new(FileSinkExec::new( + input, + sink, + sink_schema, + order_requirements, + )) as _) + } + + fn file_type(&self) -> FileType { + FileType::PARQUET + } } fn summarize_min_max( @@ -410,10 +456,7 @@ pub async fn fetch_parquet_metadata( size_hint: Option, ) -> Result { if meta.size < 8 { - return Err(DataFusionError::Execution(format!( - "file size of {} is less than footer", - meta.size - ))); + return exec_err!("file size of {} is less than footer", meta.size); } // If a size hint is provided, read more than the minimum size @@ -436,11 +479,11 @@ pub async fn fetch_parquet_metadata( let length = decode_footer(&footer)?; if meta.size < length + 8 { - return Err(DataFusionError::Execution(format!( + return exec_err!( "file size of {} is less than footer + metadata {}", meta.size, length + 8 - ))); + ); } // Did not fetch the entire file metadata in the initial read, need to make a second request @@ -500,7 +543,7 @@ async fn fetch_statistics( let mut num_rows = 0; let mut total_byte_size = 0; - let mut null_counts = vec![0; num_fields]; + let mut null_counts = vec![Precision::Exact(0); num_fields]; let mut has_statistics = false; let schema_adapter = SchemaAdapter::new(table_schema.clone()); @@ -526,7 +569,7 @@ async fn fetch_statistics( schema_adapter.map_column_index(table_idx, &file_schema) { if let Some((null_count, stats)) = column_stats.get(&file_idx) { - *null_cnt += *null_count as usize; + *null_cnt = null_cnt.add(&Precision::Exact(*null_count as usize)); summarize_min_max( &mut max_values, &mut min_values, @@ -540,33 +583,549 @@ async fn fetch_statistics( min_values[table_idx] = None; } } else { - *null_cnt += num_rows as usize; + *null_cnt = null_cnt.add(&Precision::Exact(num_rows as usize)); } } } } let column_stats = if has_statistics { - Some(get_col_stats( - &table_schema, - null_counts, - &mut max_values, - &mut min_values, - )) + get_col_stats(&table_schema, null_counts, &mut max_values, &mut min_values) } else { - None + Statistics::unknown_column(&table_schema) }; let statistics = Statistics { - num_rows: Some(num_rows as usize), - total_byte_size: Some(total_byte_size as usize), + num_rows: Precision::Exact(num_rows as usize), + total_byte_size: Precision::Exact(total_byte_size as usize), column_statistics: column_stats, - is_exact: true, }; Ok(statistics) } +/// Implements [`DataSink`] for writing to a parquet file. +struct ParquetSink { + /// Config options for writing data + config: FileSinkConfig, +} + +impl Debug for ParquetSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetSink").finish() + } +} + +impl DisplayAs for ParquetSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ParquetSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_groups).fmt_as(t, f)?; + write!(f, ")") + } + } + } +} + +impl ParquetSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } + + /// Converts table schema to writer schema, which may differ in the case + /// of hive style partitioning where some columns are removed from the + /// underlying files. + fn get_writer_schema(&self) -> Arc { + if !self.config.table_partition_cols.is_empty() { + let schema = self.config.output_schema(); + let partition_names: Vec<_> = self + .config + .table_partition_cols + .iter() + .map(|(s, _)| s) + .collect(); + Arc::new(Schema::new( + schema + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + )) + } else { + self.config.output_schema().clone() + } + } + + /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore + /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized + async fn create_async_arrow_writer( + &self, + location: &Path, + object_store: Arc, + parquet_props: WriterProperties, + ) -> Result< + AsyncArrowWriter>, + > { + let (_, multipart_writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + let writer = AsyncArrowWriter::try_new( + multipart_writer, + self.get_writer_schema(), + 10485760, + Some(parquet_props), + )?; + Ok(writer) + } +} + +#[async_trait] +impl DataSink for ParquetSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + let parquet_props = self + .config + .file_type_writer_options + .try_into_parquet()? + .writer_options(); + + let object_store = context + .runtime_env() + .object_store(&self.config.object_store_url)?; + + let parquet_opts = &context.session_config().options().execution.parquet; + let allow_single_file_parallelism = parquet_opts.allow_single_file_parallelism; + + let part_col = if !self.config.table_partition_cols.is_empty() { + Some(self.config.table_partition_cols.clone()) + } else { + None + }; + + let parallel_options = ParallelParquetWriterOptions { + max_parallel_row_groups: parquet_opts.maximum_parallel_row_group_writers, + max_buffered_record_batches_per_stream: parquet_opts + .maximum_buffered_record_batches_per_stream, + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_col, + self.config.table_paths[0].clone(), + "parquet".into(), + self.config.single_file_output, + ); + + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + while let Some((path, mut rx)) = file_stream_rx.recv().await { + if !allow_single_file_parallelism { + let mut writer = self + .create_async_arrow_writer( + &path, + object_store.clone(), + parquet_props.clone(), + ) + .await?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + writer.write(&batch).await?; + } + writer.close().await?; + Ok(row_count) + }); + } else { + let writer = create_writer( + // Parquet files as a whole are never compressed, since they + // manage compressed blocks themselves. + FileCompressionType::UNCOMPRESSED, + &path, + object_store.clone(), + ) + .await?; + let schema = self.get_writer_schema(); + let props = parquet_props.clone(); + let parallel_options_clone = parallel_options.clone(); + file_write_tasks.spawn(async move { + output_single_parquet_file_parallelized( + writer, + rx, + schema, + &props, + parallel_options_clone, + ) + .await + }); + } + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + match demux_task.await { + Ok(r) => r?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + Ok(row_count as u64) + } +} + +/// Consumes a stream of [ArrowLeafColumn] via a channel and serializes them using an [ArrowColumnWriter] +/// Once the channel is exhausted, returns the ArrowColumnWriter. +async fn column_serializer_task( + mut rx: Receiver, + mut writer: ArrowColumnWriter, +) -> Result { + while let Some(col) = rx.recv().await { + writer.write(&col)?; + } + Ok(writer) +} + +type ColumnJoinHandle = JoinHandle>; +type ColSender = Sender; +/// Spawns a parallel serialization task for each column +/// Returns join handles for each columns serialization task along with a send channel +/// to send arrow arrays to each serialization task. +fn spawn_column_parallel_row_group_writer( + schema: Arc, + parquet_props: Arc, + max_buffer_size: usize, +) -> Result<(Vec, Vec)> { + let schema_desc = arrow_to_parquet_schema(&schema)?; + let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; + let num_columns = col_writers.len(); + + let mut col_writer_handles = Vec::with_capacity(num_columns); + let mut col_array_channels = Vec::with_capacity(num_columns); + for writer in col_writers.into_iter() { + // Buffer size of this channel limits the number of arrays queued up for column level serialization + let (send_array, recieve_array) = + mpsc::channel::(max_buffer_size); + col_array_channels.push(send_array); + col_writer_handles + .push(tokio::spawn(column_serializer_task(recieve_array, writer))) + } + + Ok((col_writer_handles, col_array_channels)) +} + +/// Settings related to writing parquet files in parallel +#[derive(Clone)] +struct ParallelParquetWriterOptions { + max_parallel_row_groups: usize, + max_buffered_record_batches_per_stream: usize, +} + +/// This is the return type of calling [ArrowColumnWriter].close() on each column +/// i.e. the Vec of encoded columns which can be appended to a row group +type RBStreamSerializeResult = Result<(Vec, usize)>; + +/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective +/// parallel column serializers. +async fn send_arrays_to_col_writers( + col_array_channels: &[ColSender], + rb: &RecordBatch, + schema: Arc, +) -> Result<()> { + for (tx, array, field) in col_array_channels + .iter() + .zip(rb.columns()) + .zip(schema.fields()) + .map(|((a, b), c)| (a, b, c)) + { + for c in compute_leaves(field, array)? { + tx.send(c).await.map_err(|_| { + DataFusionError::Internal("Unable to send array to writer!".into()) + })?; + } + } + + Ok(()) +} + +/// Spawns a tokio task which joins the parallel column writer tasks, +/// and finalizes the row group. +fn spawn_rg_join_and_finalize_task( + column_writer_handles: Vec>>, + rg_rows: usize, +) -> JoinHandle { + tokio::spawn(async move { + let num_cols = column_writer_handles.len(); + let mut finalized_rg = Vec::with_capacity(num_cols); + for handle in column_writer_handles.into_iter() { + match handle.await { + Ok(r) => { + let w = r?; + finalized_rg.push(w.close()?); + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()) + } else { + unreachable!() + } + } + } + } + + Ok((finalized_rg, rg_rows)) + }) +} + +/// This task coordinates the serialization of a parquet file in parallel. +/// As the query produces RecordBatches, these are written to a RowGroup +/// via parallel [ArrowColumnWriter] tasks. Once the desired max rows per +/// row group is reached, the parallel tasks are joined on another separate task +/// and sent to a concatenation task. This task immediately continues to work +/// on the next row group in parallel. So, parquet serialization is parallelized +/// accross both columns and row_groups, with a theoretical max number of parallel tasks +/// given by n_columns * num_row_groups. +fn spawn_parquet_parallel_serialization_task( + mut data: Receiver, + serialize_tx: Sender>, + schema: Arc, + writer_props: Arc, + parallel_options: ParallelParquetWriterOptions, +) -> JoinHandle> { + tokio::spawn(async move { + let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; + let max_row_group_rows = writer_props.max_row_group_size(); + let (mut column_writer_handles, mut col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + let mut current_rg_rows = 0; + + while let Some(rb) = data.recv().await { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) + .await?; + current_rg_rows += rb.num_rows(); + } else { + let rows_left = max_row_group_rows - current_rg_rows; + let a = rb.slice(0, rows_left); + send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) + .await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + max_row_group_rows, + ); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; + + let b = rb.slice(rows_left, rb.num_rows() - rows_left); + (column_writer_handles, col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + send_arrays_to_col_writers(&col_array_channels, &b, schema.clone()) + .await?; + current_rg_rows = b.num_rows(); + } + } + + drop(col_array_channels); + // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows + if current_rg_rows > 0 { + let finalize_rg_task = + spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; + } + + Ok(()) + }) +} + +/// Consume RowGroups serialized by other parallel tasks and concatenate them in +/// to the final parquet file, while flushing finalized bytes to an [ObjectStore] +async fn concatenate_parallel_row_groups( + mut serialize_rx: Receiver>, + schema: Arc, + writer_props: Arc, + mut object_store_writer: AbortableWrite>, +) -> Result { + let merged_buff = SharedBuffer::new(1048576); + + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; + let mut parquet_writer = SerializedFileWriter::new( + merged_buff.clone(), + schema_desc.root_schema_ptr(), + writer_props, + )?; + + let mut row_count = 0; + + while let Some(handle) = serialize_rx.recv().await { + let join_result = handle.await; + match join_result { + Ok(result) => { + let mut rg_out = parquet_writer.next_row_group()?; + let (serialized_columns, cnt) = result?; + row_count += cnt; + for chunk in serialized_columns { + chunk.append_to_row_group(&mut rg_out)?; + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + if buff_to_flush.len() > 1024000 { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + rg_out.close()?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + let inner_writer = parquet_writer.into_inner()?; + let final_buff = inner_writer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + + Ok(row_count) +} + +/// Parallelizes the serialization of a single parquet file, by first serializing N +/// independent RecordBatch streams in parallel to RowGroups in memory. Another +/// task then stitches these independent RowGroups together and streams this large +/// single parquet file to an ObjectStore in multiple parts. +async fn output_single_parquet_file_parallelized( + object_store_writer: AbortableWrite>, + data: Receiver, + output_schema: Arc, + parquet_props: &WriterProperties, + parallel_options: ParallelParquetWriterOptions, +) -> Result { + let max_rowgroups = parallel_options.max_parallel_row_groups; + // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel + let (serialize_tx, serialize_rx) = + mpsc::channel::>(max_rowgroups); + + let arc_props = Arc::new(parquet_props.clone()); + let launch_serialization_task = spawn_parquet_parallel_serialization_task( + data, + serialize_tx, + output_schema.clone(), + arc_props.clone(), + parallel_options, + ); + let row_count = concatenate_parallel_row_groups( + serialize_rx, + output_schema.clone(), + arc_props.clone(), + object_store_writer, + ) + .await?; + + match launch_serialization_task.await { + Ok(Ok(_)) => (), + Ok(Err(e)) => return Err(e), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()) + } else { + unreachable!() + } + } + }; + + Ok(row_count) +} + +/// A buffer with interior mutability shared by the SerializedFileWriter and +/// ObjectStore writer +#[derive(Clone)] +struct SharedBuffer { + /// The inner buffer for reading and writing + /// + /// The lock is used to obtain internal mutability, so no worry about the + /// lock contention. + buffer: Arc>>, +} + +impl SharedBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: Arc::new(futures::lock::Mutex::new(Vec::with_capacity(capacity))), + } + } +} + +impl Write for SharedBuffer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::write(&mut *buffer, buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut buffer = self.buffer.try_lock().unwrap(); + Write::flush(&mut *buffer) + } +} + #[cfg(test)] pub(crate) mod test_util { use super::*; @@ -622,7 +1181,7 @@ pub(crate) mod test_util { Ok((meta, files)) } - //// write batches chunk_size rows at a time + /// write batches chunk_size rows at a time fn write_in_chunks( writer: &mut ArrowWriter, batch: &RecordBatch, @@ -647,7 +1206,6 @@ mod tests { use super::*; use crate::datasource::file_format::parquet::test_util::store_parquet; - use crate::datasource::physical_plan::get_scan_files; use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; @@ -664,7 +1222,9 @@ mod tests { use log::error; use object_store::local::LocalFileSystem; use object_store::path::Path; - use object_store::{GetOptions, GetResult, ListResult, MultipartId}; + use object_store::{ + GetOptions, GetResult, ListResult, MultipartId, PutOptions, PutResult, + }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; @@ -742,7 +1302,12 @@ mod tests { #[async_trait] impl ObjectStore for RequestCountingObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { Err(object_store::Error::NotImplemented) } @@ -779,12 +1344,13 @@ mod tests { Err(object_store::Error::NotImplemented) } - async fn list( + fn list( &self, _prefix: Option<&Path>, - ) -> object_store::Result>> - { - Err(object_store::Error::NotImplemented) + ) -> BoxStream<'_, object_store::Result> { + Box::pin(futures::stream::once(async { + Err(object_store::Error::NotImplemented) + })) } async fn list_with_delimiter( @@ -842,11 +1408,11 @@ mod tests { fetch_statistics(store.upcast().as_ref(), schema.clone(), &meta[0], Some(9)) .await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(1)); - assert_eq!(c2_stats.null_count, Some(3)); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.null_count, Precision::Exact(3)); let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), @@ -875,11 +1441,11 @@ mod tests { ) .await?; - assert_eq!(stats.num_rows, Some(3)); - let c1_stats = &stats.column_statistics.as_ref().expect("missing c1 stats")[0]; - let c2_stats = &stats.column_statistics.as_ref().expect("missing c2 stats")[1]; - assert_eq!(c1_stats.null_count, Some(1)); - assert_eq!(c2_stats.null_count, Some(3)); + assert_eq!(stats.num_rows, Precision::Exact(3)); + let c1_stats = &stats.column_statistics[0]; + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.null_count, Precision::Exact(3)); let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), @@ -900,7 +1466,7 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session_ctx = SessionContext::with_config(config); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); let projection = None; @@ -919,8 +1485,8 @@ mod tests { assert_eq!(tt_batches, 4 /* 8/2 */); // test metadata - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } @@ -928,7 +1494,7 @@ mod tests { #[tokio::test] async fn capture_bytes_scanned_metric() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); - let session = SessionContext::with_config(config); + let session = SessionContext::new_with_config(config); let ctx = session.state(); // Read the full file @@ -961,9 +1527,8 @@ mod tests { get_exec(&state, "alltypes_plain.parquet", projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); - assert!(exec.statistics().is_exact); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -1205,10 +1770,13 @@ mod tests { let column = batches[0].column(0); assert_eq!(&DataType::Decimal128(13, 2), column.data_type()); - // parquet use the fixed length binary as the physical type to store decimal - // TODO: arrow-rs don't support convert the physical type of binary to decimal - // https://github.com/apache/arrow-rs/pull/2160 - // let exec = get_exec(&session_ctx, "byte_array_decimal.parquet", None, None).await?; + // parquet use the byte array as the physical type to store decimal + let exec = get_exec(&state, "byte_array_decimal.parquet", None, None).await?; + let batches = collect(exec, task_ctx.clone()).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + let column = batches[0].column(0); + assert_eq!(&DataType::Decimal128(4, 2), column.data_type()); Ok(()) } @@ -1239,25 +1807,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_get_scan_files() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let projection = Some(vec![9]); - let exec = get_exec(&state, "alltypes_plain.parquet", projection, None).await?; - let scan_files = get_scan_files(exec)?; - assert_eq!(scan_files.len(), 1); - assert_eq!(scan_files[0].len(), 1); - assert_eq!(scan_files[0][0].len(), 1); - assert!(scan_files[0][0][0] - .object_meta - .location - .to_string() - .contains("alltypes_plain.parquet")); - - Ok(()) - } - fn check_page_index_validation( page_index: Option<&ParquetColumnIndex>, offset_index: Option<&ParquetOffsetIndex>, @@ -1271,8 +1820,8 @@ mod tests { // there is only one row group in one file. assert_eq!(page_index.len(), 1); assert_eq!(offset_index.len(), 1); - let page_index = page_index.get(0).unwrap(); - let offset_index = offset_index.get(0).unwrap(); + let page_index = page_index.first().unwrap(); + let offset_index = offset_index.first().unwrap(); // 13 col in one row group assert_eq!(page_index.len(), 13); diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs new file mode 100644 index 0000000000000..fa4ed8437015d --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -0,0 +1,420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! dividing input stream into multiple output files at execution time + +use std::collections::HashMap; + +use std::sync::Arc; + +use crate::datasource::listing::ListingTableUrl; + +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::builder::UInt64Builder; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::cast::as_string_array; +use datafusion_common::DataFusionError; + +use datafusion_execution::TaskContext; + +use futures::StreamExt; +use object_store::path::Path; + +use rand::distributions::DistString; + +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; + +type RecordBatchReceiver = Receiver; +type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; + +/// Splits a single [SendableRecordBatchStream] into a dynamically determined +/// number of partitions at execution time. The partitions are determined by +/// factors known only at execution time, such as total number of rows and +/// partition column values. The demuxer task communicates to the caller +/// by sending channels over a channel. The inner channels send RecordBatches +/// which should be contained within the same output file. The outer channel +/// is used to send a dynamic number of inner channels, representing a dynamic +/// number of total output files. The caller is also responsible to monitor +/// the demux task for errors and abort accordingly. The single_file_ouput parameter +/// overrides all other settings to force only a single file to be written. +/// partition_by parameter will additionally split the input based on the unique +/// values of a specific column ``` +/// ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌──────▶ │ batch 1 ├────▶...──────▶│ Batch a │ │ Output File1│ +/// │ └───────────┘ └────────────┘ └─────────────┘ +/// │ +/// ┌──────────┐ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// ┌───────────┐ ┌────────────┐ │ │ ├──────▶ │ batch a+1├────▶...──────▶│ Batch b │ │ Output File2│ +/// │ batch 1 ├────▶...──────▶│ Batch N ├─────▶│ Demux ├────────┤ ... └───────────┘ └────────────┘ └─────────────┘ +/// └───────────┘ └────────────┘ │ │ │ +/// └──────────┘ │ ┌───────────┐ ┌────────────┐ ┌─────────────┐ +/// └──────▶ │ batch d ├────▶...──────▶│ Batch n │ │ Output FileN│ +/// └───────────┘ └────────────┘ └─────────────┘ +pub(crate) fn start_demuxer_task( + input: SendableRecordBatchStream, + context: &Arc, + partition_by: Option>, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> (JoinHandle>, DemuxedStreamReceiver) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let context = context.clone(); + let task: JoinHandle> = match partition_by { + Some(parts) => { + // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot + // bound this channel without risking a deadlock. + tokio::spawn(async move { + hive_style_partitions_demuxer( + tx, + input, + context, + parts, + base_output_path, + file_extension, + ) + .await + }) + } + None => tokio::spawn(async move { + row_count_demuxer( + tx, + input, + context, + base_output_path, + file_extension, + single_file_output, + ) + .await + }), + }; + + (task, rx) +} + +/// Dynamically partitions input stream to acheive desired maximum rows per file +async fn row_count_demuxer( + mut tx: UnboundedSender<(Path, Receiver)>, + mut input: SendableRecordBatchStream, + context: Arc, + base_output_path: ListingTableUrl, + file_extension: String, + single_file_output: bool, +) -> Result<()> { + let exec_options = &context.session_config().options().execution; + + let max_rows_per_file = exec_options.soft_max_rows_per_output_file; + let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; + let minimum_parallel_files = exec_options.minimum_parallel_output_files; + let mut part_idx = 0; + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let mut open_file_streams = Vec::with_capacity(minimum_parallel_files); + + let mut next_send_steam = 0; + let mut row_counts = Vec::with_capacity(minimum_parallel_files); + + // Overrides if single_file_output is set + let minimum_parallel_files = if single_file_output { + 1 + } else { + minimum_parallel_files + }; + + let max_rows_per_file = if single_file_output { + usize::MAX + } else { + max_rows_per_file + }; + + while let Some(rb) = input.next().await.transpose()? { + // ensure we have at least minimum_parallel_files open + if open_file_streams.len() < minimum_parallel_files { + open_file_streams.push(create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?); + row_counts.push(0); + part_idx += 1; + } else if row_counts[next_send_steam] >= max_rows_per_file { + row_counts[next_send_steam] = 0; + open_file_streams[next_send_steam] = create_new_file_stream( + &base_output_path, + &write_id, + part_idx, + &file_extension, + single_file_output, + max_buffered_batches, + &mut tx, + )?; + part_idx += 1; + } + row_counts[next_send_steam] += rb.num_rows(); + open_file_streams[next_send_steam] + .send(rb) + .await + .map_err(|_| { + DataFusionError::Execution( + "Error sending RecordBatch to file stream!".into(), + ) + })?; + + next_send_steam = (next_send_steam + 1) % minimum_parallel_files; + } + Ok(()) +} + +/// Helper for row count demuxer +fn generate_file_path( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, +) -> Path { + if !single_file_output { + base_output_path + .prefix() + .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + } else { + base_output_path.prefix().to_owned() + } +} + +/// Helper for row count demuxer +fn create_new_file_stream( + base_output_path: &ListingTableUrl, + write_id: &str, + part_idx: usize, + file_extension: &str, + single_file_output: bool, + max_buffered_batches: usize, + tx: &mut UnboundedSender<(Path, Receiver)>, +) -> Result> { + let file_path = generate_file_path( + base_output_path, + write_id, + part_idx, + file_extension, + single_file_output, + ); + let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); + tx.send((file_path, rx_file)).map_err(|_| { + DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) + })?; + Ok(tx_file) +} + +/// Splits an input stream based on the distinct values of a set of columns +/// Assumes standard hive style partition paths such as +/// /col1=val1/col2=val2/outputfile.parquet +async fn hive_style_partitions_demuxer( + tx: UnboundedSender<(Path, Receiver)>, + mut input: SendableRecordBatchStream, + context: Arc, + partition_by: Vec<(String, DataType)>, + base_output_path: ListingTableUrl, + file_extension: String, +) -> Result<()> { + let write_id = + rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + + let exec_options = &context.session_config().options().execution; + let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; + + // To support non string partition col types, cast the type to &str first + let mut value_map: HashMap, Sender> = HashMap::new(); + + while let Some(rb) = input.next().await.transpose()? { + // First compute partition key for each row of batch, e.g. (col1=val1, col2=val2, ...) + let all_partition_values = compute_partition_keys_by_row(&rb, &partition_by)?; + + // Next compute how the batch should be split up to take each distinct key to its own batch + let take_map = compute_take_arrays(&rb, all_partition_values); + + // Divide up the batch into distinct partition key batches and send each batch + for (part_key, mut builder) in take_map.into_iter() { + // Take method adapted from https://github.com/lancedb/lance/pull/1337/files + // TODO: upstream RecordBatch::take to arrow-rs + let take_indices = builder.finish(); + let struct_array: StructArray = rb.clone().into(); + let parted_batch = RecordBatch::from( + arrow::compute::take(&struct_array, &take_indices, None)?.as_struct(), + ); + + // Get or create channel for this batch + let part_tx = match value_map.get_mut(&part_key) { + Some(part_tx) => part_tx, + None => { + // Create channel for previously unseen distinct partition key and notify consumer of new file + let (part_tx, part_rx) = tokio::sync::mpsc::channel::( + max_buffered_recordbatches, + ); + let file_path = compute_hive_style_file_path( + &part_key, + &partition_by, + &write_id, + &file_extension, + &base_output_path, + ); + + tx.send((file_path, part_rx)).map_err(|_| { + DataFusionError::Execution( + "Error sending new file stream!".into(), + ) + })?; + + value_map.insert(part_key.clone(), part_tx); + value_map + .get_mut(&part_key) + .ok_or(DataFusionError::Internal( + "Key must exist since it was just inserted!".into(), + ))? + } + }; + + // remove partitions columns + let final_batch_to_send = + remove_partition_by_columns(&parted_batch, &partition_by)?; + + // Finally send the partial batch partitioned by distinct value! + part_tx.send(final_batch_to_send).await.map_err(|_| { + DataFusionError::Internal("Unexpected error sending parted batch!".into()) + })?; + } + } + + Ok(()) +} + +fn compute_partition_keys_by_row<'a>( + rb: &'a RecordBatch, + partition_by: &'a [(String, DataType)], +) -> Result>> { + let mut all_partition_values = vec![]; + + for (col, dtype) in partition_by.iter() { + let mut partition_values = vec![]; + let col_array = + rb.column_by_name(col) + .ok_or(DataFusionError::Execution(format!( + "PartitionBy Column {} does not exist in source data!", + col + )))?; + + match dtype { + DataType::Utf8 => { + let array = as_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(array.value(i)); + } + } + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + col_array => { + let array = col_array.downcast_dict::() + .ok_or(DataFusionError::Execution(format!("it is not yet supported to write to hive partitions with datatype {}", + dtype)))?; + + for val in array.values() { + partition_values.push( + val.ok_or(DataFusionError::Execution(format!("Cannot partition by null value for column {}", col)))? + ); + } + }, + _ => unreachable!(), + ) + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ))) + } + } + + all_partition_values.push(partition_values); + } + + Ok(all_partition_values) +} + +fn compute_take_arrays( + rb: &RecordBatch, + all_partition_values: Vec>, +) -> HashMap, UInt64Builder> { + let mut take_map = HashMap::new(); + for i in 0..rb.num_rows() { + let mut part_key = vec![]; + for vals in all_partition_values.iter() { + part_key.push(vals[i].to_owned()); + } + let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + builder.append_value(i as u64); + } + take_map +} + +fn remove_partition_by_columns( + parted_batch: &RecordBatch, + partition_by: &Vec<(String, DataType)>, +) -> Result { + let end_idx = parted_batch.num_columns() - partition_by.len(); + let non_part_cols = &parted_batch.columns()[..end_idx]; + + let partition_names: Vec<_> = partition_by.iter().map(|(s, _)| s).collect(); + let non_part_schema = Schema::new( + parted_batch + .schema() + .fields() + .iter() + .filter(|f| !partition_names.contains(&f.name())) + .map(|f| (**f).clone()) + .collect::>(), + ); + let final_batch_to_send = + RecordBatch::try_new(Arc::new(non_part_schema), non_part_cols.into())?; + + Ok(final_batch_to_send) +} + +fn compute_hive_style_file_path( + part_key: &Vec, + partition_by: &[(String, DataType)], + write_id: &str, + file_extension: &str, + base_output_path: &ListingTableUrl, +) -> Path { + let mut file_path = base_output_path.prefix().clone(); + for j in 0..part_key.len() { + file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); + } + + file_path.child(format!("{}.{}", write_id, file_extension)) +} diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs new file mode 100644 index 0000000000000..cfcdbd8c464ec --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to enabling +//! write support for the various file formats + +use std::io::Error; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::datasource::file_format::file_compression_type::FileCompressionType; + +use crate::error::Result; + +use arrow_array::RecordBatch; + +use datafusion_common::DataFusionError; + +use async_trait::async_trait; +use bytes::Bytes; + +use futures::future::BoxFuture; +use object_store::path::Path; +use object_store::{MultipartId, ObjectStore}; + +use tokio::io::AsyncWrite; + +pub(crate) mod demux; +pub(crate) mod orchestration; + +/// Stores data needed during abortion of MultiPart writers +#[derive(Clone)] +pub(crate) struct MultiPart { + /// A shared reference to the object store + store: Arc, + multipart_id: MultipartId, + location: Path, +} + +impl MultiPart { + /// Create a new `MultiPart` + pub fn new( + store: Arc, + multipart_id: MultipartId, + location: Path, + ) -> Self { + Self { + store, + multipart_id, + location, + } + } +} + +/// A wrapper struct with abort method and writer +pub(crate) struct AbortableWrite { + writer: W, + multipart: MultiPart, +} + +impl AbortableWrite { + /// Create a new `AbortableWrite` instance with the given writer, and write mode. + pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { + Self { writer, multipart } + } + + /// handling of abort for different write modes + pub(crate) fn abort_writer(&self) -> Result>> { + let multi = self.multipart.clone(); + Ok(Box::pin(async move { + multi + .store + .abort_multipart(&multi.location, &multi.multipart_id) + .await + .map_err(DataFusionError::ObjectStore) + })) + } +} + +impl AsyncWrite for AbortableWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) + } +} + +/// A trait that defines the methods required for a RecordBatch serializer. +#[async_trait] +pub trait BatchSerializer: Unpin + Send { + /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. + async fn serialize(&mut self, batch: RecordBatch) -> Result; + /// Duplicates self to support serializing multiple batches in parallel on multiple cores + fn duplicate(&mut self) -> Result> { + Err(DataFusionError::NotImplemented( + "Parallel serialization is not implemented for this file type".into(), + )) + } +} + +/// Returns an [`AbortableWrite`] which writes to the given object store location +/// with the specified compression +pub(crate) async fn create_writer( + file_compression_type: FileCompressionType, + location: &Path, + object_store: Arc, +) -> Result>> { + let (multipart_id, writer) = object_store + .put_multipart(location) + .await + .map_err(DataFusionError::ObjectStore)?; + Ok(AbortableWrite::new( + file_compression_type.convert_async_writer(writer)?, + MultiPart::new(object_store, multipart_id, location.clone()), + )) +} diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs new file mode 100644 index 0000000000000..2ae6b70ed1c5a --- /dev/null +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module containing helper methods/traits related to +//! orchestrating file serialization, streaming to object store, +//! parallelization, and abort handling + +use std::sync::Arc; + +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::physical_plan::FileSinkConfig; +use crate::error::Result; +use crate::physical_plan::SendableRecordBatchStream; + +use arrow_array::RecordBatch; + +use datafusion_common::DataFusionError; + +use bytes::Bytes; +use datafusion_execution::TaskContext; + +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{self, Receiver}; +use tokio::task::{JoinHandle, JoinSet}; +use tokio::try_join; + +use super::demux::start_demuxer_task; +use super::{create_writer, AbortableWrite, BatchSerializer}; + +type WriterType = AbortableWrite>; +type SerializerType = Box; + +/// Serializes a single data stream in parallel and writes to an ObjectStore +/// concurrently. Data order is preserved. In the event of an error, +/// the ObjectStore writer is returned to the caller in addition to an error, +/// so that the caller may handle aborting failed writes. +pub(crate) async fn serialize_rb_stream_to_object_store( + mut data_rx: Receiver, + mut serializer: Box, + mut writer: AbortableWrite>, + unbounded_input: bool, +) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { + let (tx, mut rx) = + mpsc::channel::>>(100); + + let serialize_task = tokio::spawn(async move { + while let Some(batch) = data_rx.recv().await { + match serializer.duplicate() { + Ok(mut serializer_clone) => { + let handle = tokio::spawn(async move { + let num_rows = batch.num_rows(); + let bytes = serializer_clone.serialize(batch).await?; + Ok((num_rows, bytes)) + }); + tx.send(handle).await.map_err(|_| { + DataFusionError::Internal( + "Unknown error writing to object store".into(), + ) + })?; + if unbounded_input { + tokio::task::yield_now().await; + } + } + Err(_) => { + return Err(DataFusionError::Internal( + "Unknown error writing to object store".into(), + )) + } + } + } + Ok(()) + }); + + let mut row_count = 0; + while let Some(handle) = rx.recv().await { + match handle.await { + Ok(Ok((cnt, bytes))) => { + match writer.write_all(&bytes).await { + Ok(_) => (), + Err(e) => { + return Err(( + writer, + DataFusionError::Execution(format!( + "Error writing to object store: {e}" + )), + )) + } + }; + row_count += cnt; + } + Ok(Err(e)) => { + // Return the writer along with the error + return Err((writer, e)); + } + Err(e) => { + // Handle task panic or cancellation + return Err(( + writer, + DataFusionError::Execution(format!( + "Serialization task panicked or was cancelled: {e}" + )), + )); + } + } + } + + match serialize_task.await { + Ok(Ok(_)) => (), + Ok(Err(e)) => return Err((writer, e)), + Err(_) => { + return Err(( + writer, + DataFusionError::Internal("Unknown error writing to object store".into()), + )) + } + }; + Ok((writer, row_count as u64)) +} + +type FileWriteBundle = (Receiver, SerializerType, WriterType); +/// Contains the common logic for serializing RecordBatches and +/// writing the resulting bytes to an ObjectStore. +/// Serialization is assumed to be stateless, i.e. +/// each RecordBatch can be serialized without any +/// dependency on the RecordBatches before or after. +pub(crate) async fn stateless_serialize_and_write_files( + mut rx: Receiver, + tx: tokio::sync::oneshot::Sender, + unbounded_input: bool, +) -> Result<()> { + let mut row_count = 0; + // tracks if any writers encountered an error triggering the need to abort + let mut any_errors = false; + // tracks the specific error triggering abort + let mut triggering_error = None; + // tracks if any errors were encountered in the process of aborting writers. + // if true, we may not have a guarentee that all written data was cleaned up. + let mut any_abort_errors = false; + let mut join_set = JoinSet::new(); + while let Some((data_rx, serializer, writer)) = rx.recv().await { + join_set.spawn(async move { + serialize_rb_stream_to_object_store( + data_rx, + serializer, + writer, + unbounded_input, + ) + .await + }); + } + let mut finished_writers = Vec::new(); + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => match res { + Ok((writer, cnt)) => { + finished_writers.push(writer); + row_count += cnt; + } + Err((writer, e)) => { + finished_writers.push(writer); + any_errors = true; + triggering_error = Some(e); + } + }, + Err(e) => { + // Don't panic, instead try to clean up as many writers as possible. + // If we hit this code, ownership of a writer was not joined back to + // this thread, so we cannot clean it up (hence any_abort_errors is true) + any_errors = true; + any_abort_errors = true; + triggering_error = Some(DataFusionError::Internal(format!( + "Unexpected join error while serializing file {e}" + ))); + } + } + } + + // Finalize or abort writers as appropriate + for mut writer in finished_writers.into_iter() { + match any_errors { + true => { + let abort_result = writer.abort_writer(); + if abort_result.is_err() { + any_abort_errors = true; + } + } + false => { + writer.shutdown() + .await + .map_err(|_| DataFusionError::Internal("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!".into()))?; + } + } + } + + if any_errors { + match any_abort_errors{ + true => return Err(DataFusionError::Internal("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written.".into())), + false => match triggering_error { + Some(e) => return Err(e), + None => return Err(DataFusionError::Internal("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.".into())) + } + } + } + + tx.send(row_count).map_err(|_| { + DataFusionError::Internal( + "Error encountered while sending row count back to file sink!".into(), + ) + })?; + Ok(()) +} + +/// Orchestrates multipart put of a dynamic number of output files from a single input stream +/// for any statelessly serialized file type. That is, any file type for which each [RecordBatch] +/// can be serialized independently of all other [RecordBatch]s. +pub(crate) async fn stateless_multipart_put( + data: SendableRecordBatchStream, + context: &Arc, + file_extension: String, + get_serializer: Box Box + Send>, + config: &FileSinkConfig, + compression: FileCompressionType, +) -> Result { + let object_store = context + .runtime_env() + .object_store(&config.object_store_url)?; + + let single_file_output = config.single_file_output; + let base_output_path = &config.table_paths[0]; + let unbounded_input = config.unbounded_input; + let part_cols = if !config.table_partition_cols.is_empty() { + Some(config.table_partition_cols.clone()) + } else { + None + }; + + let (demux_task, mut file_stream_rx) = start_demuxer_task( + data, + context, + part_cols, + base_output_path.clone(), + file_extension, + single_file_output, + ); + + let rb_buffer_size = &context + .session_config() + .options() + .execution + .max_buffered_batches_per_output_file; + + let (tx_file_bundle, rx_file_bundle) = tokio::sync::mpsc::channel(rb_buffer_size / 2); + let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel(); + let write_coordinater_task = tokio::spawn(async move { + stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt, unbounded_input) + .await + }); + while let Some((location, rb_stream)) = file_stream_rx.recv().await { + let serializer = get_serializer(); + let writer = create_writer(compression, &location, object_store.clone()).await?; + + tx_file_bundle + .send((rb_stream, serializer, writer)) + .await + .map_err(|_| { + DataFusionError::Internal( + "Writer receive file bundle channel closed unexpectedly!".into(), + ) + })?; + } + + // Signal to the write coordinater that no more files are coming + drop(tx_file_bundle); + + match try_join!(write_coordinater_task, demux_task) { + Ok((r1, r2)) => { + r1?; + r2?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + + let total_count = rx_row_cnt.await.map_err(|_| { + DataFusionError::Internal( + "Did not receieve row count from write coordinater".into(), + ) + })?; + + Ok(total_count) +} diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs new file mode 100644 index 0000000000000..2fd352ee4eb31 --- /dev/null +++ b/datafusion/core/src/datasource/function.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A table that uses a function to generate data + +use super::TableProvider; + +use datafusion_common::Result; +use datafusion_expr::Expr; + +use std::sync::Arc; + +/// A trait for table function implementations +pub trait TableFunctionImpl: Sync + Send { + /// Create a table provider + fn call(&self, args: &[Expr]) -> Result>; +} + +/// A table that uses a function to generate data +pub struct TableFunction { + /// Name of the table function + name: String, + /// Function implementation + fun: Arc, +} + +impl TableFunction { + /// Create a new table function + pub fn new(name: String, fun: Arc) -> Self { + Self { name, fun } + } + + /// Get the name of the table function + pub fn name(&self) -> &str { + &self.name + } + + /// Get the function implementation and generate a table + pub fn create_table_provider(&self, args: &[Expr]) -> Result> { + self.fun.call(args) + } +} diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index efe1f7b59afbc..3536c098bd76f 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -36,10 +36,10 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; +use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError}; -use datafusion_expr::expr::ScalarUDF; -use datafusion_expr::{Expr, Volatility}; +use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; @@ -53,17 +53,17 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(&mut |expr| { - Ok(match expr { + match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - VisitRecursion::Skip + Ok(VisitRecursion::Skip) } else { - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } Expr::Literal(_) - | Expr::Alias(_, _) + | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Not(_) @@ -81,7 +81,6 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::BinaryExpr { .. } | Expr::Between { .. } | Expr::Like { .. } - | Expr::ILike { .. } | Expr::SimilarTo { .. } | Expr::InList { .. } | Expr::Exists { .. } @@ -89,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => VisitRecursion::Continue, + | Expr::Case { .. } => Ok(VisitRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match scalar_function.fun.volatility() { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } } - } - } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature.volatility { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + ScalarFunctionDefinition::UDF(fun) => { + match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } } } @@ -116,17 +122,15 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } - }) + } }) .unwrap(); is_applicable @@ -276,7 +280,10 @@ async fn prune_partitions( // Applies `filter` to `batch` returning `None` on error let do_filter = |filter| -> Option { let expr = create_physical_expr(filter, &df_schema, &schema, &props).ok()?; - Some(expr.evaluate(&batch).ok()?.into_array(partitions.len())) + expr.evaluate(&batch) + .ok()? + .into_array(partitions.len()) + .ok() }; //.Compute the conjunction of the filters, ignoring errors @@ -316,17 +323,21 @@ async fn prune_partitions( /// `filters` might contain expressions that can be resolved only at the /// file level (e.g. Parquet row group pruning). pub async fn pruned_partition_list<'a>( + ctx: &'a SessionState, store: &'a dyn ObjectStore, table_path: &'a ListingTableUrl, filters: &'a [Expr], file_extension: &'a str, partition_cols: &'a [(String, DataType)], ) -> Result>> { - let list = table_path.list_all_files(store, file_extension); - // if no partition col => simply list all the files if partition_cols.is_empty() { - return Ok(Box::pin(list.map_ok(|object_meta| object_meta.into()))); + return Ok(Box::pin( + table_path + .list_all_files(ctx, store, file_extension) + .await? + .map_ok(|object_meta| object_meta.into()), + )); } let partitions = list_partitions(store, table_path, partition_cols.len()).await?; @@ -355,8 +366,7 @@ pub async fn pruned_partition_list<'a>( Some(files) => files, None => { trace!("Recursively listing partition {}", partition.path); - let s = store.list(Some(&partition.path)).await?; - s.try_collect().await? + store.list(Some(&partition.path)).try_collect().await? } }; @@ -421,7 +431,7 @@ mod tests { use futures::StreamExt; use crate::logical_expr::{case, col, lit}; - use crate::test::object_store::make_test_store; + use crate::test::object_store::make_test_store_and_state; use super::*; @@ -467,12 +477,13 @@ mod tests { #[tokio::test] async fn test_pruned_partition_list_empty() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/notparquetfile", 100), ("tablepath/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter], @@ -489,13 +500,14 @@ mod tests { #[tokio::test] async fn test_pruned_partition_list() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/mypartition=val1/file.parquet", 100), ("tablepath/mypartition=val2/file.parquet", 100), ("tablepath/mypartition=val1/other=val3/file.parquet", 100), ]); let filter = Expr::eq(col("mypartition"), lit("val1")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter], @@ -514,24 +526,18 @@ mod tests { f1.object_meta.location.as_ref(), "tablepath/mypartition=val1/file.parquet" ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); let f2 = &pruned[1]; assert_eq!( f2.object_meta.location.as_ref(), "tablepath/mypartition=val1/other=val3/file.parquet" ); - assert_eq!( - f2.partition_values, - &[ScalarValue::Utf8(Some(String::from("val1"))),] - ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); } #[tokio::test] async fn test_pruned_partition_list_multi() { - let store = make_test_store(&[ + let (store, state) = make_test_store_and_state(&[ ("tablepath/part1=p1v1/file.parquet", 100), ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), @@ -543,6 +549,7 @@ mod tests { // filter3 cannot be resolved at partition pruning let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( + &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), &[filter1, filter2, filter3], @@ -566,10 +573,7 @@ mod tests { ); assert_eq!( &f1.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] ); let f2 = &pruned[1]; assert_eq!( @@ -578,10 +582,7 @@ mod tests { ); assert_eq!( &f2.partition_values, - &[ - ScalarValue::Utf8(Some(String::from("p1v2"))), - ScalarValue::Utf8(Some(String::from("p2v1"))) - ] + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] ); } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index 427cfc8501b39..87c1663ae7183 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -50,7 +50,6 @@ pub struct FileRange { #[derive(Debug, Clone)] /// A single file or part of a file that should be read, along with its schema, statistics -/// A single file that should be read, along with its schema, statistics /// and partition column values that need to be appended to each row. pub struct PartitionedFile { /// Path for the file (e.g. URL, filesystem path, etc) @@ -81,6 +80,7 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -96,12 +96,19 @@ impl PartitionedFile { last_modified: chrono::Utc.timestamp_nanos(0), size: size as usize, e_tag: None, + version: None, }, partition_values: vec![], range: Some(FileRange { start, end }), extensions: None, } } + + /// Return a file reference from the given path + pub fn from_path(path: String) -> Result { + let size = std::fs::metadata(path.clone())?.len(); + Ok(Self::new(path, size)) + } } impl From for PartitionedFile { diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0252e99ab8a57..0ce1b43fe4564 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,44 +17,51 @@ //! The table implementation. +use std::collections::HashMap; use std::str::FromStr; use std::{any::Any, sync::Arc}; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; -use async_trait::async_trait; -use dashmap::DashMap; -use datafusion_common::ToDFSchema; -use datafusion_expr::expr::Sort; -use datafusion_optimizer::utils::conjunction; -use datafusion_physical_expr::{create_physical_expr, LexOrdering, PhysicalSortExpr}; -use futures::{future, stream, StreamExt, TryStreamExt}; -use object_store::path::Path; -use object_store::ObjectMeta; +use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; +use super::PartitionedFile; -use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; -use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ + create_ordering, file_format::{ - arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, json::JsonFormat, - parquet::ParquetFormat, FileFormat, + arrow::ArrowFormat, + avro::AvroFormat, + csv::CsvFormat, + file_compression_type::{FileCompressionType, FileTypeExt}, + json::JsonFormat, + FileFormat, }, get_statistics_with_limit, listing::ListingTableUrl, + physical_plan::{is_plan_streaming, FileScanConfig, FileSinkConfig}, TableProvider, TableType, }; -use crate::logical_expr::TableProviderFilterPushDown; -use crate::physical_plan; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, - logical_expr::Expr, - physical_plan::{empty::EmptyExec, project_schema, ExecutionPlan, Statistics}, + logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, + physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, }; -use super::PartitionedFile; +use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; +use arrow_schema::Schema; +use datafusion_common::{ + internal_err, plan_err, project_schema, Constraints, FileType, FileTypeWriterOptions, + SchemaExt, ToDFSchema, +}; +use datafusion_execution::cache::cache_manager::FileStatisticsCache; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; +use datafusion_physical_expr::{ + create_physical_expr, LexOrdering, PhysicalSortRequirement, +}; -use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; +use async_trait::async_trait; +use futures::{future, stream, StreamExt, TryStreamExt}; /// Configuration for creating a [`ListingTable`] #[derive(Debug, Clone)] @@ -133,14 +140,15 @@ impl ListingTableConfig { .map_err(|_| DataFusionError::Internal(err_msg))?; let file_format: Arc = match file_type { - FileType::ARROW => Arc::new(ArrowFormat::default()), - FileType::AVRO => Arc::new(AvroFormat::default()), + FileType::ARROW => Arc::new(ArrowFormat), + FileType::AVRO => Arc::new(AvroFormat), FileType::CSV => Arc::new( CsvFormat::default().with_file_compression_type(file_compression_type), ), FileType::JSON => Arc::new( JsonFormat::default().with_file_compression_type(file_compression_type), ), + #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), }; @@ -149,15 +157,18 @@ impl ListingTableConfig { /// Infer `ListingOptions` based on `table_path` suffix. pub async fn infer_options(self, state: &SessionState) -> Result { - let store = state - .runtime_env() - .object_store(self.table_paths.get(0).unwrap())?; + let store = if let Some(url) = self.table_paths.first() { + state.runtime_env().object_store(url)? + } else { + return Ok(self); + }; let file = self .table_paths - .get(0) + .first() .unwrap() - .list_all_files(store.as_ref(), "") + .list_all_files(state, store.as_ref(), "") + .await? .next() .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; @@ -180,9 +191,11 @@ impl ListingTableConfig { pub async fn infer_schema(self, state: &SessionState) -> Result { match self.options { Some(options) => { - let schema = options - .infer_schema(state, self.table_paths.get(0).unwrap()) - .await?; + let schema = if let Some(url) = self.table_paths.first() { + options.infer_schema(state, url).await? + } else { + Arc::new(Schema::empty()) + }; Ok(Self { table_paths: self.table_paths, @@ -190,9 +203,7 @@ impl ListingTableConfig { options: Some(options), }) } - None => Err(DataFusionError::Internal( - "No `ListingOptions` set for inferring schema".into(), - )), + None => internal_err!("No `ListingOptions` set for inferring schema"), } } @@ -240,6 +251,12 @@ pub struct ListingOptions { /// In order to support infinite inputs, DataFusion may adjust query /// plans (e.g. joins) to run the given query in full pipelining mode. pub infinite_source: bool, + /// This setting when true indicates that the table is backed by a single file. + /// Any inserts to the table may only append to this existing file. + pub single_file: bool, + /// This setting holds file format specific options which should be used + /// when inserting into this table. + pub file_type_write_options: Option, } impl ListingOptions { @@ -258,6 +275,8 @@ impl ListingOptions { target_partitions: 1, file_sort_order: vec![], infinite_source: false, + single_file: false, + file_type_write_options: None, } } @@ -426,6 +445,21 @@ impl ListingOptions { self } + /// Configure if this table is backed by a sigle file + pub fn with_single_file(mut self, single_file: bool) -> Self { + self.single_file = single_file; + self + } + + /// Configure file format specific writing options. + pub fn with_write_options( + mut self, + file_type_write_options: FileTypeWriterOptions, + ) -> Self { + self.file_type_write_options = Some(file_type_write_options); + self + } + /// Infer the schema of the files at the given path on the provided object store. /// The inferred schema does not include the partitioning columns. /// @@ -440,7 +474,8 @@ impl ListingOptions { let store = state.runtime_env().object_store(table_path)?; let files: Vec<_> = table_path - .list_all_files(store.as_ref(), &self.file_extension) + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? .try_collect() .await?; @@ -448,39 +483,6 @@ impl ListingOptions { } } -/// Collected statistics for files -/// Cache is invalided when file size or last modification has changed -#[derive(Default)] -struct StatisticsCache { - statistics: DashMap, -} - -impl StatisticsCache { - /// Get `Statistics` for file location. Returns None if file has changed or not found. - fn get(&self, meta: &ObjectMeta) -> Option { - self.statistics - .get(&meta.location) - .map(|s| { - let (saved_meta, statistics) = s.value(); - if saved_meta.size != meta.size - || saved_meta.last_modified != meta.last_modified - { - // file has changed - None - } else { - Some(statistics.clone()) - } - }) - .unwrap_or(None) - } - - /// Save collected file statistics - fn save(&self, meta: ObjectMeta, statistics: Statistics) { - self.statistics - .insert(meta.location.clone(), (meta, statistics)); - } -} - /// Reads data from one or more files via an /// [`ObjectStore`](object_store::ObjectStore). For example, from /// local files or objects from AWS S3. Implements [`TableProvider`], @@ -488,7 +490,7 @@ impl StatisticsCache { /// /// # Features /// -/// 1. Merges schemas if the files have compatible but not indentical schemas +/// 1. Merges schemas if the files have compatible but not identical schemas /// /// 2. Hive-style partitioning support, where a path such as /// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. @@ -554,8 +556,10 @@ pub struct ListingTable { table_schema: SchemaRef, options: ListingOptions, definition: Option, - collected_statistics: StatisticsCache, + collected_statistics: FileStatisticsCache, infinite_source: bool, + constraints: Constraints, + column_defaults: HashMap, } impl ListingTable { @@ -591,13 +595,42 @@ impl ListingTable { table_schema: Arc::new(builder.finish()), options, definition: None, - collected_statistics: Default::default(), + collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), infinite_source, + constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; Ok(table) } + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. + /// + /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics + /// multiple times in the same session. + /// + /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. + pub fn with_cache(mut self, cache: Option) -> Self { + self.collected_statistics = + cache.unwrap_or(Arc::new(DefaultFileStatisticsCache::default())); + self + } + /// Specify the SQL definition for this table, if any pub fn with_definition(mut self, defintion: Option) -> Self { self.definition = defintion; @@ -616,39 +649,7 @@ impl ListingTable { /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { - let mut all_sort_orders = vec![]; - - for exprs in &self.options.file_sort_order { - // Construct PhsyicalSortExpr objects from Expr objects: - let sort_exprs = exprs - .iter() - .map(|expr| { - if let Expr::Sort(Sort { expr, asc, nulls_first }) = expr { - if let Expr::Column(col) = expr.as_ref() { - let expr = physical_plan::expressions::col(&col.name, self.table_schema.as_ref())?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } - else { - Err(DataFusionError::Plan( - format!("Expected single column references in output_ordering, got {expr:?}") - )) - } - } else { - Err(DataFusionError::Plan( - format!("Expected Expr::Sort in output_ordering, but got {expr:?}") - )) - } - }) - .collect::>>()?; - all_sort_orders.push(sort_exprs); - } - Ok(all_sort_orders) + create_ordering(&self.table_schema, &self.options.file_sort_order) } } @@ -662,6 +663,10 @@ impl TableProvider for ListingTable { Arc::clone(&self.table_schema) } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + fn table_type(&self) -> TableType { TableType::Base } @@ -680,7 +685,7 @@ impl TableProvider for ListingTable { if partitioned_file_lists.is_empty() { let schema = self.schema(); let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(false, projected_schema))); + return Ok(Arc::new(EmptyExec::new(projected_schema))); } // extract types of partition columns @@ -688,15 +693,7 @@ impl TableProvider for ListingTable { .options .table_partition_cols .iter() - .map(|col| { - Ok(( - col.0.to_owned(), - self.table_schema - .field_with_name(&col.0)? - .data_type() - .clone(), - )) - }) + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) .collect::>>()?; let filters = if let Some(expr) = conjunction(filters.to_vec()) { @@ -713,13 +710,18 @@ impl TableProvider for ListingTable { None }; + let object_store_url = if let Some(url) = self.table_paths.first() { + url.object_store() + } else { + return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); + }; // create the execution plan self.options .format .create_physical_plan( state, FileScanConfig { - object_store_url: self.table_paths.get(0).unwrap().object_store(), + object_store_url, file_schema: Arc::clone(&self.file_schema), file_groups: partitioned_file_lists, statistics, @@ -764,27 +766,32 @@ impl TableProvider for ListingTable { &self, state: &SessionState, input: Arc, + overwrite: bool, ) -> Result> { // Check that the schema of the plan matches the schema of this table. - if !input.schema().eq(&self.schema()) { - return Err(DataFusionError::Plan( + if !self + .schema() + .logically_equivalent_names_and_types(&input.schema()) + { + return plan_err!( // Return an error if schema of the input query does not match with the table schema. - "Inserting query must have the same schema with the table.".to_string(), - )); + "Inserting query must have the same schema with the table." + ); } - if self.table_paths().len() > 1 { - return Err(DataFusionError::Plan( - "Writing to a table backed by multiple files is not supported yet" - .to_owned(), - )); + let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { + return plan_err!( + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" + ); } - let table_path = &self.table_paths()[0]; // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; let file_list_stream = pruned_partition_list( + state, store.as_ref(), table_path, &[], @@ -794,28 +801,65 @@ impl TableProvider for ListingTable { .await?; let file_groups = file_list_stream.try_collect::>().await?; - - if file_groups.len() > 1 { - return Err(DataFusionError::Plan( - "Datafusion currently supports tables from single partition and/or file." - .to_owned(), - )); - } + let file_format = self.options().format.as_ref(); + + let file_type_writer_options = match &self.options().file_type_write_options { + Some(opt) => opt.clone(), + None => FileTypeWriterOptions::build_default( + &file_format.file_type(), + state.config_options(), + )?, + }; // Sink related option, apart from format let config = FileSinkConfig { object_store_url: self.table_paths()[0].object_store(), + table_paths: self.table_paths().clone(), file_groups, - output_schema: input.schema(), + output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - writer_mode: crate::datasource::file_format::FileWriterMode::Append, + // A plan can produce finite number of rows even if it has unbounded sources, like LIMIT + // queries. Thus, we can check if the plan is streaming to ensure file sink input is + // unbounded. When `unbounded_input` flag is `true` for sink, we occasionally call `yield_now` + // to consume data at the input. When `unbounded_input` flag is `false` (e.g non-streaming data), + // all of the data at the input is sink after execution finishes. See discussion for rationale: + // https://github.com/apache/arrow-datafusion/pull/7610#issuecomment-1728979918 + unbounded_input: is_plan_streaming(&input)?, + single_file_output: self.options.single_file, + overwrite, + file_type_writer_options, + }; + + let unsorted: Vec> = vec![]; + let order_requirements = if self.options().file_sort_order != unsorted { + // Multiple sort orders in outer vec are equivalent, so we pass only the first one + let ordering = self + .try_create_output_ordering()? + .first() + .ok_or(DataFusionError::Internal( + "Expected ListingTable to have a sort order, but none found!".into(), + ))? + .clone(); + // Converts Vec> into type required by execution plan to specify its required input ordering + Some( + ordering + .into_iter() + .map(PhysicalSortRequirement::from) + .collect::>(), + ) + } else { + None }; self.options() .format - .create_writer_physical_plan(input, state, config) + .create_writer_physical_plan(input, state, config, order_requirements) .await } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } } impl ListingTable { @@ -828,14 +872,15 @@ impl ListingTable { filters: &'a [Expr], limit: Option, ) -> Result<(Vec>, Statistics)> { - let store = if let Some(url) = self.table_paths.get(0) { + let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { - return Ok((vec![], Statistics::default())); + return Ok((vec![], Statistics::new_unknown(&self.file_schema))); }; // list files (with partitions) let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { pruned_partition_list( + ctx, store.as_ref(), table_path, filters, @@ -844,36 +889,46 @@ impl ListingTable { ) })) .await?; - let file_list = stream::iter(file_list).flatten(); - // collect the statistics if required by the config - let files = file_list.then(|part_file| async { - let part_file = part_file?; - let statistics = if self.options.collect_stat { - match self.collected_statistics.get(&part_file.object_meta) { - Some(statistics) => statistics, - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - &store, - self.file_schema.clone(), + let files = file_list + .map(|part_file| async { + let part_file = part_file?; + let mut statistics_result = Statistics::new_unknown(&self.file_schema); + if self.options.collect_stat { + let statistics_cache = self.collected_statistics.clone(); + match statistics_cache.get_with_extra( + &part_file.object_meta.location, + &part_file.object_meta, + ) { + Some(statistics) => { + statistics_result = statistics.as_ref().clone() + } + None => { + let statistics = self + .options + .format + .infer_stats( + ctx, + &store, + self.file_schema.clone(), + &part_file.object_meta, + ) + .await?; + statistics_cache.put_with_extra( + &part_file.object_meta.location, + statistics.clone().into(), &part_file.object_meta, - ) - .await?; - self.collected_statistics - .save(part_file.object_meta.clone(), statistics.clone()); - statistics + ); + statistics_result = statistics; + } } } - } else { - Statistics::default() - }; - Ok((part_file, statistics)) as Result<(PartitionedFile, Statistics)> - }); + Ok((part_file, statistics_result)) + as Result<(PartitionedFile, Statistics)> + }) + .boxed() + .buffered(ctx.config_options().execution.meta_fetch_concurrency); let (files, statistics) = get_statistics_with_limit(files, self.schema(), limit).await?; @@ -887,27 +942,32 @@ impl ListingTable { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::fs::File; + use super::*; - use crate::datasource::file_format::file_type::GetExt; + #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{provider_as_source, MemTable}; + use crate::execution::options::ArrowReadOptions; use crate::physical_plan::collect; use crate::prelude::*; use crate::{ assert_batches_eq, - datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat}, + datasource::file_format::avro::AvroFormat, execution::options::ReadOptions, logical_expr::{col, lit}, test::{columns, object_store::register_test_store}, }; - use arrow::csv; + use arrow::datatypes::{DataType, Schema}; - use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; - use chrono::DateTime; - use datafusion_common::assert_contains; - use datafusion_expr::LogicalPlanBuilder; + use arrow_schema::SortOptions; + use datafusion_common::stats::Precision; + use datafusion_common::{assert_contains, GetExt, ScalarValue}; + use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; + use datafusion_physical_expr::PhysicalSortExpr; use rstest::*; - use std::fs::File; use tempfile::TempDir; /// It creates dummy file and checks if it can create unbounded input executors. @@ -953,12 +1013,13 @@ mod tests { assert_eq!(exec.output_partitioning().partition_count(), 1); // test metadata - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_by_default() -> Result<()> { let testdata = crate::test_util::parquet_test_data(); @@ -976,12 +1037,13 @@ mod tests { let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics().num_rows, Some(8)); - assert_eq!(exec.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_when_no_stats() -> Result<()> { let testdata = crate::test_util::parquet_test_data(); @@ -1000,12 +1062,13 @@ mod tests { let table = ListingTable::try_new(config)?; let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics().num_rows, None); - assert_eq!(exec.statistics().total_byte_size, None); + assert_eq!(exec.statistics()?.num_rows, Precision::Absent); + assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); Ok(()) } + #[cfg(feature = "parquet")] #[tokio::test] async fn test_try_create_output_ordering() { let testdata = crate::test_util::parquet_test_data(); @@ -1045,7 +1108,6 @@ mod tests { nulls_first: false, }, }]]) - ), // ok with two columns, different options ( @@ -1069,9 +1131,7 @@ mod tests { }, }, ]]) - ), - ]; for (file_sort_order, expected_result) in cases { @@ -1389,24 +1449,6 @@ mod tests { Ok(Arc::new(table)) } - fn load_empty_schema_csv_table( - schema: SchemaRef, - temp_path: &str, - ) -> Result> { - File::create(temp_path)?; - let table_path = ListingTableUrl::parse(temp_path).unwrap(); - - let file_format = CsvFormat::default(); - let listing_options = ListingOptions::new(Arc::new(file_format)); - - let config = ListingTableConfig::new(table_path) - .with_listing_options(listing_options) - .with_schema(schema); - - let table = ListingTable::try_new(config)?; - Ok(Arc::new(table)) - } - /// Check that the files listed by the table match the specified `output_partitioning` /// when the object store contains `files`. async fn assert_list_files_for_scan_grouping( @@ -1476,48 +1518,300 @@ mod tests { Ok(()) } - #[test] - fn test_statistics_cache() { - let meta = ObjectMeta { - location: Path::from("test"), - last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") - .unwrap() - .into(), - size: 1024, - e_tag: None, - }; + #[tokio::test] + async fn test_insert_into_append_new_json_files() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); + helper_test_append_new_files_to_table( + FileType::JSON, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 2, + ) + .await?; + Ok(()) + } - let cache = StatisticsCache::default(); - assert!(cache.get(&meta).is_none()); + #[tokio::test] + async fn test_insert_into_append_new_csv_files() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); + helper_test_append_new_files_to_table( + FileType::CSV, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 2, + ) + .await?; + Ok(()) + } - cache.save(meta.clone(), Statistics::default()); - assert!(cache.get(&meta).is_some()); + #[tokio::test] + async fn test_insert_into_append_2_new_parquet_files_defaults() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); + helper_test_append_new_files_to_table( + FileType::PARQUET, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 2, + ) + .await?; + Ok(()) + } - // file size changed - let mut meta2 = meta.clone(); - meta2.size = 2048; - assert!(cache.get(&meta2).is_none()); + #[tokio::test] + async fn test_insert_into_append_1_new_parquet_files_defaults() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "20".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "20".into(), + ); + helper_test_append_new_files_to_table( + FileType::PARQUET, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 1, + ) + .await?; + Ok(()) + } - // file last_modified changed - let mut meta2 = meta.clone(); - meta2.last_modified = DateTime::parse_from_rfc3339("2022-09-27T22:40:00+02:00") - .unwrap() - .into(); - assert!(cache.get(&meta2).is_none()); + #[tokio::test] + async fn test_insert_into_sql_csv_defaults() -> Result<()> { + helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_sql_csv_defaults_header_row() -> Result<()> { + helper_test_insert_into_sql( + "csv", + FileCompressionType::UNCOMPRESSED, + "WITH HEADER ROW", + None, + ) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_sql_json_defaults() -> Result<()> { + helper_test_insert_into_sql("json", FileCompressionType::UNCOMPRESSED, "", None) + .await?; + Ok(()) + } - // different file - let mut meta2 = meta; - meta2.location = Path::from("test2"); - assert!(cache.get(&meta2).is_none()); + #[tokio::test] + async fn test_insert_into_sql_parquet_defaults() -> Result<()> { + helper_test_insert_into_sql( + "parquet", + FileCompressionType::UNCOMPRESSED, + "", + None, + ) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_sql_parquet_session_overrides() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert( + "datafusion.execution.parquet.compression".into(), + "zstd(5)".into(), + ); + config_map.insert( + "datafusion.execution.parquet.dictionary_enabled".into(), + "false".into(), + ); + config_map.insert( + "datafusion.execution.parquet.dictionary_page_size_limit".into(), + "100".into(), + ); + config_map.insert( + "datafusion.execution.parquet.staistics_enabled".into(), + "none".into(), + ); + config_map.insert( + "datafusion.execution.parquet.max_statistics_size".into(), + "10".into(), + ); + config_map.insert( + "datafusion.execution.parquet.max_row_group_size".into(), + "5".into(), + ); + config_map.insert( + "datafusion.execution.parquet.created_by".into(), + "datafusion test".into(), + ); + config_map.insert( + "datafusion.execution.parquet.column_index_truncate_length".into(), + "50".into(), + ); + config_map.insert( + "datafusion.execution.parquet.data_page_row_count_limit".into(), + "50".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_enabled".into(), + "true".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_fpp".into(), + "0.01".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_ndv".into(), + "1000".into(), + ); + config_map.insert( + "datafusion.execution.parquet.writer_version".into(), + "2.0".into(), + ); + config_map.insert( + "datafusion.execution.parquet.write_batch_size".into(), + "5".into(), + ); + helper_test_insert_into_sql( + "parquet", + FileCompressionType::UNCOMPRESSED, + "", + Some(config_map), + ) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_append_new_parquet_files_session_overrides() -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + "10".into(), + ); + config_map.insert( + "datafusion.execution.parquet.compression".into(), + "zstd(5)".into(), + ); + config_map.insert( + "datafusion.execution.parquet.dictionary_enabled".into(), + "false".into(), + ); + config_map.insert( + "datafusion.execution.parquet.dictionary_page_size_limit".into(), + "100".into(), + ); + config_map.insert( + "datafusion.execution.parquet.staistics_enabled".into(), + "none".into(), + ); + config_map.insert( + "datafusion.execution.parquet.max_statistics_size".into(), + "10".into(), + ); + config_map.insert( + "datafusion.execution.parquet.max_row_group_size".into(), + "5".into(), + ); + config_map.insert( + "datafusion.execution.parquet.created_by".into(), + "datafusion test".into(), + ); + config_map.insert( + "datafusion.execution.parquet.column_index_truncate_length".into(), + "50".into(), + ); + config_map.insert( + "datafusion.execution.parquet.data_page_row_count_limit".into(), + "50".into(), + ); + config_map.insert( + "datafusion.execution.parquet.encoding".into(), + "delta_binary_packed".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_enabled".into(), + "true".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_fpp".into(), + "0.01".into(), + ); + config_map.insert( + "datafusion.execution.parquet.bloom_filter_ndv".into(), + "1000".into(), + ); + config_map.insert( + "datafusion.execution.parquet.writer_version".into(), + "2.0".into(), + ); + config_map.insert( + "datafusion.execution.parquet.write_batch_size".into(), + "5".into(), + ); + config_map.insert("datafusion.execution.batch_size".into(), "1".into()); + helper_test_append_new_files_to_table( + FileType::PARQUET, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 2, + ) + .await?; + Ok(()) } #[tokio::test] - async fn test_append_plan_to_external_table_stored_as_csv() -> Result<()> { - let file_type = FileType::CSV; - let file_compression_type = FileCompressionType::UNCOMPRESSED; + async fn test_insert_into_append_new_parquet_files_invalid_session_fails( + ) -> Result<()> { + let mut config_map: HashMap = HashMap::new(); + config_map.insert( + "datafusion.execution.parquet.compression".into(), + "zstd".into(), + ); + let e = helper_test_append_new_files_to_table( + FileType::PARQUET, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + 2, + ) + .await + .expect_err("Example should fail!"); + assert_eq!(e.strip_backtrace(), "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)"); + + Ok(()) + } + async fn helper_test_append_new_files_to_table( + file_type: FileType, + file_compression_type: FileCompressionType, + session_config_map: Option>, + expected_n_files_per_insert: usize, + ) -> Result<()> { // Create the initial context, schema, and batch. - let session_ctx = SessionContext::new(); + let session_ctx = match session_config_map { + Some(cfg) => { + let config = SessionConfig::from_string_hash_map(cfg)?; + SessionContext::new_with_config(config) + } + None => SessionContext::new(), + }; + // Create a new schema with one field called "a" of type Int32 let schema = Arc::new(Schema::new(vec![Field::new( "column1", @@ -1525,31 +1819,74 @@ mod tests { false, )])); + let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( + Box::new(Expr::Column("column1".into())), + Operator::GtEq, + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + )); + // Create a new batch of data to insert into the table let batch = RecordBatch::try_new( schema.clone(), - vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + vec![Arc::new(arrow_array::Int32Array::from(vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + ]))], )?; - // Filename with extension - let filename = format!( - "path{}", - file_type - .to_owned() - .get_ext_with_compression(file_compression_type.clone()) - .unwrap() - ); - - // Define batch size for file reader - let batch_size = batch.num_rows(); - - // Create a temporary directory and a CSV file within it. + // Register appropriate table depending on file_type we want to test let tmp_dir = TempDir::new()?; - let path = tmp_dir.path().join(filename); + match file_type { + FileType::CSV => { + session_ctx + .register_csv( + "t", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(schema.as_ref()) + .file_compression_type(file_compression_type), + ) + .await?; + } + FileType::JSON => { + session_ctx + .register_json( + "t", + tmp_dir.path().to_str().unwrap(), + NdJsonReadOptions::default() + .schema(schema.as_ref()) + .file_compression_type(file_compression_type), + ) + .await?; + } + FileType::PARQUET => { + session_ctx + .register_parquet( + "t", + tmp_dir.path().to_str().unwrap(), + ParquetReadOptions::default().schema(schema.as_ref()), + ) + .await?; + } + FileType::AVRO => { + session_ctx + .register_avro( + "t", + tmp_dir.path().to_str().unwrap(), + AvroReadOptions::default().schema(schema.as_ref()), + ) + .await?; + } + FileType::ARROW => { + session_ctx + .register_arrow( + "t", + tmp_dir.path().to_str().unwrap(), + ArrowReadOptions::default().schema(schema.as_ref()), + ) + .await?; + } + } - let initial_table = - load_empty_schema_csv_table(schema.clone(), path.to_str().unwrap())?; - session_ctx.register_table("t", initial_table)?; // Create and register the source table with the provided schema and inserted data let source_table = Arc::new(MemTable::try_new( schema.clone(), @@ -1559,60 +1896,54 @@ mod tests { // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); // Create a table scan logical plan to read from the source table - let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; + let scan_plan = LogicalPlanBuilder::scan("source", source, None)? + .filter(filter_predicate)? + .build()?; + // Since logical plan contains a filter, increasing parallelism is helpful. + // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() .create_physical_plan(&insert_into_table) .await?; - // Execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. - let expected = vec![ + let expected = [ "+-------+", "| count |", "+-------+", - "| 6 |", + "| 20 |", "+-------+", ]; // Assert that the batches read from the file match the expected result. assert_batches_eq!(expected, &res); - // Open the CSV file, read its contents as a record batch, and collect the batches into a vector. - let file = File::open(path.clone())?; - let reader = csv::ReaderBuilder::new(schema.clone()) - .has_header(true) - .with_batch_size(batch_size) - .build(file) - .map_err(|e| DataFusionError::Internal(e.to_string()))?; - - let batches = reader - .collect::>>() - .into_iter() - .collect::>>() - .map_err(|e| DataFusionError::Internal(e.to_string()))?; - - // Define the expected result as a vector of strings. - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", + + // Read the records in the table + let batches = session_ctx + .sql("select count(*) as count from t") + .await? + .collect() + .await?; + let expected = [ + "+-------+", + "| count |", + "+-------+", + "| 20 |", + "+-------+", ]; // Assert that the batches read from the file match the expected result. assert_batches_eq!(expected, &batches); + // Assert that `target_partition_number` many files were added to the table. + let num_files = tmp_dir.path().read_dir()?.count(); + assert_eq!(num_files, expected_n_files_per_insert); + // Create a physical plan from the insert plan let plan = session_ctx .state() @@ -1622,55 +1953,101 @@ mod tests { // Again, execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. - let expected = vec![ + let expected = [ "+-------+", "| count |", "+-------+", - "| 6 |", + "| 20 |", "+-------+", ]; // Assert that the batches read from the file match the expected result. assert_batches_eq!(expected, &res); - // Open the CSV file, read its contents as a record batch, and collect the batches into a vector. - let file = File::open(path.clone())?; - let reader = csv::ReaderBuilder::new(schema.clone()) - .has_header(true) - .with_batch_size(batch_size) - .build(file) - .map_err(|e| DataFusionError::Internal(e.to_string()))?; - - let batches = reader - .collect::>>() - .into_iter() - .collect::>>() - .map_err(|e| DataFusionError::Internal(e.to_string())); + // Read the contents of the table + let batches = session_ctx + .sql("select count(*) AS count from t") + .await? + .collect() + .await?; // Define the expected result after the second append. - let expected = vec![ - "+---------+", - "| column1 |", - "+---------+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| 3 |", - "+---------+", + let expected = [ + "+-------+", + "| count |", + "+-------+", + "| 40 |", + "+-------+", ]; // Assert that the batches read from the file after the second append match the expected result. - assert_batches_eq!(expected, &batches?); + assert_batches_eq!(expected, &batches); + + // Assert that another `target_partition_number` many files were added to the table. + let num_files = tmp_dir.path().read_dir()?.count(); + assert_eq!(num_files, expected_n_files_per_insert * 2); // Return Ok if the function Ok(()) } + + /// tests insert into with end to end sql + /// create external table + insert into statements + async fn helper_test_insert_into_sql( + file_type: &str, + // TODO test with create statement options such as compression + _file_compression_type: FileCompressionType, + external_table_options: &str, + session_config_map: Option>, + ) -> Result<()> { + // Create the initial context + let session_ctx = match session_config_map { + Some(cfg) => { + let config = SessionConfig::from_string_hash_map(cfg)?; + SessionContext::new_with_config(config) + } + None => SessionContext::new(), + }; + + // create table + let tmp_dir = TempDir::new()?; + let tmp_path = tmp_dir.into_path(); + let str_path = tmp_path.to_str().expect("Temp path should convert to &str"); + session_ctx + .sql(&format!( + "create external table foo(a varchar, b varchar, c int) \ + stored as {file_type} \ + location '{str_path}' \ + {external_table_options}" + )) + .await? + .collect() + .await?; + + // insert data + session_ctx.sql("insert into foo values ('foo', 'bar', 1),('foo', 'bar', 2), ('foo', 'bar', 3)") + .await? + .collect() + .await?; + + // check count + let batches = session_ctx + .sql("select * from foo") + .await? + .collect() + .await?; + + let expected = [ + "+-----+-----+---+", + "| a | b | c |", + "+-----+-----+---+", + "| foo | bar | 1 |", + "| foo | bar | 2 |", + "| foo | bar | 3 |", + "+-----+-----+---+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index dc96f959e4438..9e9fb9210071b 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -15,20 +15,24 @@ // specific language governing permissions and limitations // under the License. +use std::fs; + use crate::datasource::object_store::ObjectStoreUrl; +use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use glob::Pattern; use itertools::Itertools; +use log::debug; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -use percent_encoding; +use std::sync::Arc; use url::Url; /// A parsed URL identifying files for a listing table, see [`ListingTableUrl::parse`] /// for more information on the supported expressions -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct ListingTableUrl { /// A URL that identifies a file or directory to list files from url: Url, @@ -41,22 +45,45 @@ pub struct ListingTableUrl { impl ListingTableUrl { /// Parse a provided string as a `ListingTableUrl` /// + /// A URL can either refer to a single object, or a collection of objects with a + /// common prefix, with the presence of a trailing `/` indicating a collection. + /// + /// For example, `file:///foo.txt` refers to the file at `/foo.txt`, whereas + /// `file:///foo/` refers to all the files under the directory `/foo` and its + /// subdirectories. + /// + /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, + /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// S3 bucket `BUCKET` + /// + /// # URL Encoding + /// + /// URL paths are expected to be URL-encoded. That is, the URL for a file named `bar%2Efoo` + /// would be `file:///bar%252Efoo`, as per the [URL] specification. + /// + /// It should be noted that some tools, such as the AWS CLI, take a different approach and + /// instead interpret the URL path verbatim. For example the object `bar%2Efoo` would be + /// addressed as `s3://BUCKET/bar%252Efoo` using [`ListingTableUrl`] but `s3://BUCKET/bar%2Efoo` + /// when using the aws-cli. + /// /// # Paths without a Scheme /// /// If no scheme is provided, or the string is an absolute filesystem path - /// as determined [`std::path::Path::is_absolute`], the string will be + /// as determined by [`std::path::Path::is_absolute`], the string will be /// interpreted as a path on the local filesystem using the operating /// system's standard path delimiter, i.e. `\` on Windows, `/` on Unix. /// /// If the path contains any of `'?', '*', '['`, it will be considered /// a glob expression and resolved as described in the section below. /// - /// Otherwise, the path will be resolved to an absolute path, returning - /// an error if it does not exist, and converted to a [file URI] + /// Otherwise, the path will be resolved to an absolute path based on the current + /// working directory, and converted to a [file URI]. /// - /// If you wish to specify a path that does not exist on the local - /// machine you must provide it as a fully-qualified [file URI] - /// e.g. `file:///myfile.txt` + /// If the path already exists in the local filesystem this will be used to determine if this + /// [`ListingTableUrl`] refers to a collection or a single object, otherwise the presence + /// of a trailing path delimiter will be used to indicate a directory. For the avoidance + /// of ambiguity it is recommended users always include trailing `/` when intending to + /// refer to a directory. /// /// ## Glob File Paths /// @@ -64,14 +91,13 @@ impl ListingTableUrl { /// be resolved as follows. /// /// The string up to the first path segment containing a glob expression will be extracted, - /// and resolved in the same manner as a normal scheme-less path. That is, resolved to - /// an absolute path on the local filesystem, returning an error if it does not exist, - /// and converted to a [file URI] + /// and resolved in the same manner as a normal scheme-less path above. /// /// The remaining string will be interpreted as a [`glob::Pattern`] and used as a /// filter when listing files from object storage /// /// [file URI]: https://en.wikipedia.org/wiki/File_URI_scheme + /// [URL]: https://url.spec.whatwg.org/ pub fn parse(s: impl AsRef) -> Result { let s = s.as_ref(); @@ -81,15 +107,41 @@ impl ListingTableUrl { } match Url::parse(s) { - Ok(url) => Ok(Self::new(url, None)), + Ok(url) => Self::try_new(url, None), Err(url::ParseError::RelativeUrlWithoutBase) => Self::parse_path(s), Err(e) => Err(DataFusionError::External(Box::new(e))), } } + /// Get object store for specified input_url + /// if input_url is actually not a url, we assume it is a local file path + /// if we have a local path, create it if not exists so ListingTableUrl::parse works + pub fn parse_create_local_if_not_exists( + s: impl AsRef, + is_directory: bool, + ) -> Result { + let s = s.as_ref(); + let is_valid_url = Url::parse(s).is_ok(); + + match is_valid_url { + true => ListingTableUrl::parse(s), + false => { + let path = std::path::PathBuf::from(s); + if !path.exists() { + if is_directory { + fs::create_dir_all(path)?; + } else { + fs::File::create(path)?; + } + } + ListingTableUrl::parse(s) + } + } + } + /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path fn parse_path(s: &str) -> Result { - let (prefix, glob) = match split_glob_expression(s) { + let (path, glob) = match split_glob_expression(s) { Some((prefix, glob)) => { let glob = Pattern::new(glob) .map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -98,25 +150,19 @@ impl ListingTableUrl { None => (s, None), }; - let path = std::path::Path::new(prefix).canonicalize()?; - let url = if path.is_dir() { - Url::from_directory_path(path) - } else { - Url::from_file_path(path) - } - .map_err(|_| DataFusionError::Internal(format!("Can not open path: {s}")))?; - // TODO: Currently we do not have an IO-related error variant that accepts () - // or a string. Once we have such a variant, change the error type above. + let url = url_from_filesystem_path(path).ok_or_else(|| { + DataFusionError::External( + format!("Failed to convert path to URL: {path}").into(), + ) + })?; - Ok(Self::new(url, glob)) + Self::try_new(url, glob) } /// Creates a new [`ListingTableUrl`] from a url and optional glob expression - fn new(url: Url, glob: Option) -> Self { - let decoded_path = - percent_encoding::percent_decode_str(url.path()).decode_utf8_lossy(); - let prefix = Path::from(decoded_path.as_ref()); - Self { url, prefix, glob } + fn try_new(url: Url, glob: Option) -> Result { + let prefix = Path::from_url_path(url.path())?; + Ok(Self { url, prefix, glob }) } /// Returns the URL scheme @@ -124,7 +170,10 @@ impl ListingTableUrl { self.url.scheme() } - /// Return the prefix from which to list files + /// Return the URL path not excluding any glob expression + /// + /// If [`Self::is_collection`], this is the listing prefix + /// Otherwise, this is the path to the object pub fn prefix(&self) -> &Path { &self.prefix } @@ -143,6 +192,11 @@ impl ListingTableUrl { } } + /// Returns `true` if `path` refers to a collection of objects + pub fn is_collection(&self) -> bool { + self.url.as_str().ends_with('/') + } + /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning /// an iterator of the remaining path segments pub(crate) fn strip_prefix<'a, 'b: 'a>( @@ -158,28 +212,40 @@ impl ListingTableUrl { } /// List all files identified by this [`ListingTableUrl`] for the provided `file_extension` - pub(crate) fn list_all_files<'a>( + pub(crate) async fn list_all_files<'a>( &'a self, + ctx: &'a SessionState, store: &'a dyn ObjectStore, file_extension: &'a str, - ) -> BoxStream<'a, Result> { + ) -> Result>> { // If the prefix is a file, use a head request, otherwise list - let is_dir = self.url.as_str().ends_with('/'); - let list = match is_dir { - true => futures::stream::once(store.list(Some(&self.prefix))) - .try_flatten() - .boxed(), + let list = match self.is_collection() { + true => match ctx.runtime_env().cache_manager.get_list_files_cache() { + None => store.list(Some(&self.prefix)), + Some(cache) => { + if let Some(res) = cache.get(&self.prefix) { + debug!("Hit list all files cache"); + futures::stream::iter(res.as_ref().clone().into_iter().map(Ok)) + .boxed() + } else { + let list_res = store.list(Some(&self.prefix)); + let vec = list_res.try_collect::>().await?; + cache.put(&self.prefix, Arc::new(vec.clone())); + futures::stream::iter(vec.into_iter().map(Ok)).boxed() + } + } + }, false => futures::stream::once(store.head(&self.prefix)).boxed(), }; - - list.map_err(Into::into) + Ok(list .try_filter(move |meta| { let path = &meta.location; let extension_match = path.as_ref().ends_with(file_extension); let glob_match = self.contains(path); futures::future::ready(extension_match && glob_match) }) - .boxed() + .map_err(DataFusionError::ObjectStore) + .boxed()) } /// Returns this [`ListingTableUrl`] as a string @@ -194,6 +260,34 @@ impl ListingTableUrl { } } +/// Creates a file URL from a potentially relative filesystem path +fn url_from_filesystem_path(s: &str) -> Option { + let path = std::path::Path::new(s); + let is_dir = match path.exists() { + true => path.is_dir(), + // Fallback to inferring from trailing separator + false => std::path::is_separator(s.chars().last()?), + }; + + let from_absolute_path = |p| { + let first = match is_dir { + true => Url::from_directory_path(p).ok(), + false => Url::from_file_path(p).ok(), + }?; + + // By default from_*_path preserve relative path segments + // We therefore parse the URL again to resolve these + Url::parse(first.as_str()).ok() + }; + + if path.is_absolute() { + return from_absolute_path(path); + } + + let absolute = std::env::current_dir().ok()?.join(path); + from_absolute_path(&absolute) +} + impl AsRef for ListingTableUrl { fn as_ref(&self) -> &str { self.url.as_ref() @@ -241,6 +335,7 @@ fn split_glob_expression(path: &str) -> Option<(&str, &str)> { #[cfg(test)] mod tests { use super::*; + use tempfile::tempdir; #[test] fn test_prefix_path() { @@ -273,7 +368,57 @@ mod tests { assert_eq!(url.prefix.as_ref(), "foo/bar"); let url = ListingTableUrl::parse("file:///foo/😺").unwrap(); - assert_eq!(url.prefix.as_ref(), "foo/%F0%9F%98%BA"); + assert_eq!(url.prefix.as_ref(), "foo/😺"); + + let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); + + let url = ListingTableUrl::parse("file:///foo/bar%2Efoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar.foo"); + + let url = ListingTableUrl::parse("file:///foo/bar%252Ffoo").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/bar%2Ffoo"); + + let url = ListingTableUrl::parse("file:///foo/a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "foo/a%2Fb.txt"); + + let dir = tempdir().unwrap(); + let path = dir.path().join("bar%2Ffoo"); + std::fs::File::create(&path).unwrap(); + + let url = ListingTableUrl::parse(path.to_str().unwrap()).unwrap(); + assert!(url.prefix.as_ref().ends_with("bar%2Ffoo"), "{}", url.prefix); + + let url = ListingTableUrl::parse("file:///foo/../a%252Fb.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "a%2Fb.txt"); + + let url = + ListingTableUrl::parse("file:///foo/./bar/../../baz/./test.txt").unwrap(); + assert_eq!(url.prefix.as_ref(), "baz/test.txt"); + + let workdir = std::env::current_dir().unwrap(); + let t = workdir.join("non-existent"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("non-existent").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("non-existent")); + + let t = workdir.parent().unwrap(); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("..").unwrap(); + assert_eq!(a, b); + + let t = t.join("bar"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar")); + + let t = t.join(".").join("foo").join("..").join("baz"); + let a = ListingTableUrl::parse(t.to_str().unwrap()).unwrap(); + let b = ListingTableUrl::parse("../bar/./foo/../baz").unwrap(); + assert_eq!(a, b); + assert!(a.prefix.as_ref().ends_with("bar/baz")); } #[test] diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 7d10fc8e0e894..96436306c641e 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -21,38 +21,34 @@ use std::path::Path; use std::str::FromStr; use std::sync::Arc; -use arrow::datatypes::{DataType, SchemaRef}; -use async_trait::async_trait; -use datafusion_common::DataFusionError; -use datafusion_expr::CreateExternalTable; - -use crate::datasource::datasource::TableProviderFactory; -use crate::datasource::file_format::arrow::ArrowFormat; -use crate::datasource::file_format::avro::AvroFormat; -use crate::datasource::file_format::csv::CsvFormat; -use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; -use crate::datasource::file_format::json::JsonFormat; +#[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::FileFormat; +use crate::datasource::file_format::{ + arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, + file_compression_type::FileCompressionType, json::JsonFormat, FileFormat, +}; use crate::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; +use crate::datasource::provider::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_common::file_options::{FileTypeWriterOptions, StatementOptions}; +use datafusion_common::{plan_err, DataFusionError, FileType}; +use datafusion_expr::CreateExternalTable; + +use async_trait::async_trait; + /// A `TableProviderFactory` capable of creating new `ListingTable`s +#[derive(Debug, Default)] pub struct ListingTableFactory {} impl ListingTableFactory { /// Creates a new `ListingTableFactory` pub fn new() -> Self { - Self {} - } -} - -impl Default for ListingTableFactory { - fn default() -> Self { - Self::new() + Self::default() } } @@ -71,18 +67,27 @@ impl TableProviderFactory for ListingTableFactory { let file_extension = get_extension(cmd.location.as_str()); let file_format: Arc = match file_type { - FileType::CSV => Arc::new( - CsvFormat::default() + FileType::CSV => { + let mut statement_options = StatementOptions::from(&cmd.options); + let mut csv_format = CsvFormat::default() .with_has_header(cmd.has_header) .with_delimiter(cmd.delimiter as u8) - .with_file_compression_type(file_compression_type), - ), + .with_file_compression_type(file_compression_type); + if let Some(quote) = statement_options.take_str_option("quote") { + csv_format = csv_format.with_quote(quote.as_bytes()[0]) + } + if let Some(escape) = statement_options.take_str_option("escape") { + csv_format = csv_format.with_escape(Some(escape.as_bytes()[0])) + } + Arc::new(csv_format) + } + #[cfg(feature = "parquet")] FileType::PARQUET => Arc::new(ParquetFormat::default()), - FileType::AVRO => Arc::new(AvroFormat::default()), + FileType::AVRO => Arc::new(AvroFormat), FileType::JSON => Arc::new( JsonFormat::default().with_file_compression_type(file_compression_type), ), - FileType::ARROW => Arc::new(ArrowFormat::default()), + FileType::ARROW => Arc::new(ArrowFormat), }; let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { @@ -131,15 +136,87 @@ impl TableProviderFactory for ListingTableFactory { // look for 'infinite' as an option let infinite_source = cmd.unbounded; + let mut statement_options = StatementOptions::from(&cmd.options); + + // Extract ListingTable specific options if present or set default + let unbounded = if infinite_source { + statement_options.take_str_option("unbounded"); + infinite_source + } else { + statement_options + .take_bool_option("unbounded")? + .unwrap_or(false) + }; + + let create_local_path = statement_options + .take_bool_option("create_local_path")? + .unwrap_or(false); + let single_file = statement_options + .take_bool_option("single_file")? + .unwrap_or(false); + + // Backwards compatibility + if let Some(s) = statement_options.take_str_option("insert_mode") { + if !s.eq_ignore_ascii_case("append_new_files") { + return plan_err!("Unknown or unsupported insert mode {s}. Only append_to_file supported"); + } + } + let file_type = file_format.file_type(); + + // Use remaining options and session state to build FileTypeWriterOptions + let file_type_writer_options = FileTypeWriterOptions::build( + &file_type, + state.config_options(), + &statement_options, + )?; + + // Some options have special syntax which takes precedence + // e.g. "WITH HEADER ROW" overrides (header false, ...) + let file_type_writer_options = match file_type { + FileType::CSV => { + let mut csv_writer_options = + file_type_writer_options.try_into_csv()?.clone(); + csv_writer_options.writer_options = csv_writer_options + .writer_options + .with_header(cmd.has_header) + .with_delimiter(cmd.delimiter.try_into().map_err(|_| { + DataFusionError::Internal( + "Unable to convert CSV delimiter into u8".into(), + ) + })?); + csv_writer_options.compression = cmd.file_compression_type; + FileTypeWriterOptions::CSV(csv_writer_options) + } + FileType::JSON => { + let mut json_writer_options = + file_type_writer_options.try_into_json()?.clone(); + json_writer_options.compression = cmd.file_compression_type; + FileTypeWriterOptions::JSON(json_writer_options) + } + #[cfg(feature = "parquet")] + FileType::PARQUET => file_type_writer_options, + FileType::ARROW => file_type_writer_options, + FileType::AVRO => file_type_writer_options, + }; + + let table_path = match create_local_path { + true => ListingTableUrl::parse_create_local_if_not_exists( + &cmd.location, + !single_file, + ), + false => ListingTableUrl::parse(&cmd.location), + }?; + let options = ListingOptions::new(file_format) .with_collect_stat(state.config().collect_statistics()) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) - .with_infinite_source(infinite_source) - .with_file_sort_order(cmd.order_exprs.clone()); + .with_file_sort_order(cmd.order_exprs.clone()) + .with_single_file(single_file) + .with_write_options(file_type_writer_options) + .with_infinite_source(unbounded); - let table_path = ListingTableUrl::parse(&cmd.location)?; let resolved_schema = match provided_schema { None => options.infer_schema(state, &table_path).await?, Some(s) => s, @@ -147,8 +224,12 @@ impl TableProviderFactory for ListingTableFactory { let config = ListingTableConfig::new(table_path) .with_listing_options(options) .with_schema(resolved_schema); - let table = - ListingTable::try_new(config)?.with_definition(cmd.definition.clone()); + let provider = ListingTable::try_new(config)? + .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); + let table = provider + .with_definition(cmd.definition.clone()) + .with_constraints(cmd.constraints.clone()) + .with_column_defaults(cmd.column_defaults.clone()); Ok(Arc::new(table)) } } @@ -164,13 +245,13 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { - use super::*; - use std::collections::HashMap; + use super::*; use crate::execution::context::SessionContext; + use datafusion_common::parsers::CompressionTypeVariant; - use datafusion_common::{DFSchema, OwnedTableReference}; + use datafusion_common::{Constraints, DFSchema, OwnedTableReference}; #[tokio::test] async fn test_create_using_non_std_file_ext() { @@ -198,6 +279,8 @@ mod tests { order_exprs: vec![], unbounded: false, options: HashMap::new(), + constraints: Constraints::empty(), + column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index f66b44e9d1f9a..7c044b29366d5 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -17,27 +17,33 @@ //! [`MemTable`] for querying `Vec` by DataFusion. +use datafusion_physical_plan::metrics::MetricsSet; use futures::StreamExt; +use log::debug; use std::any::Any; -use std::fmt::{self, Debug, Display}; +use std::collections::HashMap; +use std::fmt::{self, Debug}; use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion_common::{ + not_impl_err, plan_err, Constraints, DataFusionError, SchemaExt, +}; use datafusion_execution::TaskContext; use tokio::sync::RwLock; +use tokio::task::JoinSet; use crate::datasource::{TableProvider, TableType}; -use crate::error::{DataFusionError, Result}; +use crate::error::Result; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::common::AbortOnDropSingle; -use crate::physical_plan::insert::{DataSink, InsertExec}; +use crate::physical_plan::insert::{DataSink, FileSinkExec}; use crate::physical_plan::memory::MemoryExec; -use crate::physical_plan::ExecutionPlan; use crate::physical_plan::{common, SendableRecordBatchStream}; use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; +use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; /// Type alias for partition data pub type PartitionData = Arc>>; @@ -50,28 +56,48 @@ pub type PartitionData = Arc>>; pub struct MemTable { schema: SchemaRef, pub(crate) batches: Vec, + constraints: Constraints, + column_defaults: HashMap, } impl MemTable { /// Create a new in-memory table from the provided schema and record batches pub fn try_new(schema: SchemaRef, partitions: Vec>) -> Result { - if partitions - .iter() - .flatten() - .all(|batches| schema.contains(&batches.schema())) - { - Ok(Self { - schema, - batches: partitions - .into_iter() - .map(|e| Arc::new(RwLock::new(e))) - .collect::>(), - }) - } else { - Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )) + for batches in partitions.iter().flatten() { + let batches_schema = batches.schema(); + if !schema.contains(&batches_schema) { + debug!( + "mem table schema does not contain batches schema. \ + Target_schema: {schema:?}. Batches Schema: {batches_schema:?}" + ); + return plan_err!("Mismatch between schema and batches"); + } } + + Ok(Self { + schema, + batches: partitions + .into_iter() + .map(|e| Arc::new(RwLock::new(e))) + .collect::>(), + constraints: Constraints::empty(), + column_defaults: HashMap::new(), + }) + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self } /// Create a mem table by reading from another data source @@ -84,26 +110,31 @@ impl MemTable { let exec = t.scan(state, None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); - let tasks = (0..partition_count) - .map(|part_i| { - let task = state.task_ctx(); - let exec = exec.clone(); - let task = tokio::spawn(async move { - let stream = exec.execute(part_i, task)?; - common::collect(stream).await - }); - - AbortOnDropSingle::new(task) - }) - // this collect *is needed* so that the join below can - // switch between tasks - .collect::>(); + let mut join_set = JoinSet::new(); + + for part_idx in 0..partition_count { + let task = state.task_ctx(); + let exec = exec.clone(); + join_set.spawn(async move { + let stream = exec.execute(part_idx, task)?; + common::collect(stream).await + }); + } let mut data: Vec> = Vec::with_capacity(exec.output_partitioning().partition_count()); - for result in futures::future::join_all(tasks).await { - data.push(result.map_err(|e| DataFusionError::External(Box::new(e)))??) + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => data.push(res?), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } let exec = MemoryExec::try_new(&data, schema.clone(), None)?; @@ -143,6 +174,10 @@ impl TableProvider for MemTable { self.schema.clone() } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + fn table_type(&self) -> TableType { TableType::Base } @@ -159,8 +194,8 @@ impl TableProvider for MemTable { let inner_vec = arc_inner_vec.read().await; partitions.push(inner_vec.clone()) } - Ok(Arc::new(MemoryExec::try_new_owned_data( - partitions, + Ok(Arc::new(MemoryExec::try_new( + &partitions, self.schema(), projection.cloned(), )?)) @@ -182,16 +217,32 @@ impl TableProvider for MemTable { &self, _state: &SessionState, input: Arc, + overwrite: bool, ) -> Result> { // Create a physical plan from the logical plan. // Check that the schema of the plan matches the schema of this table. - if !input.schema().eq(&self.schema) { - return Err(DataFusionError::Plan( - "Inserting query must have the same schema with the table.".to_string(), - )); + if !self + .schema() + .logically_equivalent_names_and_types(&input.schema()) + { + return plan_err!( + "Inserting query must have the same schema with the table." + ); + } + if overwrite { + return not_impl_err!("Overwrite not implemented for MemoryTable yet"); } let sink = Arc::new(MemSink::new(self.batches.clone())); - Ok(Arc::new(InsertExec::new(input, sink))) + Ok(Arc::new(FileSinkExec::new( + input, + sink, + self.schema.clone(), + None, + ))) + } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) } } @@ -209,10 +260,14 @@ impl Debug for MemSink { } } -impl Display for MemSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let partition_count = self.batches.len(); - write!(f, "MemoryTable (partitions={partition_count})") +impl DisplayAs for MemSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let partition_count = self.batches.len(); + write!(f, "MemoryTable (partitions={partition_count})") + } + } } } @@ -224,6 +279,14 @@ impl MemSink { #[async_trait] impl DataSink for MemSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + async fn write_all( &self, mut data: SendableRecordBatchStream, @@ -395,12 +458,11 @@ mod tests { ], )?; - match MemTable::try_new(schema2, vec![vec![batch]]) { - Err(DataFusionError::Plan(e)) => { - assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}")) - } - _ => panic!("MemTable::new should have failed due to schema mismatch"), - } + let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err(); + assert_eq!( + "Error during planning: Mismatch between schema and batches", + e.strip_backtrace() + ); Ok(()) } @@ -426,12 +488,11 @@ mod tests { ], )?; - match MemTable::try_new(schema2, vec![vec![batch]]) { - Err(DataFusionError::Plan(e)) => { - assert_eq!("\"Mismatch between schema and batches\"", format!("{e:?}")) - } - _ => panic!("MemTable::new should have failed due to schema mismatch"), - } + let e = MemTable::try_new(schema2, vec![vec![batch]]).unwrap_err(); + assert_eq!( + "Error during planning: Mismatch between schema and batches", + e.strip_backtrace() + ); Ok(()) } @@ -516,7 +577,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 683afb7902e5f..2e516cc36a01d 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -15,188 +15,76 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion data sources +//! DataFusion data sources: [`TableProvider`] and [`ListingTable`] +//! +//! [`ListingTable`]: crate::datasource::listing::ListingTable -// TODO(clippy): Having a `datasource::datasource` module path is unclear and ambiguous. -// The child module should probably be renamed to something that more accurately -// describes its content. Something along the lines of `provider`, or `providers`. -#![allow(clippy::module_inception)] -pub mod datasource; +pub mod avro_to_arrow; pub mod default_table_source; pub mod empty; pub mod file_format; +pub mod function; pub mod listing; pub mod listing_table_factory; pub mod memory; pub mod physical_plan; +pub mod provider; +mod statistics; +pub mod stream; pub mod streaming; pub mod view; // backwards compatibility pub use datafusion_execution::object_store; -use futures::Stream; - -pub use self::datasource::TableProvider; pub use self::default_table_source::{ provider_as_source, source_as_provider, DefaultTableSource, }; -use self::listing::PartitionedFile; pub use self::memory::MemTable; +pub use self::provider::TableProvider; pub use self::view::ViewTable; -use crate::arrow::datatypes::{Schema, SchemaRef}; -use crate::error::Result; pub use crate::logical_expr::TableType; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; -use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; -use futures::StreamExt; - -/// Get all files as well as the file level summary statistics (no statistic for partition columns). -/// If the optional `limit` is provided, includes only sufficient files. -/// Needed to read up to `limit` number of rows. -pub async fn get_statistics_with_limit( - all_files: impl Stream>, - file_schema: SchemaRef, - limit: Option, -) -> Result<(Vec, Statistics)> { - let mut result_files = vec![]; - - let mut null_counts = vec![0; file_schema.fields().len()]; - let mut has_statistics = false; - let (mut max_values, mut min_values) = create_max_min_accs(&file_schema); +pub use statistics::get_statistics_with_limit; - let mut is_exact = true; +use arrow_schema::{Schema, SortOptions}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::Expr; +use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; - // The number of rows and the total byte size can be calculated as long as - // at least one file has them. If none of the files provide them, then they - // will be omitted from the statistics. The missing values will be counted - // as zero. - let mut num_rows = None; - let mut total_byte_size = None; - - // fusing the stream allows us to call next safely even once it is finished - let mut all_files = Box::pin(all_files.fuse()); - while let Some(res) = all_files.next().await { - let (file, file_stats) = res?; - result_files.push(file); - is_exact &= file_stats.is_exact; - num_rows = if let Some(num_rows) = num_rows { - Some(num_rows + file_stats.num_rows.unwrap_or(0)) - } else { - file_stats.num_rows - }; - total_byte_size = if let Some(total_byte_size) = total_byte_size { - Some(total_byte_size + file_stats.total_byte_size.unwrap_or(0)) - } else { - file_stats.total_byte_size - }; - if let Some(vec) = &file_stats.column_statistics { - has_statistics = true; - for (i, cs) in vec.iter().enumerate() { - null_counts[i] += cs.null_count.unwrap_or(0); - - if let Some(max_value) = &mut max_values[i] { - if let Some(file_max) = cs.max_value.clone() { - match max_value.update_batch(&[file_max.to_array()]) { - Ok(_) => {} - Err(_) => { - max_values[i] = None; - } - } - } else { - max_values[i] = None; - } - } - - if let Some(min_value) = &mut min_values[i] { - if let Some(file_min) = cs.min_value.clone() { - match min_value.update_batch(&[file_min.to_array()]) { - Ok(_) => {} - Err(_) => { - min_values[i] = None; - } +fn create_ordering( + schema: &Schema, + sort_order: &[Vec], +) -> Result> { + let mut all_sort_orders = vec![]; + + for exprs in sort_order { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + match expr { + Expr::Sort(sort) => match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } - } else { - min_values[i] = None; + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, } + expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), } + expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), } } - - // If the number of rows exceeds the limit, we can stop processing - // files. This only applies when we know the number of rows. It also - // currently ignores tables that have no statistics regarding the - // number of rows. - if num_rows.unwrap_or(usize::MIN) > limit.unwrap_or(usize::MAX) { - break; + if !sort_exprs.is_empty() { + all_sort_orders.push(sort_exprs); } } - // if we still have files in the stream, it means that the limit kicked - // in and that the statistic could have been different if we processed - // the files in a different order. - if all_files.next().await.is_some() { - is_exact = false; - } - - let column_stats = if has_statistics { - Some(get_col_stats( - &file_schema, - null_counts, - &mut max_values, - &mut min_values, - )) - } else { - None - }; - - let statistics = Statistics { - num_rows, - total_byte_size, - column_statistics: column_stats, - is_exact, - }; - - Ok((result_files, statistics)) -} - -fn create_max_min_accs( - schema: &Schema, -) -> (Vec>, Vec>) { - let max_values: Vec> = schema - .fields() - .iter() - .map(|field| MaxAccumulator::try_new(field.data_type()).ok()) - .collect::>(); - let min_values: Vec> = schema - .fields() - .iter() - .map(|field| MinAccumulator::try_new(field.data_type()).ok()) - .collect::>(); - (max_values, min_values) -} - -fn get_col_stats( - schema: &Schema, - null_counts: Vec, - max_values: &mut [Option], - min_values: &mut [Option], -) -> Vec { - (0..schema.fields().len()) - .map(|i| { - let max_value = match &max_values[i] { - Some(max_value) => max_value.evaluate().ok(), - None => None, - }; - let min_value = match &min_values[i] { - Some(min_value) => min_value.evaluate().ok(), - None => None, - }; - ColumnStatistics { - null_count: Some(null_counts[i]), - max_value, - min_value, - distinct_count: None, - } - }) - .collect() + Ok(all_sort_orders) } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 43074ccb77c1d..30b55db284918 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -16,25 +16,26 @@ // under the License. //! Execution plan for reading Arrow files + +use std::any::Any; +use std::sync::Arc; + use crate::datasource::physical_plan::{ FileMeta, FileOpenFuture, FileOpener, FileScanConfig, }; use crate::error::Result; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }; + use arrow_schema::SchemaRef; use datafusion_common::Statistics; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - LexOrdering, OrderingEquivalenceProperties, PhysicalSortExpr, -}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; + use futures::StreamExt; -use object_store::{GetResult, ObjectStore}; -use std::any::Any; -use std::sync::Arc; +use object_store::{GetResultPayload, ObjectStore}; /// Execution plan for scanning Arrow data source #[derive(Debug, Clone)] @@ -68,6 +69,17 @@ impl ArrowExec { } } +impl DisplayAs for ArrowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "ArrowExec: ")?; + self.base_config.fmt_as(t, f) + } +} + impl ExecutionPlan for ArrowExec { fn as_any(&self) -> &dyn Any { self @@ -91,8 +103,8 @@ impl ExecutionPlan for ArrowExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -132,20 +144,8 @@ impl ExecutionPlan for ArrowExec { Some(self.metrics.clone_inner()) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "ArrowExec: {}", self.base_config) - } - } - } - - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } } @@ -159,13 +159,14 @@ impl FileOpener for ArrowOpener { let object_store = self.object_store.clone(); let projection = self.projection.clone(); Ok(Box::pin(async move { - match object_store.get(file_meta.location()).await? { - GetResult::File(file, _) => { + let r = object_store.get(file_meta.location()).await?; + match r.payload { + GetResultPayload::File(file, _) => { let arrow_reader = arrow::ipc::reader::FileReader::try_new(file, projection)?; Ok(futures::stream::iter(arrow_reader).boxed()) } - r @ GetResult::Stream(_) => { + GetResultPayload::Stream(_) => { let bytes = r.bytes().await?; let cursor = std::io::Cursor::new(bytes); let arrow_reader = diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 704a97ba7e886..885b4c5d3911e 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -16,22 +16,22 @@ // under the License. //! Execution plan for reading line-delimited Avro files + +use std::any::Any; +use std::sync::Arc; + +use super::FileScanConfig; use crate::error::Result; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }; -use datafusion_execution::TaskContext; use arrow::datatypes::SchemaRef; -use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties}; - -use std::any::Any; -use std::sync::Arc; - -use super::FileScanConfig; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] @@ -65,6 +65,17 @@ impl AvroExec { } } +impl DisplayAs for AvroExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "AvroExec: ")?; + self.base_config.fmt_as(t, f) + } +} + impl ExecutionPlan for AvroExec { fn as_any(&self) -> &dyn Any { self @@ -88,8 +99,8 @@ impl ExecutionPlan for AvroExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -141,20 +152,8 @@ impl ExecutionPlan for AvroExec { Ok(Box::pin(stream)) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "AvroExec: {}", self.base_config) - } - } - } - - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -165,11 +164,12 @@ impl ExecutionPlan for AvroExec { #[cfg(feature = "avro")] mod private { use super::*; + use crate::datasource::avro_to_arrow::Reader as AvroReader; use crate::datasource::physical_plan::file_stream::{FileOpenFuture, FileOpener}; use crate::datasource::physical_plan::FileMeta; use bytes::Buf; use futures::StreamExt; - use object_store::{GetResult, ObjectStore}; + use object_store::{GetResultPayload, ObjectStore}; pub struct AvroConfig { pub schema: SchemaRef, @@ -179,11 +179,8 @@ mod private { } impl AvroConfig { - fn open( - &self, - reader: R, - ) -> Result> { - crate::avro_to_arrow::Reader::try_new( + fn open(&self, reader: R) -> Result> { + AvroReader::try_new( reader, self.schema.clone(), self.batch_size, @@ -200,12 +197,13 @@ mod private { fn open(&self, file_meta: FileMeta) -> Result { let config = self.config.clone(); Ok(Box::pin(async move { - match config.object_store.get(file_meta.location()).await? { - GetResult::File(file, _) => { + let r = config.object_store.get(file_meta.location()).await?; + match r.payload { + GetResultPayload::File(file, _) => { let reader = config.open(file)?; Ok(futures::stream::iter(reader).boxed()) } - r @ GetResult::Stream(_) => { + GetResultPayload::Stream(_) => { let bytes = r.bytes().await?; let reader = config.open(bytes.reader())?; Ok(futures::stream::iter(reader).boxed()) @@ -222,12 +220,12 @@ mod tests { use crate::datasource::file_format::{avro::AvroFormat, FileFormat}; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::chunked_store::ChunkedStore; use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test::object_store::local_unpartitioned_file; use arrow::datatypes::{DataType, Field, SchemaBuilder}; use futures::StreamExt; + use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use object_store::ObjectStore; use rstest::*; @@ -272,8 +270,8 @@ mod tests { let avro_exec = AvroExec::new(FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![vec![meta.into()]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![0, 1, 2]), limit: None, table_partition_cols: vec![], @@ -291,7 +289,7 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - let expected = vec![ + let expected = [ "+----+----------+-------------+", "| id | bool_col | tinyint_col |", "+----+----------+-------------+", @@ -344,8 +342,8 @@ mod tests { let avro_exec = AvroExec::new(FileScanConfig { object_store_url, file_groups: vec![vec![meta.into()]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection, limit: None, table_partition_cols: vec![], @@ -364,7 +362,7 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - let expected = vec![ + let expected = [ "+----+----------+-------------+-------------+", "| id | bool_col | tinyint_col | missing_col |", "+----+----------+-------------+-------------+", @@ -408,8 +406,7 @@ mod tests { .await?; let mut partitioned_file = PartitionedFile::from(meta); - partitioned_file.partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let avro_exec = AvroExec::new(FileScanConfig { // select specific columns of the files as well as the partitioning @@ -417,10 +414,10 @@ mod tests { projection: Some(vec![0, 1, file_schema.fields().len(), 2]), object_store_url, file_groups: vec![vec![partitioned_file]], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), limit: None, - table_partition_cols: vec![("date".to_owned(), DataType::Utf8)], + table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], output_ordering: vec![], infinite_source: false, }); @@ -436,7 +433,7 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - let expected = vec![ + let expected = [ "+----+----------+------------+-------------+", "| id | bool_col | date | tinyint_col |", "+----+----------+------------+-------------+", diff --git a/datafusion/core/src/datasource/physical_plan/chunked_store.rs b/datafusion/core/src/datasource/physical_plan/chunked_store.rs deleted file mode 100644 index 05528ed8a2b6a..0000000000000 --- a/datafusion/core/src/datasource/physical_plan/chunked_store.rs +++ /dev/null @@ -1,223 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; -use futures::stream::BoxStream; -use futures::StreamExt; -use object_store::path::Path; -use object_store::{GetOptions, GetResult, ListResult, ObjectMeta, ObjectStore}; -use object_store::{MultipartId, Result}; -use std::fmt::{Debug, Display, Formatter}; -use std::ops::Range; -use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWrite, BufReader}; - -/// Wraps a [`ObjectStore`] and makes its get response return chunks -/// in a controllable manner. -/// -/// A `ChunkedStore` makes the memory consumption and performance of -/// the wrapped [`ObjectStore`] worse. It is intended for use within -/// tests, to control the chunks in the produced output streams. For -/// example, it is used to verify the delimiting logic in -/// newline_delimited_stream. -/// -/// TODO: Upstream into object_store_rs -#[derive(Debug)] -pub struct ChunkedStore { - inner: Arc, - chunk_size: usize, -} - -impl ChunkedStore { - pub fn new(inner: Arc, chunk_size: usize) -> Self { - Self { inner, chunk_size } - } -} - -impl Display for ChunkedStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ChunkedStore({})", self.inner) - } -} - -#[async_trait] -impl ObjectStore for ChunkedStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.inner.put(location, bytes).await - } - - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - self.inner.put_multipart(location).await - } - - async fn abort_multipart( - &self, - location: &Path, - multipart_id: &MultipartId, - ) -> Result<()> { - self.inner.abort_multipart(location, multipart_id).await - } - - async fn get(&self, location: &Path) -> Result { - match self.inner.get(location).await? { - GetResult::File(std_file, ..) => { - let file = tokio::fs::File::from_std(std_file); - let reader = BufReader::new(file); - Ok(GetResult::Stream( - futures::stream::unfold( - (reader, self.chunk_size), - |(mut reader, chunk_size)| async move { - let mut buffer = BytesMut::zeroed(chunk_size); - let size = reader.read(&mut buffer).await.map_err(|e| { - object_store::Error::Generic { - store: "ChunkedStore", - source: Box::new(e), - } - }); - match size { - Ok(0) => None, - Ok(value) => Some(( - Ok(buffer.split_to(value).freeze()), - (reader, chunk_size), - )), - Err(e) => Some((Err(e), (reader, chunk_size))), - } - }, - ) - .boxed(), - )) - } - GetResult::Stream(stream) => { - let buffer = BytesMut::new(); - Ok(GetResult::Stream( - futures::stream::unfold( - (stream, buffer, false, self.chunk_size), - |(mut stream, mut buffer, mut exhausted, chunk_size)| async move { - // Keep accumulating bytes until we reach capacity as long as - // the stream can provide them: - if exhausted { - return None; - } - while buffer.len() < chunk_size { - match stream.next().await { - None => { - exhausted = true; - let slice = buffer.split_off(0).freeze(); - return Some(( - Ok(slice), - (stream, buffer, exhausted, chunk_size), - )); - } - Some(Ok(bytes)) => { - buffer.put(bytes); - } - Some(Err(e)) => { - return Some(( - Err(object_store::Error::Generic { - store: "ChunkedStore", - source: Box::new(e), - }), - (stream, buffer, exhausted, chunk_size), - )) - } - }; - } - // Return the chunked values as the next value in the stream - let slice = buffer.split_to(chunk_size).freeze(); - Some((Ok(slice), (stream, buffer, exhausted, chunk_size))) - }, - ) - .boxed(), - )) - } - } - } - - async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { - self.inner.get_opts(location, options).await - } - - async fn get_range(&self, location: &Path, range: Range) -> Result { - self.inner.get_range(location, range).await - } - - async fn head(&self, location: &Path) -> Result { - self.inner.head(location).await - } - - async fn delete(&self, location: &Path) -> Result<()> { - self.inner.delete(location).await - } - - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - self.inner.list(prefix).await - } - - async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { - self.inner.list_with_delimiter(prefix).await - } - - async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - self.inner.copy_if_not_exists(from, to).await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures::StreamExt; - use object_store::memory::InMemory; - use object_store::path::Path; - - #[tokio::test] - async fn test_chunked() { - let location = Path::parse("test").unwrap(); - let store = Arc::new(InMemory::new()); - store - .put(&location, Bytes::from(vec![0; 1001])) - .await - .unwrap(); - - for chunk_size in [10, 20, 31] { - let store = ChunkedStore::new(store.clone(), chunk_size); - let mut s = match store.get(&location).await.unwrap() { - GetResult::Stream(s) => s, - _ => unreachable!(), - }; - - let mut remaining = 1001; - while let Some(next) = s.next().await { - let size = next.unwrap().len(); - let expected = remaining.min(chunk_size); - assert_eq!(size, expected); - remaining -= expected; - } - assert_eq!(remaining, 0); - } - } -} diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index d2c76ecaf5eac..816a82543bab8 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -17,36 +17,38 @@ //! Execution plan for reading CSV files -use crate::datasource::file_format::file_type::FileCompressionType; +use std::any::Any; +use std::io::{Read, Seek, SeekFrom}; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; + +use super::FileScanConfig; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::listing::{FileRange, ListingTableUrl}; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }; + use arrow::csv; use arrow::datatypes::SchemaRef; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties}; - -use super::FileScanConfig; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; -use futures::ready; -use futures::{StreamExt, TryStreamExt}; -use object_store::{GetResult, ObjectStore}; -use std::any::Any; -use std::fs; -use std::path::Path; -use std::sync::Arc; -use std::task::Poll; -use tokio::task::{self, JoinHandle}; +use datafusion_common::config::ConfigOptions; +use futures::{ready, StreamExt, TryStreamExt}; +use object_store::{GetOptions, GetResultPayload, ObjectStore}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] @@ -57,9 +59,12 @@ pub struct CsvExec { projected_output_ordering: Vec, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, - file_compression_type: FileCompressionType, + /// Compression type of the file associated with CsvExec + pub file_compression_type: FileCompressionType, } impl CsvExec { @@ -68,6 +73,8 @@ impl CsvExec { base_config: FileScanConfig, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, file_compression_type: FileCompressionType, ) -> Self { let (projected_schema, projected_statistics, projected_output_ordering) = @@ -80,6 +87,8 @@ impl CsvExec { projected_output_ordering, has_header, delimiter, + quote, + escape, metrics: ExecutionPlanMetricsSet::new(), file_compression_type, } @@ -97,6 +106,28 @@ impl CsvExec { pub fn delimiter(&self) -> u8 { self.delimiter } + + /// The quote character + pub fn quote(&self) -> u8 { + self.quote + } + + /// The escape character + pub fn escape(&self) -> Option { + self.escape + } +} + +impl DisplayAs for CsvExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "CsvExec: ")?; + self.base_config.fmt_as(t, f)?; + write!(f, ", has_header={}", self.has_header) + } } impl ExecutionPlan for CsvExec { @@ -126,8 +157,8 @@ impl ExecutionPlan for CsvExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -145,6 +176,35 @@ impl ExecutionPlan for CsvExec { Ok(self) } + /// Redistribute files across partitions according to their size + /// See comments on `repartition_file_groups()` for more detail. + /// + /// Return `None` if can't get repartitioned(empty/compressed file). + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + // Parallel execution on compressed CSV file is not supported yet. + if self.file_compression_type.is_compressed() { + return Ok(None); + } + + let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( + self.base_config.file_groups.clone(), + target_partitions, + repartition_file_min_size, + ); + + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + let mut new_plan = self.clone(); + new_plan.base_config.file_groups = repartitioned_file_groups; + return Ok(Some(Arc::new(new_plan))); + } + Ok(None) + } + fn execute( &self, partition: usize, @@ -160,6 +220,8 @@ impl ExecutionPlan for CsvExec { file_projection: self.base_config.file_column_projection_indices(), has_header: self.has_header, delimiter: self.delimiter, + quote: self.quote, + escape: self.escape, object_store, }); @@ -172,24 +234,8 @@ impl ExecutionPlan for CsvExec { Ok(Box::pin(stream) as SendableRecordBatchStream) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "CsvExec: {}, has_header={}", - self.base_config, self.has_header, - ) - } - } - } - - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -205,6 +251,8 @@ pub struct CsvConfig { file_projection: Option>, has_header: bool, delimiter: u8, + quote: u8, + escape: Option, object_store: Arc, } @@ -216,6 +264,7 @@ impl CsvConfig { file_projection: Option>, has_header: bool, delimiter: u8, + quote: u8, object_store: Arc, ) -> Self { Self { @@ -224,34 +273,31 @@ impl CsvConfig { file_projection, has_header, delimiter, + quote, + escape: None, object_store, } } } impl CsvConfig { - fn open(&self, reader: R) -> Result> { - let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) - .has_header(self.has_header) - .with_delimiter(self.delimiter) - .with_batch_size(self.batch_size); - - if let Some(p) = &self.file_projection { - builder = builder.with_projection(p.clone()); - } - - Ok(builder.build(reader)?) + fn open(&self, reader: R) -> Result> { + Ok(self.builder().build(reader)?) } fn builder(&self) -> csv::ReaderBuilder { let mut builder = csv::ReaderBuilder::new(self.file_schema.clone()) .with_delimiter(self.delimiter) .with_batch_size(self.batch_size) - .has_header(self.has_header); + .with_header(self.has_header) + .with_quote(self.quote); if let Some(proj) = &self.file_projection { builder = builder.with_projection(proj.clone()); } + if let Some(escape) = self.escape { + builder = builder.with_escape(escape) + } builder } @@ -276,17 +322,162 @@ impl CsvOpener { } } +/// Returns the offset of the first newline in the object store range [start, end), or the end offset if no newline is found. +async fn find_first_newline( + object_store: &Arc, + location: &object_store::path::Path, + start_byte: usize, + end_byte: usize, +) -> Result { + let options = GetOptions { + range: Some(Range { + start: start_byte, + end: end_byte, + }), + ..Default::default() + }; + + let r = object_store.get_opts(location, options).await?; + let mut input = r.into_stream(); + + let mut buffered = Bytes::new(); + let mut index = 0; + + loop { + if buffered.is_empty() { + match input.next().await { + Some(Ok(b)) => buffered = b, + Some(Err(e)) => return Err(e.into()), + None => return Ok(index), + }; + } + + for byte in &buffered { + if *byte == b'\n' { + return Ok(index); + } + index += 1; + } + + buffered.advance(buffered.len()); + } +} + impl FileOpener for CsvOpener { + /// Open a partitioned CSV file. + /// + /// If `file_meta.range` is `None`, the entire file is opened. + /// If `file_meta.range` is `Some(FileRange {start, end})`, this signifies that the partition + /// corresponds to the byte range [start, end) within the file. + /// + /// Note: `start` or `end` might be in the middle of some lines. In such cases, the following rules + /// are applied to determine which lines to read: + /// 1. The first line of the partition is the line in which the index of the first character >= `start`. + /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. + /// + /// Examples: + /// Consider the following partitions enclosed by braces `{}`: + /// + /// {A,1,2,3,4,5,6,7,8,9\n + /// A,1,2,3,4,5,6,7,8,9\n} + /// A,1,2,3,4,5,6,7,8,9\n + /// The lines read would be: [0, 1] + /// + /// A,{1,2,3,4,5,6,7,8,9\n + /// A,1,2,3,4,5,6,7,8,9\n + /// A},1,2,3,4,5,6,7,8,9\n + /// The lines read would be: [1, 2] fn open(&self, file_meta: FileMeta) -> Result { - let config = self.config.clone(); + // `self.config.has_header` controls whether to skip reading the 1st line header + // If the .csv file is read in parallel and this `CsvOpener` is only reading some middle + // partition, then don't skip first line + let mut csv_has_header = self.config.has_header; + if let Some(FileRange { start, .. }) = file_meta.range { + if start != 0 { + csv_has_header = false; + } + } + + let config = CsvConfig { + has_header: csv_has_header, + ..(*self.config).clone() + }; + let file_compression_type = self.file_compression_type.to_owned(); + + if file_meta.range.is_some() { + assert!( + !file_compression_type.is_compressed(), + "Reading compressed .csv in parallel is not supported" + ); + } + Ok(Box::pin(async move { - match config.object_store.get(file_meta.location()).await? { - GetResult::File(file, _) => { - let decoder = file_compression_type.convert_read(file)?; + let file_size = file_meta.object_meta.size; + // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) + let range = match file_meta.range { + None => None, + Some(FileRange { start, end }) => { + let (start, end) = (start as usize, end as usize); + // Partition byte range is [start, end), the boundary might be in the middle of + // some line. Need to find out the exact line boundaries. + let start_delta = if start != 0 { + find_first_newline( + &config.object_store, + file_meta.location(), + start - 1, + file_size, + ) + .await? + } else { + 0 + }; + let end_delta = if end != file_size { + find_first_newline( + &config.object_store, + file_meta.location(), + end - 1, + file_size, + ) + .await? + } else { + 0 + }; + let range = start + start_delta..end + end_delta; + if range.start == range.end { + return Ok( + futures::stream::poll_fn(move |_| Poll::Ready(None)).boxed() + ); + } + Some(range) + } + }; + + let options = GetOptions { + range, + ..Default::default() + }; + let result = config + .object_store + .get_opts(file_meta.location(), options) + .await?; + + match result.payload { + GetResultPayload::File(mut file, _) => { + let is_whole_file_scanned = file_meta.range.is_none(); + let decoder = if is_whole_file_scanned { + // Don't seek if no range as breaks FIFO files + file_compression_type.convert_read(file)? + } else { + file.seek(SeekFrom::Start(result.range.start as _))?; + file_compression_type.convert_read( + file.take((result.range.end - result.range.start) as u64), + )? + }; + Ok(futures::stream::iter(config.open(decoder)?).boxed()) } - GetResult::Stream(s) => { + GetResultPayload::Stream(s) => { let mut decoder = config.builder().build_decoder(); let s = s.map_err(DataFusionError::from); let mut input = @@ -329,56 +520,71 @@ pub async fn plan_to_csv( path: impl AsRef, ) -> Result<()> { let path = path.as_ref(); - // create directory to contain the CSV files (one per partition) - let fs_path = Path::new(path); - if let Err(e) = fs::create_dir(fs_path) { - return Err(DataFusionError::Execution(format!( - "Could not create directory {path}: {e:?}" - ))); - } - - let mut tasks = vec![]; + let parsed = ListingTableUrl::parse(path)?; + let object_store_url = parsed.object_store(); + let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let plan = plan.clone(); - let filename = format!("part-{i}.csv"); - let path = fs_path.join(filename); - let file = fs::File::create(path)?; - let mut writer = csv::Writer::new(file); - let stream = plan.execute(i, task_ctx.clone())?; - - let handle: JoinHandle> = task::spawn(async move { - stream - .map(|batch| writer.write(&batch?)) - .try_collect() + let storeref = store.clone(); + let plan: Arc = plan.clone(); + let filename = format!("{}/part-{i}.csv", parsed.prefix()); + let file = object_store::path::Path::parse(filename)?; + + let mut stream = plan.execute(i, task_ctx.clone())?; + join_set.spawn(async move { + let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buffer = Vec::with_capacity(1024); + //only write headers on first iteration + let mut write_headers = true; + while let Some(batch) = stream.next().await.transpose()? { + let mut writer = csv::WriterBuilder::new() + .with_header(write_headers) + .build(buffer); + writer.write(&batch)?; + buffer = writer.into_inner(); + multipart_writer.write_all(&buffer).await?; + buffer.clear(); + //prevent writing headers more than once + write_headers = false; + } + multipart_writer + .shutdown() .await .map_err(DataFusionError::from) }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } #[cfg(test)] mod tests { use super::*; - use crate::datasource::file_format::file_type::FileType; - use crate::datasource::physical_plan::chunked_store::ChunkedStore; + use crate::dataframe::DataFrameWriteOptions; use crate::prelude::*; use crate::test::{partitioned_csv_config, partitioned_file_groups}; - use crate::test_util::{aggr_test_schema_with_missing_col, arrow_test_data}; use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; + use datafusion_common::test_util::arrow_test_data; + use datafusion_common::FileType; use futures::StreamExt; + use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; - use std::fs::File; + use std::fs::{self, File}; use std::io::Write; use tempfile::TempDir; use url::Url; @@ -391,6 +597,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn csv_exec_with_projection( file_compression_type: FileCompressionType, @@ -400,6 +607,7 @@ mod tests { let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -407,12 +615,20 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), )?; let mut config = partitioned_csv_config(file_schema, file_groups)?; config.projection = Some(vec![0, 2, 4]); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -423,7 +639,7 @@ mod tests { assert_eq!(100, batch.num_rows()); // slice of the first 5 lines - let expected = vec![ + let expected = [ "+----+-----+------------+", "| c1 | c3 | c5 |", "+----+-----+------------+", @@ -447,6 +663,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, @@ -456,6 +673,7 @@ mod tests { let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -463,12 +681,20 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), )?; let mut config = partitioned_csv_config(file_schema, file_groups)?; config.projection = Some(vec![4, 0, 2]); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -479,7 +705,7 @@ mod tests { assert_eq!(100, batch.num_rows()); // slice of the first 5 lines - let expected = vec![ + let expected = [ "+------------+----+-----+", "| c5 | c1 | c3 |", "+------------+----+-----+", @@ -503,6 +729,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn csv_exec_with_limit( file_compression_type: FileCompressionType, @@ -512,6 +739,7 @@ mod tests { let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -519,12 +747,20 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), )?; let mut config = partitioned_csv_config(file_schema, file_groups)?; config.limit = Some(5); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); @@ -534,8 +770,7 @@ mod tests { assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); - let expected = vec![ - "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+", + let expected = ["+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+", "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 |", "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+", "| c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW |", @@ -543,8 +778,7 @@ mod tests { "| b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz |", "| a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB |", "| b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd |", - "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+", - ]; + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+"]; crate::assert_batches_eq!(expected, &[batch]); @@ -559,6 +793,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, @@ -568,6 +803,7 @@ mod tests { let file_schema = aggr_test_schema_with_missing_col(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -575,19 +811,27 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), )?; let mut config = partitioned_csv_config(file_schema, file_groups)?; config.limit = Some(5); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(14, csv.base_config.file_schema.fields().len()); assert_eq!(14, csv.projected_schema.fields().len()); assert_eq!(14, csv.schema().fields().len()); // errors due to https://github.com/apache/arrow-datafusion/issues/4918 let mut it = csv.execute(0, task_ctx)?; - let err = it.next().await.unwrap().unwrap_err().to_string(); + let err = it.next().await.unwrap().unwrap_err().strip_backtrace(); assert_eq!( err, "Arrow error: Csv error: incorrect number of fields for line 1, expected 14 got 13" @@ -603,6 +847,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn csv_exec_with_partition( file_compression_type: FileCompressionType, @@ -612,6 +857,7 @@ mod tests { let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -619,14 +865,14 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), )?; let mut config = partitioned_csv_config(file_schema, file_groups)?; // Add partition columns - config.table_partition_cols = vec![("date".to_owned(), DataType::Utf8)]; - config.file_groups[0][0].partition_values = - vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; + config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; + config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column // Which is supposed to be after the file fields @@ -634,7 +880,14 @@ mod tests { // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); @@ -645,7 +898,7 @@ mod tests { assert_eq!(100, batch.num_rows()); // slice of the first 5 lines - let expected = vec![ + let expected = [ "+----+------------+", "| c1 | date |", "+----+------------+", @@ -699,7 +952,7 @@ mod tests { async fn test_additional_stores( file_compression_type: FileCompressionType, store: Arc, - ) { + ) -> Result<()> { let ctx = SessionContext::new(); let url = Url::parse("file://").unwrap(); ctx.runtime_env().register_object_store(&url, store.clone()); @@ -709,6 +962,7 @@ mod tests { let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( path.as_str(), @@ -716,11 +970,19 @@ mod tests { 1, FileType::CSV, file_compression_type.to_owned(), + tmp_dir.path(), ) .unwrap(); let config = partitioned_csv_config(file_schema, file_groups).unwrap(); - let csv = CsvExec::new(config, true, b',', file_compression_type.to_owned()); + let csv = CsvExec::new( + config, + true, + b',', + b'"', + None, + file_compression_type.to_owned(), + ); let it = csv.execute(0, task_ctx).unwrap(); let batches: Vec<_> = it.try_collect().await.unwrap(); @@ -728,6 +990,7 @@ mod tests { let total_rows = batches.iter().map(|b| b.num_rows()).sum::(); assert_eq!(total_rows, 100); + Ok(()) } #[rstest( @@ -738,11 +1001,12 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn test_chunked_csv( file_compression_type: FileCompressionType, #[values(10, 20, 30, 40)] chunk_size: usize, - ) { + ) -> Result<()> { test_additional_stores( file_compression_type, Arc::new(ChunkedStore::new( @@ -750,7 +1014,8 @@ mod tests { chunk_size, )), ) - .await; + .await?; + Ok(()) } #[tokio::test] @@ -774,7 +1039,7 @@ mod tests { let result = df.collect().await.unwrap(); - let expected = vec![ + let expected = [ "+---+---+", "| a | b |", "+---+---+", @@ -789,17 +1054,23 @@ mod tests { #[tokio::test] async fn write_csv_results_error_handling() -> Result<()> { let ctx = SessionContext::new(); + + // register a local file system object store + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); let options = CsvReadOptions::default() .schema_infer_max_records(2) .has_header(true); let df = ctx.read_csv("tests/data/corrupt.csv", options).await?; - let tmp_dir = TempDir::new()?; - let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + + let out_dir_url = "file://local/out"; let e = df - .write_csv(&out_dir) + .write_csv(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!("Arrow error: Parser error: Error while parsing value d for column 0 at line 4", format!("{e}")); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); Ok(()) } @@ -807,8 +1078,9 @@ mod tests { async fn write_csv_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; @@ -820,10 +1092,19 @@ mod tests { ) .await?; + // register a local file system object store + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + + ctx.runtime_env().register_object_store(&local_url, local); + // execute a simple query and write the results to CSV let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; let df = ctx.sql("SELECT c1, c2 FROM test").await?; - df.write_csv(&out_dir).await?; + df.write_csv(out_dir_url, DataFrameWriteOptions::new(), None) + .await?; // create a new context and verify that the results were saved to a partitioned csv file let ctx = SessionContext::new(); @@ -833,11 +1114,32 @@ mod tests { Field::new("c2", DataType::UInt64, false), ])); + // get name of first part + let paths = fs::read_dir(&out_dir).unwrap(); + let mut part_0_name: String = "".to_owned(); + for path in paths { + let path = path.unwrap(); + let name = path + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + if name.ends_with("_0.csv") { + part_0_name = name; + break; + } + } + + if part_0_name.is_empty() { + panic!("Did not find part_0 in csv output files!") + } // register each partition as well as the top level dir let csv_read_option = CsvReadOptions::new().schema(&schema); ctx.register_csv( "part0", - &format!("{out_dir}/part-0.csv"), + &format!("{out_dir}/{part_0_name}"), csv_read_option.clone(), ) .await?; @@ -870,4 +1172,20 @@ mod tests { } } } + + /// Get the schema for the aggregate_test_* csv files with an additional filed not present in the files. + fn aggr_test_schema_with_missing_col() -> SchemaRef { + let fields = + Fields::from_iter(aggr_test_schema().fields().iter().cloned().chain( + std::iter::once(Arc::new(Field::new( + "missing_col", + DataType::Int64, + true, + ))), + )); + + let schema = Schema::new(fields); + + Arc::new(schema) + } } diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs new file mode 100644 index 0000000000000..d308397ab6e2f --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -0,0 +1,811 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`FileScanConfig`] to configure scanning of possibly partitioned +//! file sources. + +use std::{ + borrow::Cow, cmp::min, collections::HashMap, fmt::Debug, marker::PhantomData, + sync::Arc, vec, +}; + +use super::get_projected_output_ordering; +use crate::datasource::{ + listing::{FileRange, PartitionedFile}, + object_store::ObjectStoreUrl, +}; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +use arrow::array::{ArrayData, BufferBuilder}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowNativeType, UInt16Type}; +use arrow_array::{ArrayRef, DictionaryArray, RecordBatch, RecordBatchOptions}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::stats::Precision; +use datafusion_common::{exec_err, ColumnStatistics, Statistics}; +use datafusion_physical_expr::LexOrdering; + +use itertools::Itertools; +use log::warn; + +/// Convert type to a type suitable for use as a [`ListingTable`] +/// partition column. Returns `Dictionary(UInt16, val_type)`, which is +/// a reasonable trade off between a reasonable number of partition +/// values and space efficiency. +/// +/// This use this to specify types for partition columns. However +/// you MAY also choose not to dictionary-encode the data or to use a +/// different dictionary type. +/// +/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. +/// +/// [`ListingTable`]: crate::datasource::listing::ListingTable +pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) +} + +/// Convert a [`ScalarValue`] of partition columns to a type, as +/// decribed in the documentation of [`wrap_partition_type_in_dict`], +/// which can wrap the types. +pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) +} + +/// The base configurations to provide when creating a physical plan for +/// any given file format. +#[derive(Clone)] +pub struct FileScanConfig { + /// Object store URL, used to get an [`ObjectStore`] instance from + /// [`RuntimeEnv::object_store`] + /// + /// [`ObjectStore`]: object_store::ObjectStore + /// [`RuntimeEnv::object_store`]: datafusion_execution::runtime_env::RuntimeEnv::object_store + pub object_store_url: ObjectStoreUrl, + /// Schema before `projection` is applied. It contains the all columns that may + /// appear in the files. It does not include table partition columns + /// that may be added. + pub file_schema: SchemaRef, + /// List of files to be processed, grouped into partitions + /// + /// Each file must have a schema of `file_schema` or a subset. If + /// a particular file has a subset, the missing columns are + /// padded with NULLs. + /// + /// DataFusion may attempt to read each partition of files + /// concurrently, however files *within* a partition will be read + /// sequentially, one after the next. + pub file_groups: Vec>, + /// Estimated overall statistics of the files, taking `filters` into account. + pub statistics: Statistics, + /// Columns on which to project the data. Indexes that are higher than the + /// number of columns of `file_schema` refer to `table_partition_cols`. + pub projection: Option>, + /// The maximum number of records to read from this plan. If `None`, + /// all records after filtering are returned. + pub limit: Option, + /// The partitioning columns + pub table_partition_cols: Vec, + /// All equivalent lexicographical orderings that describe the schema. + pub output_ordering: Vec, + /// Indicates whether this plan may produce an infinite stream of records. + pub infinite_source: bool, +} + +impl FileScanConfig { + /// Project the schema and the statistics on the given column indices + pub fn project(&self) -> (SchemaRef, Statistics, Vec) { + if self.projection.is_none() && self.table_partition_cols.is_empty() { + return ( + Arc::clone(&self.file_schema), + self.statistics.clone(), + self.output_ordering.clone(), + ); + } + + let proj_iter: Box> = match &self.projection { + Some(proj) => Box::new(proj.iter().copied()), + None => Box::new( + 0..(self.file_schema.fields().len() + self.table_partition_cols.len()), + ), + }; + + let mut table_fields = vec![]; + let mut table_cols_stats = vec![]; + for idx in proj_iter { + if idx < self.file_schema.fields().len() { + let field = self.file_schema.field(idx); + table_fields.push(field.clone()); + table_cols_stats.push(self.statistics.column_statistics[idx].clone()) + } else { + let partition_idx = idx - self.file_schema.fields().len(); + table_fields.push(self.table_partition_cols[partition_idx].to_owned()); + // TODO provide accurate stat for partition column (#1186) + table_cols_stats.push(ColumnStatistics::new_unknown()) + } + } + + let table_stats = Statistics { + num_rows: self.statistics.num_rows.clone(), + // TODO correct byte size? + total_byte_size: Precision::Absent, + column_statistics: table_cols_stats, + }; + + let table_schema = Arc::new( + Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), + ); + let projected_output_ordering = + get_projected_output_ordering(self, &table_schema); + (table_schema, table_stats, projected_output_ordering) + } + + #[allow(unused)] // Only used by avro + pub(crate) fn projected_file_column_names(&self) -> Option> { + self.projection.as_ref().map(|p| { + p.iter() + .filter(|col_idx| **col_idx < self.file_schema.fields().len()) + .map(|col_idx| self.file_schema.field(*col_idx).name()) + .cloned() + .collect() + }) + } + + pub(crate) fn file_column_projection_indices(&self) -> Option> { + self.projection.as_ref().map(|p| { + p.iter() + .filter(|col_idx| **col_idx < self.file_schema.fields().len()) + .copied() + .collect() + }) + } + + /// Repartition all input files into `target_partitions` partitions, if total file size exceed + /// `repartition_file_min_size` + /// `target_partitions` and `repartition_file_min_size` directly come from configuration. + /// + /// This function only try to partition file byte range evenly, and let specific `FileOpener` to + /// do actual partition on specific data source type. (e.g. `CsvOpener` will only read lines + /// overlap with byte range but also handle boundaries to ensure all lines will be read exactly once) + pub fn repartition_file_groups( + file_groups: Vec>, + target_partitions: usize, + repartition_file_min_size: usize, + ) -> Option>> { + let flattened_files = file_groups.iter().flatten().collect::>(); + + // Perform redistribution only in case all files should be read from beginning to end + let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); + if has_ranges { + return None; + } + + let total_size = flattened_files + .iter() + .map(|f| f.object_meta.size as i64) + .sum::(); + if total_size < (repartition_file_min_size as i64) || total_size == 0 { + return None; + } + + let target_partition_size = + (total_size as usize + (target_partitions) - 1) / (target_partitions); + + let current_partition_index: usize = 0; + let current_partition_size: usize = 0; + + // Partition byte range evenly for all `PartitionedFile`s + let repartitioned_files = flattened_files + .into_iter() + .scan( + (current_partition_index, current_partition_size), + |state, source_file| { + let mut produced_files = vec![]; + let mut range_start = 0; + while range_start < source_file.object_meta.size { + let range_end = min( + range_start + (target_partition_size - state.1), + source_file.object_meta.size, + ); + + let mut produced_file = source_file.clone(); + produced_file.range = Some(FileRange { + start: range_start as i64, + end: range_end as i64, + }); + produced_files.push((state.0, produced_file)); + + if state.1 + (range_end - range_start) >= target_partition_size { + state.0 += 1; + state.1 = 0; + } else { + state.1 += range_end - range_start; + } + range_start = range_end; + } + Some(produced_files) + }, + ) + .flatten() + .group_by(|(partition_idx, _)| *partition_idx) + .into_iter() + .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) + .collect_vec(); + + Some(repartitioned_files) + } +} + +/// A helper that projects partition columns into the file record batches. +/// +/// One interesting trick is the usage of a cache for the key buffers of the partition column +/// dictionaries. Indeed, the partition columns are constant, so the dictionaries that represent them +/// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, +/// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). +pub struct PartitionColumnProjector { + /// An Arrow buffer initialized to zeros that represents the key array of all partition + /// columns (partition columns are materialized by dictionary arrays with only one + /// value in the dictionary, thus all the keys are equal to zero). + key_buffer_cache: ZeroBufferGenerators, + /// Mapping between the indexes in the list of partition columns and the target + /// schema. Sorted by index in the target schema so that we can iterate on it to + /// insert the partition columns in the target record batch. + projected_partition_indexes: Vec<(usize, usize)>, + /// The schema of the table once the projection was applied. + projected_schema: SchemaRef, +} + +impl PartitionColumnProjector { + // Create a projector to insert the partitioning columns into batches read from files + // - `projected_schema`: the target schema with both file and partitioning columns + // - `table_partition_cols`: all the partitioning column names + pub fn new(projected_schema: SchemaRef, table_partition_cols: &[String]) -> Self { + let mut idx_map = HashMap::new(); + for (partition_idx, partition_name) in table_partition_cols.iter().enumerate() { + if let Ok(schema_idx) = projected_schema.index_of(partition_name) { + idx_map.insert(partition_idx, schema_idx); + } + } + + let mut projected_partition_indexes: Vec<_> = idx_map.into_iter().collect(); + projected_partition_indexes.sort_by(|(_, a), (_, b)| a.cmp(b)); + + Self { + projected_partition_indexes, + key_buffer_cache: Default::default(), + projected_schema, + } + } + + // Transform the batch read from the file by inserting the partitioning columns + // to the right positions as deduced from `projected_schema` + // - `file_batch`: batch read from the file, with internal projection applied + // - `partition_values`: the list of partition values, one for each partition column + pub fn project( + &mut self, + file_batch: RecordBatch, + partition_values: &[ScalarValue], + ) -> Result { + let expected_cols = + self.projected_schema.fields().len() - self.projected_partition_indexes.len(); + + if file_batch.columns().len() != expected_cols { + return exec_err!( + "Unexpected batch schema from file, expected {} cols but got {}", + expected_cols, + file_batch.columns().len() + ); + } + let mut cols = file_batch.columns().to_vec(); + for &(pidx, sidx) in &self.projected_partition_indexes { + let mut partition_value = Cow::Borrowed(&partition_values[pidx]); + + // check if user forgot to dict-encode the partition value + let field = self.projected_schema.field(sidx); + let expected_data_type = field.data_type(); + let actual_data_type = partition_value.data_type(); + if let DataType::Dictionary(key_type, _) = expected_data_type { + if !matches!(actual_data_type, DataType::Dictionary(_, _)) { + warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); + partition_value = Cow::Owned(ScalarValue::Dictionary( + key_type.clone(), + Box::new(partition_value.as_ref().clone()), + )); + } + } + + cols.insert( + sidx, + create_output_array( + &mut self.key_buffer_cache, + partition_value.as_ref(), + file_batch.num_rows(), + )?, + ) + } + + RecordBatch::try_new_with_options( + Arc::clone(&self.projected_schema), + cols, + &RecordBatchOptions::new().with_row_count(Some(file_batch.num_rows())), + ) + .map_err(Into::into) + } +} + +#[derive(Debug, Default)] +struct ZeroBufferGenerators { + gen_i8: ZeroBufferGenerator, + gen_i16: ZeroBufferGenerator, + gen_i32: ZeroBufferGenerator, + gen_i64: ZeroBufferGenerator, + gen_u8: ZeroBufferGenerator, + gen_u16: ZeroBufferGenerator, + gen_u32: ZeroBufferGenerator, + gen_u64: ZeroBufferGenerator, +} + +/// Generate a arrow [`Buffer`] that contains zero values. +#[derive(Debug, Default)] +struct ZeroBufferGenerator +where + T: ArrowNativeType, +{ + cache: Option, + _t: PhantomData, +} + +impl ZeroBufferGenerator +where + T: ArrowNativeType, +{ + const SIZE: usize = std::mem::size_of::(); + + fn get_buffer(&mut self, n_vals: usize) -> Buffer { + match &mut self.cache { + Some(buf) if buf.len() >= n_vals * Self::SIZE => { + buf.slice_with_length(0, n_vals * Self::SIZE) + } + _ => { + let mut key_buffer_builder = BufferBuilder::::new(n_vals); + key_buffer_builder.advance(n_vals); // keys are all 0 + self.cache.insert(key_buffer_builder.finish()).clone() + } + } + } +} + +fn create_dict_array( + buffer_gen: &mut ZeroBufferGenerator, + dict_val: &ScalarValue, + len: usize, + data_type: DataType, +) -> Result +where + T: ArrowNativeType, +{ + let dict_vals = dict_val.to_array()?; + + let sliced_key_buffer = buffer_gen.get_buffer(len); + + // assemble pieces together + let mut builder = ArrayData::builder(data_type) + .len(len) + .add_buffer(sliced_key_buffer); + builder = builder.add_child_data(dict_vals.to_data()); + Ok(Arc::new(DictionaryArray::::from( + builder.build().unwrap(), + ))) +} + +fn create_output_array( + key_buffer_cache: &mut ZeroBufferGenerators, + val: &ScalarValue, + len: usize, +) -> Result { + if let ScalarValue::Dictionary(key_type, dict_val) = &val { + match key_type.as_ref() { + DataType::Int8 => { + return create_dict_array( + &mut key_buffer_cache.gen_i8, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int16 => { + return create_dict_array( + &mut key_buffer_cache.gen_i16, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int32 => { + return create_dict_array( + &mut key_buffer_cache.gen_i32, + dict_val, + len, + val.data_type(), + ); + } + DataType::Int64 => { + return create_dict_array( + &mut key_buffer_cache.gen_i64, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt8 => { + return create_dict_array( + &mut key_buffer_cache.gen_u8, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt16 => { + return create_dict_array( + &mut key_buffer_cache.gen_u16, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt32 => { + return create_dict_array( + &mut key_buffer_cache.gen_u32, + dict_val, + len, + val.data_type(), + ); + } + DataType::UInt64 => { + return create_dict_array( + &mut key_buffer_cache.gen_u64, + dict_val, + len, + val.data_type(), + ); + } + _ => {} + } + } + + val.to_array_of_size(len) +} + +#[cfg(test)] +mod tests { + use arrow_array::Int32Array; + + use super::*; + use crate::{test::columns, test_util::aggr_test_schema}; + + #[test] + fn physical_plan_config_no_projection() { + let file_schema = aggr_test_schema(); + let conf = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + to_partition_cols(vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )]), + ); + + let (proj_schema, proj_statistics, _) = conf.project(); + assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); + assert_eq!( + proj_schema.field(file_schema.fields().len()).name(), + "date", + "partition columns are the last columns" + ); + assert_eq!( + proj_statistics.column_statistics.len(), + file_schema.fields().len() + 1 + ); + // TODO implement tests for partition column statistics once implemented + + let col_names = conf.projected_file_column_names(); + assert_eq!(col_names, None); + + let col_indices = conf.file_column_projection_indices(); + assert_eq!(col_indices, None); + } + + #[test] + fn physical_plan_config_no_projection_tab_cols_as_field() { + let file_schema = aggr_test_schema(); + + // make a table_partition_col as a field + let table_partition_col = + Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), true) + .with_metadata(HashMap::from_iter(vec![( + "key_whatever".to_owned(), + "value_whatever".to_owned(), + )])); + + let conf = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + vec![table_partition_col.clone()], + ); + + // verify the proj_schema inlcudes the last column and exactly the same the field it is defined + let (proj_schema, _proj_statistics, _) = conf.project(); + assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); + assert_eq!( + *proj_schema.field(file_schema.fields().len()), + table_partition_col, + "partition columns are the last columns and ust have all values defined in created field" + ); + } + + #[test] + fn physical_plan_config_with_projection() { + let file_schema = aggr_test_schema(); + let conf = config_for_projection( + Arc::clone(&file_schema), + Some(vec![file_schema.fields().len(), 0]), + Statistics { + num_rows: Precision::Inexact(10), + // assign the column index to distinct_count to help assert + // the source statistic after the projection + column_statistics: (0..file_schema.fields().len()) + .map(|i| ColumnStatistics { + distinct_count: Precision::Inexact(i), + ..Default::default() + }) + .collect(), + total_byte_size: Precision::Absent, + }, + to_partition_cols(vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )]), + ); + + let (proj_schema, proj_statistics, _) = conf.project(); + assert_eq!( + columns(&proj_schema), + vec!["date".to_owned(), "c1".to_owned()] + ); + let proj_stat_cols = proj_statistics.column_statistics; + assert_eq!(proj_stat_cols.len(), 2); + // TODO implement tests for proj_stat_cols[0] once partition column + // statistics are implemented + assert_eq!(proj_stat_cols[1].distinct_count, Precision::Inexact(0)); + + let col_names = conf.projected_file_column_names(); + assert_eq!(col_names, Some(vec!["c1".to_owned()])); + + let col_indices = conf.file_column_projection_indices(); + assert_eq!(col_indices, Some(vec![0])); + } + + #[test] + fn partition_column_projector() { + let file_batch = build_table_i32( + ("a", &vec![0, 1, 2]), + ("b", &vec![-2, -1, 0]), + ("c", &vec![10, 11, 12]), + ); + let partition_cols = vec![ + ( + "year".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "month".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "day".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ]; + // create a projected schema + let conf = config_for_projection( + file_batch.schema(), + // keep all cols from file and 2 from partitioning + Some(vec![ + 0, + 1, + 2, + file_batch.schema().fields().len(), + file_batch.schema().fields().len() + 2, + ]), + Statistics::new_unknown(&file_batch.schema()), + to_partition_cols(partition_cols.clone()), + ); + let (proj_schema, ..) = conf.project(); + // created a projector for that projected schema + let mut proj = PartitionColumnProjector::new( + proj_schema, + &partition_cols + .iter() + .map(|x| x.0.clone()) + .collect::>(), + ); + + // project first batch + let projected_batch = proj + .project( + // file_batch is ok here because we kept all the file cols in the projection + file_batch, + &[ + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("26")), + ], + ) + .expect("Projection of partition columns into record batch failed"); + let expected = [ + "+---+----+----+------+-----+", + "| a | b | c | year | day |", + "+---+----+----+------+-----+", + "| 0 | -2 | 10 | 2021 | 26 |", + "| 1 | -1 | 11 | 2021 | 26 |", + "| 2 | 0 | 12 | 2021 | 26 |", + "+---+----+----+------+-----+", + ]; + crate::assert_batches_eq!(expected, &[projected_batch]); + + // project another batch that is larger than the previous one + let file_batch = build_table_i32( + ("a", &vec![5, 6, 7, 8, 9]), + ("b", &vec![-10, -9, -8, -7, -6]), + ("c", &vec![12, 13, 14, 15, 16]), + ); + let projected_batch = proj + .project( + // file_batch is ok here because we kept all the file cols in the projection + file_batch, + &[ + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("27")), + ], + ) + .expect("Projection of partition columns into record batch failed"); + let expected = [ + "+---+-----+----+------+-----+", + "| a | b | c | year | day |", + "+---+-----+----+------+-----+", + "| 5 | -10 | 12 | 2021 | 27 |", + "| 6 | -9 | 13 | 2021 | 27 |", + "| 7 | -8 | 14 | 2021 | 27 |", + "| 8 | -7 | 15 | 2021 | 27 |", + "| 9 | -6 | 16 | 2021 | 27 |", + "+---+-----+----+------+-----+", + ]; + crate::assert_batches_eq!(expected, &[projected_batch]); + + // project another batch that is smaller than the previous one + let file_batch = build_table_i32( + ("a", &vec![0, 1, 3]), + ("b", &vec![2, 3, 4]), + ("c", &vec![4, 5, 6]), + ); + let projected_batch = proj + .project( + // file_batch is ok here because we kept all the file cols in the projection + file_batch, + &[ + wrap_partition_value_in_dict(ScalarValue::from("2021")), + wrap_partition_value_in_dict(ScalarValue::from("10")), + wrap_partition_value_in_dict(ScalarValue::from("28")), + ], + ) + .expect("Projection of partition columns into record batch failed"); + let expected = [ + "+---+---+---+------+-----+", + "| a | b | c | year | day |", + "+---+---+---+------+-----+", + "| 0 | 2 | 4 | 2021 | 28 |", + "| 1 | 3 | 5 | 2021 | 28 |", + "| 3 | 4 | 6 | 2021 | 28 |", + "+---+---+---+------+-----+", + ]; + crate::assert_batches_eq!(expected, &[projected_batch]); + + // forgot to dictionary-wrap the scalar value + let file_batch = build_table_i32( + ("a", &vec![0, 1, 2]), + ("b", &vec![-2, -1, 0]), + ("c", &vec![10, 11, 12]), + ); + let projected_batch = proj + .project( + // file_batch is ok here because we kept all the file cols in the projection + file_batch, + &[ + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), + ], + ) + .expect("Projection of partition columns into record batch failed"); + let expected = [ + "+---+----+----+------+-----+", + "| a | b | c | year | day |", + "+---+----+----+------+-----+", + "| 0 | -2 | 10 | 2021 | 26 |", + "| 1 | -1 | 11 | 2021 | 26 |", + "| 2 | 0 | 12 | 2021 | 26 |", + "+---+----+----+------+-----+", + ]; + crate::assert_batches_eq!(expected, &[projected_batch]); + } + + // sets default for configs that play no role in projections + fn config_for_projection( + file_schema: SchemaRef, + projection: Option>, + statistics: Statistics, + table_partition_cols: Vec, + ) -> FileScanConfig { + FileScanConfig { + file_schema, + file_groups: vec![vec![]], + limit: None, + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + projection, + statistics, + table_partition_cols, + output_ordering: vec![], + infinite_source: false, + } + } + + /// Convert partition columns from Vec to Vec + fn to_partition_cols(table_partition_cols: Vec<(String, DataType)>) -> Vec { + table_partition_cols + .iter() + .map(|(name, dtype)| Field::new(name, dtype.clone(), false)) + .collect::>() + } + + /// returns record batch with 3 columns of i32 in memory + pub fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() + } +} diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 2c4437de0a921..a715f6e8e3cde 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -112,7 +112,7 @@ enum FileStreamState { /// The idle state, no file is currently being read Idle, /// Currently performing asynchronous IO to obtain a stream of RecordBatch - /// for a given parquet file + /// for a given file Open { /// A [`FileOpenFuture`] returned by [`FileOpener::open`] future: FileOpenFuture, @@ -259,7 +259,7 @@ impl FileStream { &config .table_partition_cols .iter() - .map(|x| x.0.clone()) + .map(|x| x.name().clone()) .collect::>(), ); @@ -519,10 +519,12 @@ impl RecordBatchStream for FileStream { #[cfg(test)] mod tests { use arrow_schema::Schema; + use datafusion_common::internal_err; use datafusion_common::DataFusionError; + use datafusion_common::Statistics; use super::*; - use crate::datasource::file_format::BatchSerializer; + use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::FileMeta; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; @@ -557,12 +559,9 @@ mod tests { let idx = self.current_idx.fetch_add(1, Ordering::SeqCst); if self.error_opening_idx.contains(&idx) { - Ok(futures::future::ready(Err(DataFusionError::Internal( - "error opening".to_owned(), - ))) - .boxed()) + Ok(futures::future::ready(internal_err!("error opening")).boxed()) } else if self.error_scanning_idx.contains(&idx) { - let error = futures::future::ready(Err(ArrowError::IoError( + let error = futures::future::ready(Err(ArrowError::IpcError( "error scanning".to_owned(), ))); let stream = futures::stream::once(error).boxed(); @@ -661,9 +660,9 @@ mod tests { let config = FileScanConfig { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + statistics: Statistics::new_unknown(&file_schema), file_schema, file_groups: vec![file_group], - statistics: Default::default(), projection: None, limit: self.limit, table_partition_cols: vec![], diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 8340c282a01ee..9c3b523a652c9 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -16,37 +16,38 @@ // under the License. //! Execution plan for reading line-delimited JSON files -use crate::datasource::file_format::file_type::FileCompressionType; + +use std::any::Any; +use std::io::BufReader; +use std::sync::Arc; +use std::task::Poll; + +use super::FileScanConfig; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::FileMeta; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ - ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }; -use datafusion_execution::TaskContext; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; -use datafusion_physical_expr::{LexOrdering, OrderingEquivalenceProperties}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, stream, StreamExt, TryStreamExt}; -use object_store::{GetResult, ObjectStore}; -use std::any::Any; -use std::fs; -use std::io::BufReader; -use std::path::Path; -use std::sync::Arc; -use std::task::Poll; -use tokio::task::{self, JoinHandle}; - -use super::FileScanConfig; +use object_store; +use object_store::{GetResultPayload, ObjectStore}; +use tokio::io::AsyncWriteExt; +use tokio::task::JoinSet; /// Execution plan for scanning NdJson data source #[derive(Debug, Clone)] @@ -85,6 +86,17 @@ impl NdJsonExec { } } +impl DisplayAs for NdJsonExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "JsonExec: ")?; + self.base_config.fmt_as(t, f) + } +} + impl ExecutionPlan for NdJsonExec { fn as_any(&self) -> &dyn Any { self @@ -108,8 +120,8 @@ impl ExecutionPlan for NdJsonExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -150,20 +162,8 @@ impl ExecutionPlan for NdJsonExec { Ok(Box::pin(stream) as SendableRecordBatchStream) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "JsonExec: {}", self.base_config) - } - } - } - - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } fn metrics(&self) -> Option { @@ -204,15 +204,16 @@ impl FileOpener for JsonOpener { let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - match store.get(file_meta.location()).await? { - GetResult::File(file, _) => { + let r = store.get(file_meta.location()).await?; + match r.payload { + GetResultPayload::File(file, _) => { let bytes = file_compression_type.convert_read(file)?; let reader = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build(BufReader::new(bytes))?; Ok(futures::stream::iter(reader).boxed()) } - GetResult::Stream(s) => { + GetResultPayload::Stream(s) => { let mut decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; @@ -261,38 +262,49 @@ pub async fn plan_to_json( path: impl AsRef, ) -> Result<()> { let path = path.as_ref(); - // create directory to contain the CSV files (one per partition) - let fs_path = Path::new(path); - if let Err(e) = fs::create_dir(fs_path) { - return Err(DataFusionError::Execution(format!( - "Could not create directory {path}: {e:?}" - ))); - } - - let mut tasks = vec![]; + let parsed = ListingTableUrl::parse(path)?; + let object_store_url = parsed.object_store(); + let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let plan = plan.clone(); - let filename = format!("part-{i}.json"); - let path = fs_path.join(filename); - let file = fs::File::create(path)?; - let mut writer = json::LineDelimitedWriter::new(file); - let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle> = task::spawn(async move { - stream - .map(|batch| writer.write(&batch?)) - .try_collect() + let storeref = store.clone(); + let plan: Arc = plan.clone(); + let filename = format!("{}/part-{i}.json", parsed.prefix()); + let file = object_store::path::Path::parse(filename)?; + + let mut stream = plan.execute(i, task_ctx.clone())?; + join_set.spawn(async move { + let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + + let mut buffer = Vec::with_capacity(1024); + while let Some(batch) = stream.next().await.transpose()? { + let mut writer = json::LineDelimitedWriter::new(buffer); + writer.write(&batch)?; + buffer = writer.into_inner(); + multipart_writer.write_all(&buffer).await?; + buffer.clear(); + } + + multipart_writer + .shutdown() .await .map_err(DataFusionError::from) }); - tasks.push(AbortOnDropSingle::new(handle)); } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, // propagate DataFusion error + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + Ok(()) } @@ -304,17 +316,21 @@ mod tests { use object_store::local::LocalFileSystem; use crate::assert_batches_eq; - use crate::datasource::file_format::file_type::FileType; + use crate::dataframe::DataFrameWriteOptions; + use crate::datasource::file_format::file_compression_type::FileTypeExt; use crate::datasource::file_format::{json::JsonFormat, FileFormat}; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::chunked_store::ChunkedStore; use crate::execution::context::SessionState; use crate::prelude::NdJsonReadOptions; use crate::prelude::*; use crate::test::partitioned_file_groups; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; + use datafusion_common::FileType; + use object_store::chunked::ChunkedStore; use rstest::*; + use std::fs; + use std::path::Path; use tempfile::TempDir; use url::Url; @@ -325,6 +341,7 @@ mod tests { async fn prepare_store( state: &SessionState, file_compression_type: FileCompressionType, + work_dir: &Path, ) -> (ObjectStoreUrl, Vec>, SchemaRef) { let store_url = ObjectStoreUrl::local_filesystem(); let store = state.runtime_env().object_store(&store_url).unwrap(); @@ -336,12 +353,13 @@ mod tests { 1, FileType::JSON, file_compression_type.to_owned(), + work_dir, ) .unwrap(); let meta = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .clone() .object_meta; @@ -357,23 +375,25 @@ mod tests { async fn test_additional_stores( file_compression_type: FileCompressionType, store: Arc, - ) { + ) -> Result<()> { let ctx = SessionContext::new(); let url = Url::parse("file://").unwrap(); ctx.runtime_env().register_object_store(&url, store.clone()); let filename = "1.json"; + let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, FileType::JSON, file_compression_type.to_owned(), + tmp_dir.path(), ) .unwrap(); let path = file_groups - .get(0) + .first() .unwrap() - .get(0) + .first() .unwrap() .object_meta .location @@ -407,6 +427,7 @@ mod tests { ], &results ); + Ok(()) } #[rstest( @@ -417,6 +438,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn nd_json_exec_file_without_projection( file_compression_type: FileCompressionType, @@ -426,15 +448,16 @@ mod tests { let task_ctx = session_ctx.task_ctx(); use arrow::datatypes::DataType; + let tmp_dir = TempDir::new()?; let (object_store_url, file_groups, file_schema) = - prepare_store(&state, file_compression_type.to_owned()).await; + prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: Some(3), table_partition_cols: vec![], @@ -488,6 +511,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn nd_json_exec_file_with_missing_column( file_compression_type: FileCompressionType, @@ -496,8 +520,10 @@ mod tests { let state = session_ctx.state(); let task_ctx = session_ctx.task_ctx(); use arrow::datatypes::DataType; + + let tmp_dir = TempDir::new()?; let (object_store_url, file_groups, actual_schema) = - prepare_store(&state, file_compression_type.to_owned()).await; + prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let mut builder = SchemaBuilder::from(actual_schema.fields()); builder.push(Field::new("missing_col", DataType::Int32, true)); @@ -509,8 +535,8 @@ mod tests { FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: Some(3), table_partition_cols: vec![], @@ -541,6 +567,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn nd_json_exec_file_projection( file_compression_type: FileCompressionType, @@ -548,15 +575,16 @@ mod tests { let session_ctx = SessionContext::new(); let state = session_ctx.state(); let task_ctx = session_ctx.task_ctx(); + let tmp_dir = TempDir::new()?; let (object_store_url, file_groups, file_schema) = - prepare_store(&state, file_compression_type.to_owned()).await; + prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![0, 2]), limit: None, table_partition_cols: vec![], @@ -592,6 +620,7 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn nd_json_exec_file_mixed_order_projection( file_compression_type: FileCompressionType, @@ -599,15 +628,16 @@ mod tests { let session_ctx = SessionContext::new(); let state = session_ctx.state(); let task_ctx = session_ctx.task_ctx(); + let tmp_dir = TempDir::new()?; let (object_store_url, file_groups, file_schema) = - prepare_store(&state, file_compression_type.to_owned()).await; + prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( FileScanConfig { object_store_url, file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![3, 0, 2]), limit: None, table_partition_cols: vec![], @@ -644,9 +674,9 @@ mod tests { #[tokio::test] async fn write_json_results() -> Result<()> { // create partitioned input file and context - let tmp_dir = TempDir::new()?; - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let path = format!("{TEST_DATA_BASE}/1.json"); @@ -654,19 +684,49 @@ mod tests { ctx.register_json("test", path.as_str(), NdJsonReadOptions::default()) .await?; + // register a local file system object store for /tmp directory + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + // execute a simple query and write the results to CSV let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; let df = ctx.sql("SELECT a, b FROM test").await?; - df.write_json(&out_dir).await?; + df.write_json(out_dir_url, DataFrameWriteOptions::new()) + .await?; // create a new context and verify that the results were saved to a partitioned csv file let ctx = SessionContext::new(); + // get name of first part + let paths = fs::read_dir(&out_dir).unwrap(); + let mut part_0_name: String = "".to_owned(); + for path in paths { + let name = path + .unwrap() + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + if name.ends_with("_0.json") { + part_0_name = name; + break; + } + } + + if part_0_name.is_empty() { + panic!("Did not find part_0 in json output files!") + } + // register each partition as well as the top level dir let json_read_option = NdJsonReadOptions::default(); ctx.register_json( "part0", - &format!("{out_dir}/part-0.json"), + &format!("{out_dir}/{part_0_name}"), json_read_option.clone(), ) .await?; @@ -697,11 +757,12 @@ mod tests { case(FileCompressionType::XZ), case(FileCompressionType::ZSTD) )] + #[cfg(feature = "compression")] #[tokio::test] async fn test_chunked_json( file_compression_type: FileCompressionType, #[values(10, 20, 30, 40)] chunk_size: usize, - ) { + ) -> Result<()> { test_additional_stores( file_compression_type, Arc::new(ChunkedStore::new( @@ -709,23 +770,58 @@ mod tests { chunk_size, )), ) - .await; + .await?; + Ok(()) } #[tokio::test] async fn write_json_results_error_handling() -> Result<()> { let ctx = SessionContext::new(); + // register a local file system object store for /tmp directory + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); let options = CsvReadOptions::default() .schema_infer_max_records(2) .has_header(true); let df = ctx.read_csv("tests/data/corrupt.csv", options).await?; - let tmp_dir = TempDir::new()?; - let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; let e = df - .write_json(&out_dir) + .write_json(out_dir_url, DataFrameWriteOptions::new()) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!("Arrow error: Parser error: Error while parsing value d for column 0 at line 4", format!("{e}")); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); + Ok(()) + } + + #[tokio::test] + async fn ndjson_schema_infer_max_records() -> Result<()> { + async fn read_test_data(schema_infer_max_records: usize) -> Result { + let ctx = SessionContext::new(); + + let options = NdJsonReadOptions { + schema_infer_max_records, + ..Default::default() + }; + + let batches = ctx + .read_json("tests/data/4.json", options) + .await? + .collect() + .await?; + + Ok(batches[0].schema()) + } + + // Use only the first 2 rows to infer the schema, those have 2 fields. + let schema = read_test_data(2).await?; + assert_eq!(schema.fields().len(), 2); + + // Use all rows to infer the schema, those have 5 fields. + let schema = read_test_data(10).await?; + assert_eq!(schema.fields().len(), 5); + Ok(()) } } diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 54b916788c676..14e550eab1d55 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -19,241 +19,84 @@ mod arrow_file; mod avro; -#[cfg(test)] -mod chunked_store; mod csv; +mod file_scan_config; mod file_stream; mod json; +#[cfg(feature = "parquet")] pub mod parquet; pub(crate) use self::csv::plan_to_csv; pub use self::csv::{CsvConfig, CsvExec, CsvOpener}; -pub(crate) use self::parquet::plan_to_parquet; +pub(crate) use self::json::plan_to_json; +#[cfg(feature = "parquet")] pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory}; -use arrow::{ - array::{new_null_array, ArrayData, ArrayRef, BufferBuilder, DictionaryArray}, - buffer::Buffer, - compute::can_cast_types, - datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, - record_batch::{RecordBatch, RecordBatchOptions}, -}; + pub use arrow_file::ArrowExec; pub use avro::AvroExec; -use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use file_scan_config::PartitionColumnProjector; +pub use file_scan_config::{ + wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, +}; pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; -pub(crate) use json::plan_to_json; pub use json::{JsonOpener, NdJsonExec}; -use crate::datasource::file_format::FileWriterMode; -use crate::datasource::{ - listing::{FileRange, PartitionedFile}, - object_store::ObjectStoreUrl, +use std::{ + fmt::{Debug, Formatter, Result as FmtResult}, + sync::Arc, + vec, }; -use crate::physical_plan::ExecutionPlan; + +use super::listing::ListingTableUrl; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{DisplayAs, DisplayFormatType}; use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, + datasource::{ + listing::{FileRange, PartitionedFile}, + object_store::ObjectStoreUrl, + }, + physical_plan::display::{OutputOrderingDisplay, ProjectSchemaDisplay}, }; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use arrow::{ + array::new_null_array, + compute::{can_cast_types, cast}, + datatypes::{DataType, Schema, SchemaRef}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use datafusion_common::{file_options::FileTypeWriterOptions, plan_err}; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_plan::ExecutionPlan; -use arrow::compute::cast; -use log::{debug, warn}; +use log::debug; use object_store::path::Path; use object_store::ObjectMeta; -use std::{ - borrow::Cow, - collections::HashMap, - fmt::{Debug, Display, Formatter, Result as FmtResult}, - marker::PhantomData, - sync::Arc, - vec, -}; - -use super::{ColumnStatistics, Statistics}; - -/// Convert type to a type suitable for use as a [`ListingTable`] -/// partition column. Returns `Dictionary(UInt16, val_type)`, which is -/// a reasonable trade off between a reasonable number of partition -/// values and space efficiency. -/// -/// This use this to specify types for partition columns. However -/// you MAY also choose not to dictionary-encode the data or to use a -/// different dictionary type. -/// -/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. -/// -/// [`ListingTable`]: crate::datasource::listing::ListingTable -pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) -} - -/// Convert a [`ScalarValue`] of partition columns to a type, as -/// decribed in the documentation of [`wrap_partition_type_in_dict`], -/// which can wrap the types. -pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) -} - -/// Get all of the [`PartitionedFile`] to be scanned for an [`ExecutionPlan`] -pub fn get_scan_files( - plan: Arc, -) -> Result>>> { - let mut collector: Vec>> = vec![]; - plan.apply(&mut |plan| { - let plan_any = plan.as_any(); - let file_groups = - if let Some(parquet_exec) = plan_any.downcast_ref::() { - parquet_exec.base_config().file_groups.clone() - } else if let Some(avro_exec) = plan_any.downcast_ref::() { - avro_exec.base_config().file_groups.clone() - } else if let Some(json_exec) = plan_any.downcast_ref::() { - json_exec.base_config().file_groups.clone() - } else if let Some(csv_exec) = plan_any.downcast_ref::() { - csv_exec.base_config().file_groups.clone() - } else { - return Ok(VisitRecursion::Continue); - }; - - collector.push(file_groups); - Ok(VisitRecursion::Skip) - })?; - Ok(collector) -} - -/// The base configurations to provide when creating a physical plan for -/// any given file format. -#[derive(Clone)] -pub struct FileScanConfig { - /// Object store URL, used to get an [`ObjectStore`] instance from - /// [`RuntimeEnv::object_store`] - /// - /// [`ObjectStore`]: object_store::ObjectStore - /// [`RuntimeEnv::object_store`]: datafusion_execution::runtime_env::RuntimeEnv::object_store - pub object_store_url: ObjectStoreUrl, - /// Schema before `projection` is applied. It contains the all columns that may - /// appear in the files. It does not include table partition columns - /// that may be added. - pub file_schema: SchemaRef, - /// List of files to be processed, grouped into partitions - /// - /// Each file must have a schema of `file_schema` or a subset. If - /// a particular file has a subset, the missing columns are - /// padded with with NULLs. - /// - /// DataFusion may attempt to read each partition of files - /// concurrently, however files *within* a partition will be read - /// sequentially, one after the next. - pub file_groups: Vec>, - /// Estimated overall statistics of the files, taking `filters` into account. - pub statistics: Statistics, - /// Columns on which to project the data. Indexes that are higher than the - /// number of columns of `file_schema` refer to `table_partition_cols`. - pub projection: Option>, - /// The maximum number of records to read from this plan. If `None`, - /// all records after filtering are returned. - pub limit: Option, - /// The partitioning columns - pub table_partition_cols: Vec<(String, DataType)>, - /// All equivalent lexicographical orderings that describe the schema. - pub output_ordering: Vec, - /// Indicates whether this plan may produce an infinite stream of records. - pub infinite_source: bool, -} - -impl FileScanConfig { - /// Project the schema and the statistics on the given column indices - fn project(&self) -> (SchemaRef, Statistics, Vec) { - if self.projection.is_none() && self.table_partition_cols.is_empty() { - return ( - Arc::clone(&self.file_schema), - self.statistics.clone(), - self.output_ordering.clone(), - ); - } - - let proj_iter: Box> = match &self.projection { - Some(proj) => Box::new(proj.iter().copied()), - None => Box::new( - 0..(self.file_schema.fields().len() + self.table_partition_cols.len()), - ), - }; - - let mut table_fields = vec![]; - let mut table_cols_stats = vec![]; - for idx in proj_iter { - if idx < self.file_schema.fields().len() { - table_fields.push(self.file_schema.field(idx).clone()); - if let Some(file_cols_stats) = &self.statistics.column_statistics { - table_cols_stats.push(file_cols_stats[idx].clone()) - } else { - table_cols_stats.push(ColumnStatistics::default()) - } - } else { - let partition_idx = idx - self.file_schema.fields().len(); - table_fields.push(Field::new( - &self.table_partition_cols[partition_idx].0, - self.table_partition_cols[partition_idx].1.to_owned(), - false, - )); - // TODO provide accurate stat for partition column (#1186) - table_cols_stats.push(ColumnStatistics::default()) - } - } - - let table_stats = Statistics { - num_rows: self.statistics.num_rows, - is_exact: self.statistics.is_exact, - // TODO correct byte size? - total_byte_size: None, - column_statistics: Some(table_cols_stats), - }; - - let table_schema = Arc::new( - Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), - ); - let projected_output_ordering = - get_projected_output_ordering(self, &table_schema); - (table_schema, table_stats, projected_output_ordering) - } - - #[allow(unused)] // Only used by avro - fn projected_file_column_names(&self) -> Option> { - self.projection.as_ref().map(|p| { - p.iter() - .filter(|col_idx| **col_idx < self.file_schema.fields().len()) - .map(|col_idx| self.file_schema.field(*col_idx).name()) - .cloned() - .collect() - }) - } - - fn file_column_projection_indices(&self) -> Option> { - self.projection.as_ref().map(|p| { - p.iter() - .filter(|col_idx| **col_idx < self.file_schema.fields().len()) - .copied() - .collect() - }) - } -} /// The base configurations to provide when creating a physical plan for /// writing to any given file format. -#[derive(Debug, Clone)] pub struct FileSinkConfig { /// Object store URL, used to get an ObjectStore instance pub object_store_url: ObjectStoreUrl, /// A vector of [`PartitionedFile`] structs, each representing a file partition pub file_groups: Vec, + /// Vector of partition paths + pub table_paths: Vec, /// The schema of the output file pub output_schema: SchemaRef, /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// A writer mode that determines how data is written to the file - pub writer_mode: FileWriterMode, + /// If true, it is assumed there is a single table_path which is a file to which all data should be written + /// regardless of input partitioning. Otherwise, each table path is assumed to be a directory + /// to which each output partition is written to its own output file. + pub single_file_output: bool, + /// If input is unbounded, tokio tasks need to yield to not block execution forever + pub unbounded_input: bool, + /// Controls whether existing data should be overwritten by this sink + pub overwrite: bool, + /// Contains settings specific to writing a given FileType, e.g. parquet max_row_group_size + pub file_type_writer_options: FileTypeWriterOptions, } impl FileSinkConfig { @@ -269,15 +112,16 @@ impl Debug for FileScanConfig { write!(f, "statistics={:?}, ", self.statistics)?; - Display::fmt(self, f) + DisplayAs::fmt_as(self, DisplayFormatType::Verbose, f) } } -impl Display for FileScanConfig { - fn fmt(&self, f: &mut Formatter) -> FmtResult { +impl DisplayAs for FileScanConfig { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { let (schema, _, orderings) = self.project(); - write!(f, "file_groups={}", FileGroupsDisplay(&self.file_groups))?; + write!(f, "file_groups=")?; + FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; if !schema.fields().is_empty() { write!(f, ", projection={}", ProjectSchemaDisplay(&schema))?; @@ -293,7 +137,22 @@ impl Display for FileScanConfig { if let Some(ordering) = orderings.first() { if !ordering.is_empty() { - write!(f, ", output_ordering={}", OutputOrderingDisplay(ordering))?; + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{}", start)?; + for (idx, ordering) in + orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) + { + match idx { + 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, + _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + } + } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{}", end)?; } } @@ -310,29 +169,30 @@ impl Display for FileScanConfig { #[derive(Debug)] struct FileGroupsDisplay<'a>(&'a [Vec]); -impl<'a> Display for FileGroupsDisplay<'a> { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - let n_group = self.0.len(); - let groups = if n_group == 1 { "group" } else { "groups" }; - write!(f, "{{{n_group} {groups}: [")?; - // To avoid showing too many partitions - let max_groups = 5; - for (idx, group) in self.0.iter().take(max_groups).enumerate() { - if idx > 0 { - write!(f, ", ")?; +impl<'a> DisplayAs for FileGroupsDisplay<'a> { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { + let n_groups = self.0.len(); + let groups = if n_groups == 1 { "group" } else { "groups" }; + write!(f, "{{{n_groups} {groups}: [")?; + match t { + DisplayFormatType::Default => { + // To avoid showing too many partitions + let max_groups = 5; + fmt_up_to_n_elements(self.0, max_groups, f, |group, f| { + FileGroupDisplay(group).fmt_as(t, f) + })?; + } + DisplayFormatType::Verbose => { + fmt_elements_split_by_commas(self.0.iter(), f, |group, f| { + FileGroupDisplay(group).fmt_as(t, f) + })? } - write!(f, "{}", FileGroupDisplay(group))?; - } - // Remaining elements are showed as `...` (to indicate there is more) - if n_group > max_groups { - write!(f, ", ...")?; } - write!(f, "]}}")?; - Ok(()) + write!(f, "]}}") } } -/// A wrapper to customize partitioned file display +/// A wrapper to customize partitioned group of files display /// /// Prints in the format: /// ```text @@ -341,55 +201,73 @@ impl<'a> Display for FileGroupsDisplay<'a> { #[derive(Debug)] pub(crate) struct FileGroupDisplay<'a>(pub &'a [PartitionedFile]); -impl<'a> Display for FileGroupDisplay<'a> { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - let group = self.0; +impl<'a> DisplayAs for FileGroupDisplay<'a> { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { write!(f, "[")?; - for (idx, pf) in group.iter().enumerate() { - if idx > 0 { - write!(f, ", ")?; + match t { + DisplayFormatType::Default => { + // To avoid showing too many files + let max_files = 5; + fmt_up_to_n_elements(self.0, max_files, f, |pf, f| { + write!(f, "{}", pf.object_meta.location.as_ref())?; + if let Some(range) = pf.range.as_ref() { + write!(f, ":{}..{}", range.start, range.end)?; + } + Ok(()) + })? } - write!(f, "{}", pf.object_meta.location.as_ref())?; - if let Some(range) = pf.range.as_ref() { - write!(f, ":{}..{}", range.start, range.end)?; + DisplayFormatType::Verbose => { + fmt_elements_split_by_commas(self.0.iter(), f, |pf, f| { + write!(f, "{}", pf.object_meta.location.as_ref())?; + if let Some(range) = pf.range.as_ref() { + write!(f, ":{}..{}", range.start, range.end)?; + } + Ok(()) + })? } } - write!(f, "]")?; - Ok(()) + write!(f, "]") } } -/// A wrapper to customize partitioned file display -#[derive(Debug)] -struct ProjectSchemaDisplay<'a>(&'a SchemaRef); - -impl<'a> Display for ProjectSchemaDisplay<'a> { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - let parts: Vec<_> = self - .0 - .fields() - .iter() - .map(|x| x.name().to_owned()) - .collect::>(); - write!(f, "[{}]", parts.join(", ")) +/// helper to format an array of up to N elements +fn fmt_up_to_n_elements( + elements: &[E], + n: usize, + f: &mut Formatter, + format_element: F, +) -> FmtResult +where + F: Fn(&E, &mut Formatter) -> FmtResult, +{ + let len = elements.len(); + fmt_elements_split_by_commas(elements.iter().take(n), f, |element, f| { + format_element(element, f) + })?; + // Remaining elements are showed as `...` (to indicate there is more) + if len > n { + write!(f, ", ...")?; } + Ok(()) } -/// A wrapper to customize output ordering display. -#[derive(Debug)] -struct OutputOrderingDisplay<'a>(&'a [PhysicalSortExpr]); - -impl<'a> Display for OutputOrderingDisplay<'a> { - fn fmt(&self, f: &mut Formatter) -> FmtResult { - write!(f, "[")?; - for (i, e) in self.0.iter().enumerate() { - if i > 0 { - write!(f, ", ")? - } - write!(f, "{e}")?; +/// helper formatting array elements with a comma and a space between them +fn fmt_elements_split_by_commas( + iter: I, + f: &mut Formatter, + format_element: F, +) -> FmtResult +where + I: Iterator, + F: Fn(E, &mut Formatter) -> FmtResult, +{ + for (idx, element) in iter.enumerate() { + if idx > 0 { + write!(f, ", ")?; } - write!(f, "]") + format_element(element, f)?; } + Ok(()) } /// A utility which can adapt file-level record batches to a table schema which may have a schema @@ -454,12 +332,12 @@ impl SchemaAdapter { projection.push(file_idx); } false => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", file_field.name(), file_field.data_type(), table_field.data_type() - ))) + ) } } } @@ -511,240 +389,6 @@ impl SchemaMapping { } } -/// A helper that projects partition columns into the file record batches. -/// -/// One interesting trick is the usage of a cache for the key buffers of the partition column -/// dictionaries. Indeed, the partition columns are constant, so the dictionaries that represent them -/// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, -/// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). -struct PartitionColumnProjector { - /// An Arrow buffer initialized to zeros that represents the key array of all partition - /// columns (partition columns are materialized by dictionary arrays with only one - /// value in the dictionary, thus all the keys are equal to zero). - key_buffer_cache: ZeroBufferGenerators, - /// Mapping between the indexes in the list of partition columns and the target - /// schema. Sorted by index in the target schema so that we can iterate on it to - /// insert the partition columns in the target record batch. - projected_partition_indexes: Vec<(usize, usize)>, - /// The schema of the table once the projection was applied. - projected_schema: SchemaRef, -} - -impl PartitionColumnProjector { - // Create a projector to insert the partitioning columns into batches read from files - // - `projected_schema`: the target schema with both file and partitioning columns - // - `table_partition_cols`: all the partitioning column names - fn new(projected_schema: SchemaRef, table_partition_cols: &[String]) -> Self { - let mut idx_map = HashMap::new(); - for (partition_idx, partition_name) in table_partition_cols.iter().enumerate() { - if let Ok(schema_idx) = projected_schema.index_of(partition_name) { - idx_map.insert(partition_idx, schema_idx); - } - } - - let mut projected_partition_indexes: Vec<_> = idx_map.into_iter().collect(); - projected_partition_indexes.sort_by(|(_, a), (_, b)| a.cmp(b)); - - Self { - projected_partition_indexes, - key_buffer_cache: Default::default(), - projected_schema, - } - } - - // Transform the batch read from the file by inserting the partitioning columns - // to the right positions as deduced from `projected_schema` - // - `file_batch`: batch read from the file, with internal projection applied - // - `partition_values`: the list of partition values, one for each partition column - fn project( - &mut self, - file_batch: RecordBatch, - partition_values: &[ScalarValue], - ) -> Result { - let expected_cols = - self.projected_schema.fields().len() - self.projected_partition_indexes.len(); - - if file_batch.columns().len() != expected_cols { - return Err(DataFusionError::Execution(format!( - "Unexpected batch schema from file, expected {} cols but got {}", - expected_cols, - file_batch.columns().len() - ))); - } - let mut cols = file_batch.columns().to_vec(); - for &(pidx, sidx) in &self.projected_partition_indexes { - let mut partition_value = Cow::Borrowed(&partition_values[pidx]); - - // check if user forgot to dict-encode the partition value - let field = self.projected_schema.field(sidx); - let expected_data_type = field.data_type(); - let actual_data_type = partition_value.get_datatype(); - if let DataType::Dictionary(key_type, _) = expected_data_type { - if !matches!(actual_data_type, DataType::Dictionary(_, _)) { - warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); - partition_value = Cow::Owned(ScalarValue::Dictionary( - key_type.clone(), - Box::new(partition_value.as_ref().clone()), - )); - } - } - - cols.insert( - sidx, - create_output_array( - &mut self.key_buffer_cache, - partition_value.as_ref(), - file_batch.num_rows(), - ), - ) - } - RecordBatch::try_new(Arc::clone(&self.projected_schema), cols).map_err(Into::into) - } -} - -#[derive(Debug, Default)] -struct ZeroBufferGenerators { - gen_i8: ZeroBufferGenerator, - gen_i16: ZeroBufferGenerator, - gen_i32: ZeroBufferGenerator, - gen_i64: ZeroBufferGenerator, - gen_u8: ZeroBufferGenerator, - gen_u16: ZeroBufferGenerator, - gen_u32: ZeroBufferGenerator, - gen_u64: ZeroBufferGenerator, -} - -/// Generate a arrow [`Buffer`] that contains zero values. -#[derive(Debug, Default)] -struct ZeroBufferGenerator -where - T: ArrowNativeType, -{ - cache: Option, - _t: PhantomData, -} - -impl ZeroBufferGenerator -where - T: ArrowNativeType, -{ - const SIZE: usize = std::mem::size_of::(); - - fn get_buffer(&mut self, n_vals: usize) -> Buffer { - match &mut self.cache { - Some(buf) if buf.len() >= n_vals * Self::SIZE => { - buf.slice_with_length(0, n_vals * Self::SIZE) - } - _ => { - let mut key_buffer_builder = BufferBuilder::::new(n_vals); - key_buffer_builder.advance(n_vals); // keys are all 0 - self.cache.insert(key_buffer_builder.finish()).clone() - } - } - } -} - -fn create_dict_array( - buffer_gen: &mut ZeroBufferGenerator, - dict_val: &ScalarValue, - len: usize, - data_type: DataType, -) -> ArrayRef -where - T: ArrowNativeType, -{ - let dict_vals = dict_val.to_array(); - - let sliced_key_buffer = buffer_gen.get_buffer(len); - - // assemble pieces together - let mut builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(sliced_key_buffer); - builder = builder.add_child_data(dict_vals.to_data()); - Arc::new(DictionaryArray::::from( - builder.build().unwrap(), - )) -} - -fn create_output_array( - key_buffer_cache: &mut ZeroBufferGenerators, - val: &ScalarValue, - len: usize, -) -> ArrayRef { - if let ScalarValue::Dictionary(key_type, dict_val) = &val { - match key_type.as_ref() { - DataType::Int8 => { - return create_dict_array( - &mut key_buffer_cache.gen_i8, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::Int16 => { - return create_dict_array( - &mut key_buffer_cache.gen_i16, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::Int32 => { - return create_dict_array( - &mut key_buffer_cache.gen_i32, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::Int64 => { - return create_dict_array( - &mut key_buffer_cache.gen_i64, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::UInt8 => { - return create_dict_array( - &mut key_buffer_cache.gen_u8, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::UInt16 => { - return create_dict_array( - &mut key_buffer_cache.gen_u16, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::UInt32 => { - return create_dict_array( - &mut key_buffer_cache.gen_u32, - dict_val, - len, - val.get_datatype(), - ); - } - DataType::UInt64 => { - return create_dict_array( - &mut key_buffer_cache.gen_u64, - dict_val, - len, - val.get_datatype(), - ); - } - _ => {} - } - } - - val.to_array_of_size(len) -} - /// A single file or part of a file that should be read, along with its schema, statistics pub struct FileMeta { /// Path for the file (e.g. URL, filesystem path, etc) @@ -859,11 +503,30 @@ fn get_projected_output_ordering( // since rest of the orderings are violated break; } - all_orderings.push(new_ordering); + // do not push empty entries + // otherwise we may have `Some(vec![])` at the output ordering. + if !new_ordering.is_empty() { + all_orderings.push(new_ordering); + } } all_orderings } +// Get output (un)boundedness information for the given `plan`. +pub(crate) fn is_plan_streaming(plan: &Arc) -> Result { + let result = if plan.children().is_empty() { + plan.unbounded_output(&[]) + } else { + let children_unbounded_output = plan + .children() + .iter() + .map(is_plan_streaming) + .collect::>>(); + plan.unbounded_output(&children_unbounded_output?) + }; + result +} + #[cfg(test)] mod tests { use arrow_array::cast::AsArray; @@ -872,270 +535,14 @@ mod tests { BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, StringArray, UInt64Array, }; + use arrow_schema::Field; use chrono::Utc; + use datafusion_common::config::ConfigOptions; - use crate::{ - test::{build_table_i32, columns}, - test_util::aggr_test_schema, - }; + use crate::physical_plan::{DefaultDisplay, VerboseDisplay}; use super::*; - #[test] - fn physical_plan_config_no_projection() { - let file_schema = aggr_test_schema(); - let conf = config_for_projection( - Arc::clone(&file_schema), - None, - Statistics::default(), - vec![( - "date".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - )], - ); - - let (proj_schema, proj_statistics, _) = conf.project(); - assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); - assert_eq!( - proj_schema.field(file_schema.fields().len()).name(), - "date", - "partition columns are the last columns" - ); - assert_eq!( - proj_statistics - .column_statistics - .expect("projection creates column statistics") - .len(), - file_schema.fields().len() + 1 - ); - // TODO implement tests for partition column statistics once implemented - - let col_names = conf.projected_file_column_names(); - assert_eq!(col_names, None); - - let col_indices = conf.file_column_projection_indices(); - assert_eq!(col_indices, None); - } - - #[test] - fn physical_plan_config_with_projection() { - let file_schema = aggr_test_schema(); - let conf = config_for_projection( - Arc::clone(&file_schema), - Some(vec![file_schema.fields().len(), 0]), - Statistics { - num_rows: Some(10), - // assign the column index to distinct_count to help assert - // the source statistic after the projection - column_statistics: Some( - (0..file_schema.fields().len()) - .map(|i| ColumnStatistics { - distinct_count: Some(i), - ..Default::default() - }) - .collect(), - ), - ..Default::default() - }, - vec![( - "date".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - )], - ); - - let (proj_schema, proj_statistics, _) = conf.project(); - assert_eq!( - columns(&proj_schema), - vec!["date".to_owned(), "c1".to_owned()] - ); - let proj_stat_cols = proj_statistics - .column_statistics - .expect("projection creates column statistics"); - assert_eq!(proj_stat_cols.len(), 2); - // TODO implement tests for proj_stat_cols[0] once partition column - // statistics are implemented - assert_eq!(proj_stat_cols[1].distinct_count, Some(0)); - - let col_names = conf.projected_file_column_names(); - assert_eq!(col_names, Some(vec!["c1".to_owned()])); - - let col_indices = conf.file_column_projection_indices(); - assert_eq!(col_indices, Some(vec![0])); - } - - #[test] - fn partition_column_projector() { - let file_batch = build_table_i32( - ("a", &vec![0, 1, 2]), - ("b", &vec![-2, -1, 0]), - ("c", &vec![10, 11, 12]), - ); - let partition_cols = vec![ - ( - "year".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "month".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "day".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ]; - // create a projected schema - let conf = config_for_projection( - file_batch.schema(), - // keep all cols from file and 2 from partitioning - Some(vec![ - 0, - 1, - 2, - file_batch.schema().fields().len(), - file_batch.schema().fields().len() + 2, - ]), - Statistics::default(), - partition_cols.clone(), - ); - let (proj_schema, ..) = conf.project(); - // created a projector for that projected schema - let mut proj = PartitionColumnProjector::new( - proj_schema, - &partition_cols - .iter() - .map(|x| x.0.clone()) - .collect::>(), - ); - - // project first batch - let projected_batch = proj - .project( - // file_batch is ok here because we kept all the file cols in the projection - file_batch, - &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "26".to_owned(), - ))), - ], - ) - .expect("Projection of partition columns into record batch failed"); - let expected = vec![ - "+---+----+----+------+-----+", - "| a | b | c | year | day |", - "+---+----+----+------+-----+", - "| 0 | -2 | 10 | 2021 | 26 |", - "| 1 | -1 | 11 | 2021 | 26 |", - "| 2 | 0 | 12 | 2021 | 26 |", - "+---+----+----+------+-----+", - ]; - crate::assert_batches_eq!(expected, &[projected_batch]); - - // project another batch that is larger than the previous one - let file_batch = build_table_i32( - ("a", &vec![5, 6, 7, 8, 9]), - ("b", &vec![-10, -9, -8, -7, -6]), - ("c", &vec![12, 13, 14, 15, 16]), - ); - let projected_batch = proj - .project( - // file_batch is ok here because we kept all the file cols in the projection - file_batch, - &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "27".to_owned(), - ))), - ], - ) - .expect("Projection of partition columns into record batch failed"); - let expected = vec![ - "+---+-----+----+------+-----+", - "| a | b | c | year | day |", - "+---+-----+----+------+-----+", - "| 5 | -10 | 12 | 2021 | 27 |", - "| 6 | -9 | 13 | 2021 | 27 |", - "| 7 | -8 | 14 | 2021 | 27 |", - "| 8 | -7 | 15 | 2021 | 27 |", - "| 9 | -6 | 16 | 2021 | 27 |", - "+---+-----+----+------+-----+", - ]; - crate::assert_batches_eq!(expected, &[projected_batch]); - - // project another batch that is smaller than the previous one - let file_batch = build_table_i32( - ("a", &vec![0, 1, 3]), - ("b", &vec![2, 3, 4]), - ("c", &vec![4, 5, 6]), - ); - let projected_batch = proj - .project( - // file_batch is ok here because we kept all the file cols in the projection - file_batch, - &[ - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "2021".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "10".to_owned(), - ))), - wrap_partition_value_in_dict(ScalarValue::Utf8(Some( - "28".to_owned(), - ))), - ], - ) - .expect("Projection of partition columns into record batch failed"); - let expected = vec![ - "+---+---+---+------+-----+", - "| a | b | c | year | day |", - "+---+---+---+------+-----+", - "| 0 | 2 | 4 | 2021 | 28 |", - "| 1 | 3 | 5 | 2021 | 28 |", - "| 3 | 4 | 6 | 2021 | 28 |", - "+---+---+---+------+-----+", - ]; - crate::assert_batches_eq!(expected, &[projected_batch]); - - // forgot to dictionary-wrap the scalar value - let file_batch = build_table_i32( - ("a", &vec![0, 1, 2]), - ("b", &vec![-2, -1, 0]), - ("c", &vec![10, 11, 12]), - ); - let projected_batch = proj - .project( - // file_batch is ok here because we kept all the file cols in the projection - file_batch, - &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), - ], - ) - .expect("Projection of partition columns into record batch failed"); - let expected = vec![ - "+---+----+----+------+-----+", - "| a | b | c | year | day |", - "+---+----+----+------+-----+", - "| 0 | -2 | 10 | 2021 | 26 |", - "| 1 | -1 | 11 | 2021 | 26 |", - "| 2 | 0 | 12 | 2021 | 26 |", - "+---+----+----+------+-----+", - ]; - crate::assert_batches_eq!(expected, &[projected_batch]); - } - #[test] fn schema_mapping_map_batch() { let table_schema = Arc::new(Schema::new(vec![ @@ -1239,9 +646,9 @@ mod tests { let c2 = mapped_batch.column(1).as_primitive::(); let c4 = mapped_batch.column(2).as_primitive::(); - assert_eq!(c1.value(0), "1"); - assert_eq!(c1.value(1), "0"); - assert_eq!(c1.value(2), "1"); + assert_eq!(c1.value(0), "true"); + assert_eq!(c1.value(1), "false"); + assert_eq!(c1.value(2), "true"); assert_eq!(c2.value(0), 2.0_f64); assert_eq!(c2.value(1), 7.0_f64); @@ -1252,30 +659,10 @@ mod tests { assert_eq!(c4.value(2), 3.0_f32); } - // sets default for configs that play no role in projections - fn config_for_projection( - file_schema: SchemaRef, - projection: Option>, - statistics: Statistics, - table_partition_cols: Vec<(String, DataType)>, - ) -> FileScanConfig { - FileScanConfig { - file_schema, - file_groups: vec![vec![]], - limit: None, - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - projection, - statistics, - table_partition_cols, - output_ordering: vec![], - infinite_source: false, - } - } - #[test] fn file_groups_display_empty() { let expected = "{0 groups: []}"; - assert_eq!(&FileGroupsDisplay(&[]).to_string(), expected); + assert_eq!(DefaultDisplay(FileGroupsDisplay(&[])).to_string(), expected); } #[test] @@ -1283,11 +670,14 @@ mod tests { let files = [vec![partitioned_file("foo"), partitioned_file("bar")]]; let expected = "{1 group: [[foo, bar]]}"; - assert_eq!(&FileGroupsDisplay(&files).to_string(), expected); + assert_eq!( + DefaultDisplay(FileGroupsDisplay(&files)).to_string(), + expected + ); } #[test] - fn file_groups_display_many() { + fn file_groups_display_many_default() { let files = [ vec![partitioned_file("foo"), partitioned_file("bar")], vec![partitioned_file("baz")], @@ -1295,15 +685,111 @@ mod tests { ]; let expected = "{3 groups: [[foo, bar], [baz], []]}"; - assert_eq!(&FileGroupsDisplay(&files).to_string(), expected); + assert_eq!( + DefaultDisplay(FileGroupsDisplay(&files)).to_string(), + expected + ); + } + + #[test] + fn file_groups_display_many_verbose() { + let files = [ + vec![partitioned_file("foo"), partitioned_file("bar")], + vec![partitioned_file("baz")], + vec![], + ]; + + let expected = "{3 groups: [[foo, bar], [baz], []]}"; + assert_eq!( + VerboseDisplay(FileGroupsDisplay(&files)).to_string(), + expected + ); + } + + #[test] + fn file_groups_display_too_many_default() { + let files = [ + vec![partitioned_file("foo"), partitioned_file("bar")], + vec![partitioned_file("baz")], + vec![partitioned_file("qux")], + vec![partitioned_file("quux")], + vec![partitioned_file("quuux")], + vec![partitioned_file("quuuux")], + vec![], + ]; + + let expected = "{7 groups: [[foo, bar], [baz], [qux], [quux], [quuux], ...]}"; + assert_eq!( + DefaultDisplay(FileGroupsDisplay(&files)).to_string(), + expected + ); } #[test] - fn file_group_display_many() { + fn file_groups_display_too_many_verbose() { + let files = [ + vec![partitioned_file("foo"), partitioned_file("bar")], + vec![partitioned_file("baz")], + vec![partitioned_file("qux")], + vec![partitioned_file("quux")], + vec![partitioned_file("quuux")], + vec![partitioned_file("quuuux")], + vec![], + ]; + + let expected = + "{7 groups: [[foo, bar], [baz], [qux], [quux], [quuux], [quuuux], []]}"; + assert_eq!( + VerboseDisplay(FileGroupsDisplay(&files)).to_string(), + expected + ); + } + + #[test] + fn file_group_display_many_default() { let files = vec![partitioned_file("foo"), partitioned_file("bar")]; let expected = "[foo, bar]"; - assert_eq!(&FileGroupDisplay(&files).to_string(), expected); + assert_eq!( + DefaultDisplay(FileGroupDisplay(&files)).to_string(), + expected + ); + } + + #[test] + fn file_group_display_too_many_default() { + let files = vec![ + partitioned_file("foo"), + partitioned_file("bar"), + partitioned_file("baz"), + partitioned_file("qux"), + partitioned_file("quux"), + partitioned_file("quuux"), + ]; + + let expected = "[foo, bar, baz, qux, quux, ...]"; + assert_eq!( + DefaultDisplay(FileGroupDisplay(&files)).to_string(), + expected + ); + } + + #[test] + fn file_group_display_too_many_verbose() { + let files = vec![ + partitioned_file("foo"), + partitioned_file("bar"), + partitioned_file("baz"), + partitioned_file("qux"), + partitioned_file("quux"), + partitioned_file("quuux"), + ]; + + let expected = "[foo, bar, baz, qux, quux, quuux]"; + assert_eq!( + VerboseDisplay(FileGroupDisplay(&files)).to_string(), + expected + ); } /// create a PartitionedFile for testing @@ -1313,6 +799,7 @@ mod tests { last_modified: Utc::now(), size: 42, e_tag: None, + version: None, }; PartitionedFile { @@ -1322,4 +809,345 @@ mod tests { extensions: None, } } + + /// Unit tests for `repartition_file_groups()` + #[cfg(feature = "parquet")] + mod repartition_file_groups_test { + use datafusion_common::Statistics; + use itertools::Itertools; + + use super::*; + + /// Empty file won't get partitioned + #[tokio::test] + async fn repartition_empty_file_only() { + let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); + let file_group = vec![vec![partitioned_file_empty]]; + + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: file_group, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let partitioned_file = repartition_with_size(&parquet_exec, 4, 0); + + assert!(partitioned_file[0][0].range.is_none()); + } + + // Repartition when there is a empty file in file groups + #[tokio::test] + async fn repartition_empty_files() { + let partitioned_file_a = PartitionedFile::new("a".to_string(), 10); + let partitioned_file_b = PartitionedFile::new("b".to_string(), 10); + let partitioned_file_empty = PartitionedFile::new("empty".to_string(), 0); + + let empty_first = vec![ + vec![partitioned_file_empty.clone()], + vec![partitioned_file_a.clone()], + vec![partitioned_file_b.clone()], + ]; + let empty_middle = vec![ + vec![partitioned_file_a.clone()], + vec![partitioned_file_empty.clone()], + vec![partitioned_file_b.clone()], + ]; + let empty_last = vec![ + vec![partitioned_file_a], + vec![partitioned_file_b], + vec![partitioned_file_empty], + ]; + + // Repartition file groups into x partitions + let expected_2 = + vec![(0, "a".to_string(), 0, 10), (1, "b".to_string(), 0, 10)]; + let expected_3 = vec![ + (0, "a".to_string(), 0, 7), + (1, "a".to_string(), 7, 10), + (1, "b".to_string(), 0, 4), + (2, "b".to_string(), 4, 10), + ]; + + //let file_groups_testset = [empty_first, empty_middle, empty_last]; + let file_groups_testset = [empty_first, empty_middle, empty_last]; + + for fg in file_groups_testset { + for (n_partition, expected) in [(2, &expected_2), (3, &expected_3)] { + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: fg.clone(), + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Arc::new( + Schema::empty(), + )), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = + repartition_with_size_to_vec(&parquet_exec, n_partition, 10); + + assert_eq!(expected, &actual); + } + } + } + + #[tokio::test] + async fn repartition_single_file() { + // Single file, single partition into multiple partitions + let partitioned_file = PartitionedFile::new("a".to_string(), 123); + let single_partition = vec![vec![partitioned_file]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: single_partition, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size_to_vec(&parquet_exec, 4, 10); + let expected = vec![ + (0, "a".to_string(), 0, 31), + (1, "a".to_string(), 31, 62), + (2, "a".to_string(), 62, 93), + (3, "a".to_string(), 93, 123), + ]; + assert_eq!(expected, actual); + } + + #[tokio::test] + async fn repartition_too_much_partitions() { + // Single file, single parittion into 96 partitions + let partitioned_file = PartitionedFile::new("a".to_string(), 8); + let single_partition = vec![vec![partitioned_file]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: single_partition, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size_to_vec(&parquet_exec, 96, 5); + let expected = vec![ + (0, "a".to_string(), 0, 1), + (1, "a".to_string(), 1, 2), + (2, "a".to_string(), 2, 3), + (3, "a".to_string(), 3, 4), + (4, "a".to_string(), 4, 5), + (5, "a".to_string(), 5, 6), + (6, "a".to_string(), 6, 7), + (7, "a".to_string(), 7, 8), + ]; + assert_eq!(expected, actual); + } + + #[tokio::test] + async fn repartition_multiple_partitions() { + // Multiple files in single partition after redistribution + let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); + let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); + let source_partitions = + vec![vec![partitioned_file_1], vec![partitioned_file_2]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: source_partitions, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size_to_vec(&parquet_exec, 3, 10); + let expected = vec![ + (0, "a".to_string(), 0, 34), + (1, "a".to_string(), 34, 40), + (1, "b".to_string(), 0, 28), + (2, "b".to_string(), 28, 60), + ]; + assert_eq!(expected, actual); + } + + #[tokio::test] + async fn repartition_same_num_partitions() { + // "Rebalance" files across partitions + let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); + let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); + let source_partitions = + vec![vec![partitioned_file_1], vec![partitioned_file_2]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: source_partitions, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size_to_vec(&parquet_exec, 2, 10); + let expected = vec![ + (0, "a".to_string(), 0, 40), + (0, "b".to_string(), 0, 10), + (1, "b".to_string(), 10, 60), + ]; + assert_eq!(expected, actual); + } + + #[tokio::test] + async fn repartition_no_action_ranges() { + // No action due to Some(range) in second file + let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); + let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); + partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); + + let source_partitions = + vec![vec![partitioned_file_1], vec![partitioned_file_2]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: source_partitions, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size(&parquet_exec, 65, 10); + assert_eq!(2, actual.len()); + } + + #[tokio::test] + async fn repartition_no_action_min_size() { + // No action due to target_partition_size + let partitioned_file = PartitionedFile::new("a".to_string(), 123); + let single_partition = vec![vec![partitioned_file]]; + let parquet_exec = ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: single_partition, + file_schema: Arc::new(Schema::empty()), + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + ); + + let actual = repartition_with_size(&parquet_exec, 65, 500); + assert_eq!(1, actual.len()); + } + + /// Calls `ParquetExec.repartitioned` with the specified + /// `target_partitions` and `repartition_file_min_size`, returning the + /// resulting `PartitionedFile`s + fn repartition_with_size( + parquet_exec: &ParquetExec, + target_partitions: usize, + repartition_file_min_size: usize, + ) -> Vec> { + let mut config = ConfigOptions::new(); + config.optimizer.repartition_file_min_size = repartition_file_min_size; + + parquet_exec + .repartitioned(target_partitions, &config) + .unwrap() // unwrap Result + .unwrap() // unwrap Option + .as_any() + .downcast_ref::() + .unwrap() + .base_config() + .file_groups + .clone() + } + + /// Calls `repartition_with_size` and returns a tuple for each output `PartitionedFile`: + /// + /// `(partition index, file path, start, end)` + fn repartition_with_size_to_vec( + parquet_exec: &ParquetExec, + target_partitions: usize, + repartition_file_min_size: usize, + ) -> Vec<(usize, String, i64, i64)> { + let file_groups = repartition_with_size( + parquet_exec, + target_partitions, + repartition_file_min_size, + ); + + file_groups + .iter() + .enumerate() + .flat_map(|(part_idx, files)| { + files + .iter() + .map(|f| { + ( + part_idx, + f.object_meta.location.to_string(), + f.range.as_ref().unwrap().start, + f.range.as_ref().unwrap().end, + ) + }) + .collect_vec() + }) + .collect_vec() + } + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs similarity index 78% rename from datafusion/core/src/datasource/physical_plan/parquet.rs rename to datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 48e4d49371704..718f9f820af17 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -17,66 +17,59 @@ //! Execution plan for reading Parquet files +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, }; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, FileMeta, FileScanConfig, SchemaAdapter, + parquet::page_filter::PagePruningPredicate, DisplayAs, FileMeta, FileScanConfig, + SchemaAdapter, }; use crate::{ config::ConfigOptions, - datasource::listing::FileRange, + datasource::listing::ListingTableUrl, error::{DataFusionError, Result}, execution::context::TaskContext, physical_optimizer::pruning::PruningPredicate, physical_plan::{ - common::AbortOnDropSingle, metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - ordering_equivalence_properties_helper, DisplayFormatType, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }, }; -use datafusion_physical_expr::PhysicalSortExpr; -use fmt::Debug; -use std::any::Any; -use std::cmp::min; -use std::fmt; -use std::fs; -use std::ops::Range; -use std::sync::Arc; use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::ArrowError; use datafusion_physical_expr::{ - LexOrdering, OrderingEquivalenceProperties, PhysicalExpr, + EquivalenceProperties, LexOrdering, PhysicalExpr, PhysicalSortExpr, }; use bytes::Bytes; use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; -use itertools::Itertools; use log::debug; +use object_store::path::Path; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use parquet::arrow::{ArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; +use parquet::arrow::{AsyncArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; use parquet::basic::{ConvertedType, LogicalType}; use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; use parquet::schema::types::ColumnDescriptor; +use tokio::task::JoinSet; mod metrics; pub mod page_filter; mod row_filter; mod row_groups; +mod statistics; pub use metrics::ParquetFileMetrics; -#[derive(Default)] -struct RepartitionState { - current_partition_index: usize, - current_partition_size: usize, -} - /// Execution plan for scanning one or more Parquet partitions #[derive(Debug, Clone)] pub struct ParquetExec { @@ -89,6 +82,9 @@ pub struct ParquetExec { /// Override for `Self::with_enable_page_index`. If None, uses /// values from base_config enable_page_index: Option, + /// Override for `Self::with_enable_bloom_filter`. If None, uses + /// values from base_config + enable_bloom_filter: Option, /// Base configuration for this scan base_config: FileScanConfig, projected_statistics: Statistics, @@ -158,6 +154,7 @@ impl ParquetExec { pushdown_filters: None, reorder_filters: None, enable_page_index: None, + enable_bloom_filter: None, base_config, projected_schema, projected_statistics, @@ -251,76 +248,44 @@ impl ParquetExec { .unwrap_or(config_options.execution.parquet.enable_page_index) } - /// Redistribute files across partitions according to their size - pub fn get_repartitioned( - &self, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Self { - let flattened_files = self - .base_config() - .file_groups - .iter() - .flatten() - .collect::>(); - - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = flattened_files.iter().any(|f| f.range.is_some()); - if has_ranges { - return self.clone(); - } + /// If enabled, the reader will read by the bloom filter + pub fn with_enable_bloom_filter(mut self, enable_bloom_filter: bool) -> Self { + self.enable_bloom_filter = Some(enable_bloom_filter); + self + } - let total_size = flattened_files - .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) { - return self.clone(); - } + /// Return the value described in [`Self::with_enable_bloom_filter`] + fn enable_bloom_filter(&self, config_options: &ConfigOptions) -> bool { + self.enable_bloom_filter + .unwrap_or(config_options.execution.parquet.bloom_filter_enabled) + } +} - let target_partition_size = - (total_size as usize + (target_partitions) - 1) / (target_partitions); - - let repartitioned_files = flattened_files - .into_iter() - .scan(RepartitionState::default(), |state, source_file| { - let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { - let range_end = min( - range_start - + (target_partition_size - state.current_partition_size), - source_file.object_meta.size, - ); +impl DisplayAs for ParquetExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let predicate_string = self + .predicate + .as_ref() + .map(|p| format!(", predicate={p}")) + .unwrap_or_default(); - let mut produced_file = source_file.clone(); - produced_file.range = Some(FileRange { - start: range_start as i64, - end: range_end as i64, - }); - produced_files.push((state.current_partition_index, produced_file)); - - if state.current_partition_size + (range_end - range_start) - >= target_partition_size - { - state.current_partition_index += 1; - state.current_partition_size = 0; - } else { - state.current_partition_size += range_end - range_start; - } - range_start = range_end; - } - Some(produced_files) - }) - .flatten() - .group_by(|(partition_idx, _)| *partition_idx) - .into_iter() - .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) - .collect_vec(); + let pruning_predicate_string = self + .pruning_predicate + .as_ref() + .map(|pre| format!(", pruning_predicate={}", pre.predicate_expr())) + .unwrap_or_default(); - let mut new_parquet_exec = self.clone(); - new_parquet_exec.base_config.file_groups = repartitioned_files; - new_parquet_exec + write!(f, "ParquetExec: ")?; + self.base_config.fmt_as(t, f)?; + write!(f, "{}{}", predicate_string, pruning_predicate_string,) + } + } } } @@ -350,8 +315,8 @@ impl ExecutionPlan for ParquetExec { .map(|ordering| ordering.as_slice()) } - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - ordering_equivalence_properties_helper( + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( self.schema(), &self.projected_output_ordering, ) @@ -364,6 +329,27 @@ impl ExecutionPlan for ParquetExec { Ok(self) } + /// Redistribute files across partitions according to their size + /// See comments on `get_file_groups_repartitioned()` for more detail. + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + let repartition_file_min_size = config.optimizer.repartition_file_min_size; + let repartitioned_file_groups_option = FileScanConfig::repartition_file_groups( + self.base_config.file_groups.clone(), + target_partitions, + repartition_file_min_size, + ); + + let mut new_plan = self.clone(); + if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { + new_plan.base_config.file_groups = repartitioned_file_groups; + } + Ok(Some(Arc::new(new_plan))) + } + fn execute( &self, partition_index: usize, @@ -404,6 +390,7 @@ impl ExecutionPlan for ParquetExec { pushdown_filters: self.pushdown_filters(config_options), reorder_filters: self.reorder_filters(config_options), enable_page_index: self.enable_page_index(config_options), + enable_bloom_filter: self.enable_bloom_filter(config_options), }; let stream = @@ -412,40 +399,12 @@ impl ExecutionPlan for ParquetExec { Ok(Box::pin(stream)) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let predicate_string = self - .predicate - .as_ref() - .map(|p| format!(", predicate={p}")) - .unwrap_or_default(); - - let pruning_predicate_string = self - .pruning_predicate - .as_ref() - .map(|pre| format!(", pruning_predicate={}", pre.predicate_expr())) - .unwrap_or_default(); - - write!( - f, - "ParquetExec: {}{}{}", - self.base_config, predicate_string, pruning_predicate_string, - ) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.projected_statistics.clone() + fn statistics(&self) -> Result { + Ok(self.projected_statistics.clone()) } } @@ -465,6 +424,7 @@ struct ParquetOpener { pushdown_filters: bool, reorder_filters: bool, enable_page_index: bool, + enable_bloom_filter: bool, } impl FileOpener for ParquetOpener { @@ -499,6 +459,7 @@ impl FileOpener for ParquetOpener { self.enable_page_index, &self.page_pruning_predicate, ); + let enable_bloom_filter = self.enable_bloom_filter; let limit = self.limit; Ok(Box::pin(async move { @@ -541,16 +502,33 @@ impl FileOpener for ParquetOpener { }; }; - // Row group pruning: attempt to skip entire row_groups + // Row group pruning by statistics: attempt to skip entire row_groups // using metadata on the row groups - let file_metadata = builder.metadata(); - let row_groups = row_groups::prune_row_groups( + let file_metadata = builder.metadata().clone(); + let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); + let mut row_groups = row_groups::prune_row_groups_by_statistics( + builder.parquet_schema(), file_metadata.row_groups(), file_range, - pruning_predicate.as_ref().map(|p| p.as_ref()), + predicate, &file_metrics, ); + // Bloom filter pruning: if bloom filters are enabled and then attempt to skip entire row_groups + // using bloom filters on the row groups + if enable_bloom_filter && !row_groups.is_empty() { + if let Some(predicate) = predicate { + row_groups = row_groups::prune_row_groups_by_bloom_filters( + &mut builder, + &row_groups, + file_metadata.row_groups(), + predicate, + &file_metrics, + ) + .await; + } + } + // page index pruning: if all data on individual pages can // be ruled using page metadata, rows from other columns // with that range can be skipped as well @@ -626,7 +604,7 @@ impl DefaultParquetFileReaderFactory { } /// Implements [`AsyncFileReader`] for a parquet file in object storage -struct ParquetFileReader { +pub(crate) struct ParquetFileReader { file_metrics: ParquetFileMetrics, inner: ParquetObjectReader, } @@ -694,69 +672,54 @@ pub async fn plan_to_parquet( writer_properties: Option, ) -> Result<()> { let path = path.as_ref(); - // create directory to contain the Parquet files (one per partition) - let fs_path = std::path::Path::new(path); - if let Err(e) = fs::create_dir(fs_path) { - return Err(DataFusionError::Execution(format!( - "Could not create directory {path}: {e:?}" - ))); - } - - let mut tasks = vec![]; + let parsed = ListingTableUrl::parse(path)?; + let object_store_url = parsed.object_store(); + let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { - let plan = plan.clone(); - let filename = format!("part-{i}.parquet"); - let path = fs_path.join(filename); - let file = fs::File::create(path)?; - let mut writer = - ArrowWriter::try_new(file, plan.schema(), writer_properties.clone())?; - let stream = plan.execute(i, task_ctx.clone())?; - let handle: tokio::task::JoinHandle> = - tokio::task::spawn(async move { - stream - .map(|batch| { - writer.write(&batch?).map_err(DataFusionError::ParquetError) - }) - .try_collect() - .await - .map_err(DataFusionError::from)?; + let plan: Arc = plan.clone(); + let filename = format!("{}/part-{i}.parquet", parsed.prefix()); + let file = Path::parse(filename)?; + let propclone = writer_properties.clone(); + + let storeref = store.clone(); + let (_, multipart_writer) = storeref.put_multipart(&file).await?; + let mut stream = plan.execute(i, task_ctx.clone())?; + join_set.spawn(async move { + let mut writer = AsyncArrowWriter::try_new( + multipart_writer, + plan.schema(), + 10485760, + propclone, + )?; + while let Some(next_batch) = stream.next().await { + let batch = next_batch?; + writer.write(&batch).await?; + } + writer + .close() + .await + .map_err(DataFusionError::from) + .map(|_| ()) + }); + } - writer.close().map_err(DataFusionError::from).map(|_| ()) - }); - tasks.push(AbortOnDropSingle::new(handle)); + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } } - futures::future::join_all(tasks) - .await - .into_iter() - .try_for_each(|result| { - result.map_err(|e| DataFusionError::Execution(format!("{e}")))? - })?; Ok(()) } -// Copy from the arrow-rs -// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 -// Convert the byte slice to fixed length byte array with the length of 16 -fn sign_extend_be(b: &[u8]) -> [u8; 16] { - assert!(b.len() <= 16, "Array too large, expected less than 16"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; - for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { - *d = *s; - } - result -} - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be(b)) -} - // Convert parquet column schema to arrow data type, and just consider the // decimal data type. pub(crate) fn parquet_to_arrow_decimal_type( @@ -782,6 +745,7 @@ mod tests { // See also `parquet_exec` integration test use super::*; + use crate::dataframe::DataFrameWriteOptions; use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::file_format::parquet::test_util::store_parquet; use crate::datasource::file_format::test_util::scan_format; @@ -814,9 +778,10 @@ mod tests { use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::ObjectMeta; - use std::fs::File; + use std::fs::{self, File}; use std::io::Write; use tempfile::TempDir; + use url::Url; struct RoundTripResult { /// Data that was read back from ParquetFiles @@ -908,8 +873,8 @@ mod tests { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection, limit: None, table_partition_cols: vec![], @@ -965,17 +930,22 @@ mod tests { #[tokio::test] async fn write_parquet_results_error_handling() -> Result<()> { let ctx = SessionContext::new(); + // register a local file system object store for /tmp directory + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + let options = CsvReadOptions::default() .schema_infer_max_records(2) .has_header(true); let df = ctx.read_csv("tests/data/corrupt.csv", options).await?; - let tmp_dir = TempDir::new()?; - let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; let e = df - .write_parquet(&out_dir, None) + .write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!("Arrow error: Parser error: Error while parsing value d for column 0 at line 4", format!("{e}")); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); Ok(()) } @@ -1042,7 +1012,7 @@ mod tests { .round_trip_to_batches(vec![batch1, batch2]) .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+----+----+", "| c1 | c2 | c3 |", "+-----+----+----+", @@ -1077,7 +1047,7 @@ mod tests { .round_trip_to_batches(vec![batch1, batch2]) .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+----+----+", "| c1 | c3 | c2 |", "+-----+----+----+", @@ -1115,7 +1085,7 @@ mod tests { .round_trip_to_batches(vec![batch1, batch2]) .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+----+----+", "| c1 | c3 | c2 |", "+-----+----+----+", @@ -1154,7 +1124,7 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - let expected = vec![ + let expected = [ "+----+----+----+", "| c1 | c3 | c2 |", "+----+----+----+", @@ -1196,7 +1166,7 @@ mod tests { .round_trip_to_batches(vec![batch1, batch2]) .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+-----+", "| c1 | c4 |", "+-----+-----+", @@ -1270,7 +1240,7 @@ mod tests { // a null array, then the pruning predicate (currently) can not be applied. // In a real query where this predicate was pushed down from a filter stage instead of created directly in the `ParquetExec`, // the filter stage would be preserved as a separate execution plan stage so the actual query results would be as expected. - let expected = vec![ + let expected = [ "+-----+----+", "| c1 | c2 |", "+-----+----+", @@ -1307,7 +1277,7 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - let expected = vec![ + let expected = [ "+----+----+", "| c1 | c2 |", "+----+----+", @@ -1414,7 +1384,7 @@ mod tests { .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+----+", "| c1 | c2 |", "+-----+----+", @@ -1445,7 +1415,7 @@ mod tests { .await .unwrap(); - let expected = vec![ + let expected = [ "+-----+----+", "| c1 | c2 |", "+-----+----+", @@ -1560,8 +1530,8 @@ mod tests { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups, + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], @@ -1633,11 +1603,11 @@ mod tests { let partitioned_file = PartitionedFile { object_meta: meta, partition_values: vec![ - ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), - Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + Box::new(ScalarValue::from("26")), ), ], range: None, @@ -1663,20 +1633,21 @@ mod tests { FileScanConfig { object_store_url, file_groups: vec![vec![partitioned_file]], - file_schema: schema, - statistics: Statistics::default(), + file_schema: schema.clone(), + statistics: Statistics::new_unknown(&schema), // file has 10 cols so index 12 should be month and 13 should be day projection: Some(vec![0, 1, 2, 12, 13]), limit: None, table_partition_cols: vec![ - ("year".to_owned(), DataType::Utf8), - ("month".to_owned(), DataType::UInt8), - ( - "day".to_owned(), + Field::new("year", DataType::Utf8, false), + Field::new("month", DataType::UInt8, false), + Field::new( + "day", DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), ), + false, ), ], output_ordering: vec![], @@ -1691,7 +1662,7 @@ mod tests { let mut results = parquet_exec.execute(0, task_ctx)?; let batch = results.next().await.unwrap()?; assert_eq!(batch.schema().as_ref(), &expected_schema); - let expected = vec![ + let expected = [ "+----+----------+-------------+-------+-----+", "| id | bool_col | tinyint_col | month | day |", "+----+----------+-------------+-------+-----+", @@ -1727,6 +1698,7 @@ mod tests { last_modified: Utc.timestamp_nanos(0), size: 1337, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -1738,7 +1710,7 @@ mod tests { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![vec![partitioned_file]], file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), + statistics: Statistics::new_unknown(&Schema::empty()), projection: None, limit: None, table_partition_cols: vec![], @@ -1782,14 +1754,12 @@ mod tests { // assert the batches and some metrics #[rustfmt::skip] - let expected = vec![ - "+-----+", + let expected = ["+-----+", "| int |", "+-----+", "| 4 |", "| 5 |", - "+-----+", - ]; + "+-----+"]; assert_batches_sorted_eq!(expected, &rt.batches.unwrap()); assert_eq!(get_value(&metrics, "page_index_rows_filtered"), 4); assert!( @@ -1826,7 +1796,7 @@ mod tests { let metrics = rt.parquet_exec.metrics().unwrap(); // assert the batches and some metrics - let expected = vec![ + let expected = [ "+-----+", "| c1 |", "+-----+", "| Foo |", "| zzz |", "+-----+", ]; assert_batches_sorted_eq!(expected, &rt.batches.unwrap()); @@ -1869,7 +1839,9 @@ mod tests { assert!(pruning_predicate.is_some()); // convert to explain plan form - let display = displayable(rt.parquet_exec.as_ref()).indent().to_string(); + let display = displayable(rt.parquet_exec.as_ref()) + .indent(true) + .to_string(); assert_contains!( &display, @@ -1920,242 +1892,6 @@ mod tests { assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } - #[tokio::test] - async fn parquet_exec_repartition_single_file() { - // Single file, single partition into multiple partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(4, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 31), - (1, "a".to_string(), 31, 62), - (2, "a".to_string(), 62, 93), - (3, "a".to_string(), 93, 123), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn parquet_exec_repartition_too_much_partitions() { - // Single file, single parittion into 96 partitions - let partitioned_file = PartitionedFile::new("a".to_string(), 8); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(96, 5) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 1), - (1, "a".to_string(), 1, 2), - (2, "a".to_string(), 2, 3), - (3, "a".to_string(), 3, 4), - (4, "a".to_string(), 4, 5), - (5, "a".to_string(), 5, 6), - (6, "a".to_string(), 6, 7), - (7, "a".to_string(), 7, 8), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn parquet_exec_repartition_multiple_partitions() { - // Multiple files in single partition after redistribution - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(3, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 34), - (1, "a".to_string(), 34, 40), - (1, "b".to_string(), 0, 28), - (2, "b".to_string(), 28, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn parquet_exec_repartition_same_num_partitions() { - // "Rebalance" files across partitions - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 40); - let partitioned_file_2 = PartitionedFile::new("b".to_string(), 60); - let source_partitions = vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = file_groups_to_vec( - parquet_exec - .get_repartitioned(2, 10) - .base_config() - .file_groups - .clone(), - ); - let expected = vec![ - (0, "a".to_string(), 0, 40), - (0, "b".to_string(), 0, 10), - (1, "b".to_string(), 10, 60), - ]; - assert_eq!(expected, actual); - } - - #[tokio::test] - async fn parquet_exec_repartition_no_action_ranges() { - // No action due to Some(range) in second file - let partitioned_file_1 = PartitionedFile::new("a".to_string(), 123); - let mut partitioned_file_2 = PartitionedFile::new("b".to_string(), 144); - partitioned_file_2.range = Some(FileRange { start: 1, end: 50 }); - - let source_partitions = vec![vec![partitioned_file_1], vec![partitioned_file_2]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: source_partitions, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = parquet_exec - .get_repartitioned(65, 10) - .base_config() - .file_groups - .clone(); - assert_eq!(2, actual.len()); - } - - #[tokio::test] - async fn parquet_exec_repartition_no_action_min_size() { - // No action due to target_partition_size - let partitioned_file = PartitionedFile::new("a".to_string(), 123); - let single_partition = vec![vec![partitioned_file]]; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: single_partition, - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - ); - - let actual = parquet_exec - .get_repartitioned(65, 500) - .base_config() - .file_groups - .clone(); - assert_eq!(1, actual.len()); - } - - fn file_groups_to_vec( - file_groups: Vec>, - ) -> Vec<(usize, String, i64, i64)> { - file_groups - .iter() - .enumerate() - .flat_map(|(part_idx, files)| { - files - .iter() - .map(|f| { - ( - part_idx, - f.object_meta.location.to_string(), - f.range.as_ref().unwrap().start, - f.range.as_ref().unwrap().end, - ) - }) - .collect_vec() - }) - .collect_vec() - } - /// returns the sum of all the metrics with the specified name /// the returned set. /// @@ -2207,8 +1943,9 @@ mod tests { // create partitioned input file and context let tmp_dir = TempDir::new()?; // let mut ctx = create_ctx(&tmp_dir, 4).await?; - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; // register csv file with the execution context ctx.register_csv( @@ -2218,40 +1955,45 @@ mod tests { ) .await?; + // register a local file system object store for /tmp directory + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + ctx.runtime_env().register_object_store(&local_url, local); + // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; + let out_dir_url = "file://local/out"; let df = ctx.sql("SELECT c1, c2 FROM test").await?; - df.write_parquet(&out_dir, None).await?; + df.write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) + .await?; // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; - // create a new context and verify that the results were saved to a partitioned csv file + // create a new context and verify that the results were saved to a partitioned parquet file let ctx = SessionContext::new(); + // get write_id + let mut paths = fs::read_dir(&out_dir).unwrap(); + let path = paths.next(); + let name = path + .unwrap()? + .path() + .file_name() + .expect("Should be a file name") + .to_str() + .expect("Should be a str") + .to_owned(); + println!("{name}"); + let (parsed_id, _) = name.split_once('_').expect("File should contain _ !"); + let write_id = parsed_id.to_owned(); + // register each partition as well as the top level dir ctx.register_parquet( "part0", - &format!("{out_dir}/part-0.parquet"), - ParquetReadOptions::default(), - ) - .await?; - ctx.register_parquet( - "part1", - &format!("{out_dir}/part-1.parquet"), - ParquetReadOptions::default(), - ) - .await?; - ctx.register_parquet( - "part2", - &format!("{out_dir}/part-2.parquet"), - ParquetReadOptions::default(), - ) - .await?; - ctx.register_parquet( - "part3", - &format!("{out_dir}/part-3.parquet"), + &format!("{out_dir}/{write_id}_0.parquet"), ParquetReadOptions::default(), ) .await?; + ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default()) .await?; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index e5c1d8feb0abc..42bfef35996e9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -39,9 +39,8 @@ use parquet::{ }; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; +use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; +use crate::datasource::physical_plan::parquet::statistics::from_bytes_to_i128; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use super::metrics::ParquetFileMetrics; @@ -147,17 +146,19 @@ impl PagePruningPredicate { let file_offset_indexes = file_metadata.offset_index(); let file_page_indexes = file_metadata.column_index(); - let (file_offset_indexes, file_page_indexes) = - match (file_offset_indexes, file_page_indexes) { - (Some(o), Some(i)) => (o, i), - _ => { - trace!( - "skip page pruning due to lack of indexes. Have offset: {} file: {}", + let (file_offset_indexes, file_page_indexes) = match ( + file_offset_indexes, + file_page_indexes, + ) { + (Some(o), Some(i)) => (o, i), + _ => { + trace!( + "skip page pruning due to lack of indexes. Have offset: {}, column index: {}", file_offset_indexes.is_some(), file_page_indexes.is_some() ); - return Ok(None); - } - }; + return Ok(None); + } + }; let mut row_selections = Vec::with_capacity(page_index_predicates.len()); for predicate in page_index_predicates { diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 0f4b09caeded5..5fe0a0a13a736 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -126,7 +126,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { match self .physical_expr .evaluate(&batch) - .map(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array(batch.num_rows())) { Ok(array) => { let bool_arr = as_boolean_array(&array)?.clone(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 07ef28304cce6..65414f5619a5b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -15,28 +15,35 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Schema}, +use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_schema::FieldRef; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use parquet::file::metadata::ColumnChunkMetaData; +use parquet::schema::types::SchemaDescriptor; +use parquet::{ + arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, + bloom_filter::Sbbf, + file::metadata::RowGroupMetaData, }; -use datafusion_common::Column; -use datafusion_common::ScalarValue; -use log::debug; - -use parquet::file::{ - metadata::RowGroupMetaData, statistics::Statistics as ParquetStatistics, +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, }; -use crate::datasource::physical_plan::parquet::{ - from_bytes_to_i128, parquet_to_arrow_decimal_type, -}; -use crate::{ - datasource::listing::FileRange, - physical_optimizer::pruning::{PruningPredicate, PruningStatistics}, +use crate::datasource::listing::FileRange; +use crate::datasource::physical_plan::parquet::statistics::{ + max_statistics, min_statistics, parquet_column, }; +use crate::logical_expr::Operator; +use crate::physical_expr::expressions as phys_expr; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use crate::physical_plan::PhysicalExpr; use super::ParquetFileMetrics; +/// Prune row groups based on statistics +/// /// Returns a vector of indexes into `groups` which should be scanned. /// /// If an index is NOT present in the returned Vec it means the @@ -44,7 +51,11 @@ use super::ParquetFileMetrics; /// /// If an index IS present in the returned Vec it means the predicate /// did not filter out that row group. -pub(crate) fn prune_row_groups( +/// +/// Note: This method currently ignores ColumnOrder +/// +pub(crate) fn prune_row_groups_by_statistics( + parquet_schema: &SchemaDescriptor, groups: &[RowGroupMetaData], range: Option, predicate: Option<&PruningPredicate>, @@ -67,8 +78,9 @@ pub(crate) fn prune_row_groups( if let Some(predicate) = predicate { let pruning_stats = RowGroupPruningStatistics { + parquet_schema, row_group_metadata: metadata, - parquet_schema: predicate.schema().as_ref(), + arrow_schema: predicate.schema().as_ref(), }; match predicate.prune(&pruning_stats) { Ok(values) => { @@ -81,7 +93,7 @@ pub(crate) fn prune_row_groups( // stats filter array could not be built // return a closure which will not filter out any row groups Err(e) => { - debug!("Error evaluating row group predicate values {e}"); + log::debug!("Error evaluating row group predicate values {e}"); metrics.predicate_evaluation_errors.add(1); } } @@ -92,146 +104,230 @@ pub(crate) fn prune_row_groups( filtered } -/// Wraps parquet statistics in a way -/// that implements [`PruningStatistics`] -struct RowGroupPruningStatistics<'a> { - row_group_metadata: &'a RowGroupMetaData, - parquet_schema: &'a Schema, -} - -/// Extract the min/max statistics from a `ParquetStatistics` object -macro_rules! get_statistic { - ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ - if !$column_statistics.has_min_max_set() { - return None; +/// Prune row groups by bloom filters +/// +/// Returns a vector of indexes into `groups` which should be scanned. +/// +/// If an index is NOT present in the returned Vec it means the +/// predicate filtered all the row group. +/// +/// If an index IS present in the returned Vec it means the predicate +/// did not filter out that row group. +pub(crate) async fn prune_row_groups_by_bloom_filters< + T: AsyncFileReader + Send + 'static, +>( + builder: &mut ParquetRecordBatchStreamBuilder, + row_groups: &[usize], + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, +) -> Vec { + let bf_predicates = match BloomFilterPruningPredicate::try_new(predicate.orig_expr()) + { + Ok(predicates) => predicates, + Err(_) => { + return row_groups.to_vec(); } - match $column_statistics { - ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), - ParquetStatistics::Int32(s) => { - match $target_arrow_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) + }; + let mut filtered = Vec::with_capacity(groups.len()); + for idx in row_groups { + let rg_metadata = &groups[*idx]; + // get all columns bloom filter + let mut column_sbbf = + HashMap::with_capacity(bf_predicates.required_columns.len()); + for column_name in bf_predicates.required_columns.iter() { + let column_idx = match rg_metadata + .columns() + .iter() + .enumerate() + .find(|(_, column)| column.column_path().string().eq(column_name)) + { + Some((column_idx, _)) => column_idx, + None => continue, + }; + let bf = match builder + .get_row_group_column_bloom_filter(*idx, column_idx) + .await + { + Ok(bf) => match bf { + Some(bf) => bf, + None => { + continue; } - _ => Some(ScalarValue::Int32(Some(*s.$func()))), + }, + Err(e) => { + log::error!("Error evaluating row group predicate values when using BloomFilterPruningPredicate {e}"); + metrics.predicate_evaluation_errors.add(1); + continue; } + }; + column_sbbf.insert(column_name.to_owned(), bf); + } + if bf_predicates.prune(&column_sbbf) { + metrics.row_groups_pruned.add(1); + continue; + } + filtered.push(*idx); + } + filtered +} + +struct BloomFilterPruningPredicate { + /// Actual pruning predicate + predicate_expr: Option, + /// The statistics required to evaluate this predicate + required_columns: Vec, +} + +impl BloomFilterPruningPredicate { + fn try_new(expr: &Arc) -> Result { + let binary_expr = expr.as_any().downcast_ref::(); + match binary_expr { + Some(binary_expr) => { + let columns = Self::get_predicate_columns(expr); + Ok(Self { + predicate_expr: Some(binary_expr.clone()), + required_columns: columns.into_iter().collect(), + }) } - ParquetStatistics::Int64(s) => { - match $target_arrow_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - precision, - scale, - )) - } - _ => Some(ScalarValue::Int64(Some(*s.$func()))), + None => Err(DataFusionError::Execution( + "BloomFilterPruningPredicate only support binary expr".to_string(), + )), + } + } + + fn prune(&self, column_sbbf: &HashMap) -> bool { + Self::prune_expr_with_bloom_filter(self.predicate_expr.as_ref(), column_sbbf) + } + + /// Return true if the `expr` can be proved not `true` + /// based on the bloom filter. + /// + /// We only checked `BinaryExpr` but it also support `InList`, + /// Because of the `optimizer` will convert `InList` to `BinaryExpr`. + fn prune_expr_with_bloom_filter( + expr: Option<&phys_expr::BinaryExpr>, + column_sbbf: &HashMap, + ) -> bool { + let Some(expr) = expr else { + // unsupported predicate + return false; + }; + match expr.op() { + Operator::And | Operator::Or => { + let left = Self::prune_expr_with_bloom_filter( + expr.left().as_any().downcast_ref::(), + column_sbbf, + ); + let right = Self::prune_expr_with_bloom_filter( + expr.right() + .as_any() + .downcast_ref::(), + column_sbbf, + ); + match expr.op() { + Operator::And => left || right, + Operator::Or => left && right, + _ => false, } } - // 96 bit ints not supported - ParquetStatistics::Int96(_) => None, - ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), - ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), - ParquetStatistics::ByteArray(s) => { - match $target_arrow_type { - // decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) - .map(|s| s.to_string()) - .ok(); - Some(ScalarValue::Utf8(s)) + Operator::Eq => { + if let Some((col, val)) = Self::check_expr_is_col_equal_const(expr) { + if let Some(sbbf) = column_sbbf.get(col.name()) { + match val { + ScalarValue::Utf8(Some(v)) => !sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => !sbbf.check(&v), + ScalarValue::Float64(Some(v)) => !sbbf.check(&v), + ScalarValue::Float32(Some(v)) => !sbbf.check(&v), + ScalarValue::Int64(Some(v)) => !sbbf.check(&v), + ScalarValue::Int32(Some(v)) => !sbbf.check(&v), + ScalarValue::Int16(Some(v)) => !sbbf.check(&v), + ScalarValue::Int8(Some(v)) => !sbbf.check(&v), + _ => false, + } + } else { + false } + } else { + false } } - // type not supported yet - ParquetStatistics::FixedLenByteArray(s) => { - match $target_arrow_type { - // just support the decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - precision, - scale, - )) - } - _ => None, + _ => false, + } + } + + fn get_predicate_columns(expr: &Arc) -> HashSet { + let mut columns = HashSet::new(); + expr.apply(&mut |expr| { + if let Some(binary_expr) = + expr.as_any().downcast_ref::() + { + if let Some((column, _)) = + Self::check_expr_is_col_equal_const(binary_expr) + { + columns.insert(column.name().to_string()); } } - } - }}; -} + Ok(VisitRecursion::Continue) + }) + // no way to fail as only Ok(VisitRecursion::Continue) is returned + .unwrap(); -// Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate -macro_rules! get_min_max_values { - ($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{ - let (_column_index, field) = - if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) { - (v, f) - } else { - // Named column was not present - return None; - }; + columns + } - let data_type = field.data_type(); - // The result may be None, because DataFusion doesn't have support for ScalarValues of the column type - let null_scalar: ScalarValue = data_type.try_into().ok()?; + fn check_expr_is_col_equal_const( + exr: &phys_expr::BinaryExpr, + ) -> Option<(phys_expr::Column, ScalarValue)> { + if Operator::Eq.ne(exr.op()) { + return None; + } - $self.row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - .and_then(|c| if c.statistics().is_some() {Some((c.statistics().unwrap(), c.column_descr()))} else {None}) - .map(|(stats, column_descr)| - { - let target_data_type = parquet_to_arrow_decimal_type(column_descr); - get_statistic!(stats, $func, $bytes_func, target_data_type) - }) - .flatten() - // column either didn't have statistics at all or didn't have min/max values - .or_else(|| Some(null_scalar.clone())) - .map(|s| s.to_array()) - }} + let left_any = exr.left().as_any(); + let right_any = exr.right().as_any(); + if let (Some(col), Some(liter)) = ( + left_any.downcast_ref::(), + right_any.downcast_ref::(), + ) { + return Some((col.clone(), liter.value().clone())); + } + if let (Some(liter), Some(col)) = ( + left_any.downcast_ref::(), + right_any.downcast_ref::(), + ) { + return Some((col.clone(), liter.value().clone())); + } + None + } } -// Extract the null count value on the ParquetStatistics -macro_rules! get_null_count_values { - ($self:expr, $column:expr) => {{ - let value = ScalarValue::UInt64( - if let Some(col) = $self - .row_group_metadata - .columns() - .iter() - .find(|c| c.column_descr().name() == &$column.name) - { - col.statistics().map(|s| s.null_count()) - } else { - Some($self.row_group_metadata.num_rows() as u64) - }, - ); +/// Wraps [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] +/// +/// Note: This should be implemented for an array of [`RowGroupMetaData`] instead +/// of per row-group +struct RowGroupPruningStatistics<'a> { + parquet_schema: &'a SchemaDescriptor, + row_group_metadata: &'a RowGroupMetaData, + arrow_schema: &'a Schema, +} - Some(value.to_array()) - }}; +impl<'a> RowGroupPruningStatistics<'a> { + /// Lookups up the parquet column by name + fn column(&self, name: &str) -> Option<(&ColumnChunkMetaData, &FieldRef)> { + let (idx, field) = parquet_column(self.parquet_schema, self.arrow_schema, name)?; + Some((self.row_group_metadata.column(idx), field)) + } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, min, min_bytes) + let (column, field) = self.column(&column.name)?; + min_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn max_values(&self, column: &Column) -> Option { - get_min_max_values!(self, column, max, max_bytes) + let (column, field) = self.column(&column.name)?; + max_statistics(field.data_type(), std::iter::once(column.statistics())).ok() } fn num_containers(&self) -> usize { @@ -239,21 +335,31 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { } fn null_counts(&self, column: &Column) -> Option { - get_null_count_values!(self, column) + let (c, _) = self.column(&column.name)?; + let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); + scalar.to_array().ok() } } #[cfg(test)] mod tests { use super::*; + use crate::datasource::physical_plan::parquet::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_common::ToDFSchema; - use datafusion_expr::{cast, col, lit, Expr}; + use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; + use datafusion_common::{DataFusionError, Result}; + use datafusion_expr::{ + builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, + TableSource, WindowUDF, + }; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; + use datafusion_sql::planner::ContextProvider; + use parquet::arrow::arrow_to_parquet_schema; + use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::{ @@ -329,7 +435,13 @@ mod tests { let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups(&[rgm1, rgm2], None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + &[rgm1, rgm2], + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -358,7 +470,13 @@ mod tests { // missing statistics for first row group mean that the result from the predicate expression // is null / undefined so the first row group can't be filtered out assert_eq!( - prune_row_groups(&[rgm1, rgm2], None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + &[rgm1, rgm2], + None, + Some(&pruning_predicate), + &metrics + ), vec![0, 1] ); } @@ -400,7 +518,13 @@ mod tests { // the first row group is still filtered out because the predicate expression can be partially evaluated // when conditions are joined using AND assert_eq!( - prune_row_groups(groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); @@ -413,7 +537,13 @@ mod tests { // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out assert_eq!( - prune_row_groups(groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![0, 1] ); } @@ -448,6 +578,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); @@ -456,7 +587,13 @@ mod tests { let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". assert_eq!( - prune_row_groups(&groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + &groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -471,6 +608,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let schema_descr = arrow_to_parquet_schema(&schema).unwrap(); let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); @@ -482,7 +620,13 @@ mod tests { // bool = NULL always evaluates to NULL (and thus will not // pass predicates. Ideally these should both be false assert_eq!( - prune_row_groups(&groups, None, Some(&pruning_predicate), &metrics), + prune_row_groups_by_statistics( + &schema_descr, + &groups, + None, + Some(&pruning_predicate), + &metrics + ), vec![1] ); } @@ -535,7 +679,8 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -598,7 +743,8 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3, rgm4], None, Some(&pruning_predicate), @@ -645,7 +791,8 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -715,7 +862,8 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -774,7 +922,8 @@ mod tests { ); let metrics = parquet_file_metrics(); assert_eq!( - prune_row_groups( + prune_row_groups_by_statistics( + &schema_descr, &[rgm1, rgm2, rgm3], None, Some(&pruning_predicate), @@ -788,7 +937,6 @@ mod tests { schema_descr: &SchemaDescPtr, column_statistics: Vec, ) -> RowGroupMetaData { - use parquet::file::metadata::ColumnChunkMetaData; let mut columns = vec![]; for (i, s) in column_statistics.iter().enumerate() { let column = ColumnChunkMetaData::builder(schema_descr.column(i)) @@ -806,8 +954,8 @@ mod tests { } fn get_test_schema_descr(fields: Vec) -> SchemaDescPtr { - use parquet::schema::types::{SchemaDescriptor, Type as SchemaType}; - let mut schema_fields = fields + use parquet::schema::types::Type as SchemaType; + let schema_fields = fields .iter() .map(|field| { let mut builder = @@ -829,7 +977,7 @@ mod tests { }) .collect::>(); let schema = SchemaType::group_type_builder("schema") - .with_fields(&mut schema_fields) + .with_fields(schema_fields) .build() .unwrap(); @@ -846,4 +994,391 @@ mod tests { let execution_props = ExecutionProps::new(); create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } + + // Note the values in the `String` column are: + // ❯ select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + // +-----------+ + // | String | + // +-----------+ + // | Hello | + // | This is | + // | a | + // | test | + // | How | + // | are you | + // | doing | + // | today | + // | the quick | + // | brown fox | + // | jumps | + // | over | + // | the lazy | + // | dog | + // +-----------+ + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello_Not_exists")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#).eq(lit("Hello_Not_Exists")); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert!(pruned_row_groups.is_empty()); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert!(pruned_row_groups.is_empty()); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate + let schema = Schema::new(vec![ + Field::new("String", DataType::Utf8, false), + Field::new("String3", DataType::Utf8, false), + ]); + let sql = + "SELECT * FROM tbl WHERE \"String\" IN ('Hello_Not_Exists', 'Hello_Not_Exists2')"; + let expr = sql_to_physical_plan(sql).unwrap(); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert!(pruned_row_groups.is_empty()); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#).eq(lit("Hello")); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate `(String = "foo") OR (String != "bar")` + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let expr = col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "alltypes_plain.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate on a column without a bloom filter + let schema = Schema::new(vec![Field::new("string_col", DataType::Utf8, false)]); + let expr = col(r#""string_col""#).eq(lit("0")); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let row_groups = vec![0]; + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + &row_groups, + ) + .await + .unwrap(); + assert_eq!(pruned_row_groups, row_groups); + } + + async fn test_row_group_bloom_filter_pruning_predicate( + file_name: &str, + data: bytes::Bytes, + pruning_predicate: &PruningPredicate, + row_groups: &[usize], + ) -> Result> { + use object_store::{ObjectMeta, ObjectStore}; + + let object_meta = ObjectMeta { + location: object_store::path::Path::parse(file_name).expect("creating path"), + last_modified: chrono::DateTime::from(std::time::SystemTime::now()), + size: data.len(), + e_tag: None, + version: None, + }; + let in_memory = object_store::memory::InMemory::new(); + in_memory + .put(&object_meta.location, data) + .await + .expect("put parquet file into in memory object store"); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = + ParquetFileMetrics::new(0, object_meta.location.as_ref(), &metrics); + let reader = ParquetFileReader { + inner: ParquetObjectReader::new(Arc::new(in_memory), object_meta), + file_metrics: file_metrics.clone(), + }; + let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); + + let metadata = builder.metadata().clone(); + let pruned_row_group = prune_row_groups_by_bloom_filters( + &mut builder, + row_groups, + metadata.row_groups(), + pruning_predicate, + &file_metrics, + ) + .await; + + Ok(pruned_row_group) + } + + fn sql_to_physical_plan(sql: &str) -> Result> { + use datafusion_optimizer::{ + analyzer::Analyzer, optimizer::Optimizer, OptimizerConfig, OptimizerContext, + }; + use datafusion_sql::{ + planner::SqlToRel, + sqlparser::{ast::Statement, parser::Parser}, + }; + use sqlparser::dialect::GenericDialect; + + // parse the SQL + let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... + let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); + let statement = &ast[0]; + + // create a logical query plan + let schema_provider = TestSchemaProvider::new(); + let sql_to_rel = SqlToRel::new(&schema_provider); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + + // hard code the return value of now() + let config = OptimizerContext::new().with_skip_failing_rules(false); + let analyzer = Analyzer::new(); + let optimizer = Optimizer::new(); + // analyze and optimize the logical plan + let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; + let plan = optimizer.optimize(&plan, &config, |_, _| {})?; + // convert the logical plan into a physical plan + let exprs = plan.expressions(); + let expr = &exprs[0]; + let df_schema = plan.schema().as_ref().to_owned(); + let tb_schema: Schema = df_schema.clone().into(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, &tb_schema, &execution_props) + } + + struct TestSchemaProvider { + options: ConfigOptions, + tables: HashMap>, + } + + impl TestSchemaProvider { + pub fn new() -> Self { + let mut tables = HashMap::new(); + tables.insert( + "tbl".to_string(), + create_table_source(vec![Field::new( + "String".to_string(), + DataType::Utf8, + false, + )]), + ); + + Self { + options: Default::default(), + tables, + } + } + } + + impl ContextProvider for TestSchemaProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + match self.tables.get(name.table()) { + Some(table) => Ok(table.clone()), + _ => datafusion_common::plan_err!("Table not found: {}", name.table()), + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + } + + fn create_table_source(fields: Vec) -> Arc { + Arc::new(LogicalTableSource::new(Arc::new(Schema::new(fields)))) + } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs new file mode 100644 index 0000000000000..4e472606da515 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -0,0 +1,899 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. + +// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 + +use arrow::{array::ArrayRef, datatypes::DataType}; +use arrow_array::new_empty_array; +use arrow_schema::{FieldRef, Schema}; +use datafusion_common::{Result, ScalarValue}; +use parquet::file::statistics::Statistics as ParquetStatistics; +use parquet::schema::types::SchemaDescriptor; + +// Convert the bytes array to i128. +// The endian of the input bytes array must be big-endian. +pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { + // The bytes array are from parquet file and must be the big-endian. + // The endian is defined by parquet format, and the reference document + // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 + i128::from_be_bytes(sign_extend_be(b)) +} + +// Copy from arrow-rs +// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 +// Convert the byte slice to fixed length byte array with the length of 16 +fn sign_extend_be(b: &[u8]) -> [u8; 16] { + assert!(b.len() <= 16, "Array too large, expected less than 16"); + let is_negative = (b[0] & 128u8) == 128u8; + let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; + for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { + *d = *s; + } + result +} + +/// Extract a single min/max statistics from a [`ParquetStatistics`] object +/// +/// * `$column_statistics` is the `ParquetStatistics` object +/// * `$func is the function` (`min`/`max`) to call to get the value +/// * `$bytes_func` is the function (`min_bytes`/`max_bytes`) to call to get the value as bytes +/// * `$target_arrow_type` is the [`DataType`] of the target statistics +macro_rules! get_statistic { + ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ + if !$column_statistics.has_min_max_set() { + return None; + } + match $column_statistics { + ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), + ParquetStatistics::Int32(s) => { + match $target_arrow_type { + // int32 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int32(Some(*s.$func()))), + } + } + ParquetStatistics::Int64(s) => { + match $target_arrow_type { + // int64 to decimal with the precision and scale + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(*s.$func() as i128), + *precision, + *scale, + )) + } + _ => Some(ScalarValue::Int64(Some(*s.$func()))), + } + } + // 96 bit ints not supported + ParquetStatistics::Int96(_) => None, + ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), + ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), + ParquetStatistics::ByteArray(s) => { + match $target_arrow_type { + // decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => { + let s = std::str::from_utf8(s.$bytes_func()) + .map(|s| s.to_string()) + .ok(); + Some(ScalarValue::Utf8(s)) + } + } + } + // type not supported yet + ParquetStatistics::FixedLenByteArray(s) => { + match $target_arrow_type { + // just support the decimal data type + Some(DataType::Decimal128(precision, scale)) => { + Some(ScalarValue::Decimal128( + Some(from_bytes_to_i128(s.$bytes_func())), + *precision, + *scale, + )) + } + _ => None, + } + } + } + }}; +} + +/// Lookups up the parquet column by name +/// +/// Returns the parquet column index and the corresponding arrow field +pub(crate) fn parquet_column<'a>( + parquet_schema: &SchemaDescriptor, + arrow_schema: &'a Schema, + name: &str, +) -> Option<(usize, &'a FieldRef)> { + let (root_idx, field) = arrow_schema.fields.find(name)?; + if field.data_type().is_nested() { + // Nested fields are not supported and require non-trivial logic + // to correctly walk the parquet schema accounting for the + // logical type rules - + // + // For example a ListArray could correspond to anything from 1 to 3 levels + // in the parquet schema + return None; + } + + // This could be made more efficient (#TBD) + let parquet_idx = (0..parquet_schema.columns().len()) + .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; + Some((parquet_idx, field)) +} + +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn min_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, min, min_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] +pub(crate) fn max_statistics<'a, I: Iterator>>( + data_type: &DataType, + iterator: I, +) -> Result { + let scalars = iterator + .map(|x| x.and_then(|s| get_statistic!(s, max, max_bytes, Some(data_type)))); + collect_scalars(data_type, scalars) +} + +/// Builds an array from an iterator of ScalarValue +fn collect_scalars>>( + data_type: &DataType, + iterator: I, +) -> Result { + let mut scalars = iterator.peekable(); + match scalars.peek().is_none() { + true => Ok(new_empty_array(data_type)), + false => { + let null = ScalarValue::try_from(data_type)?; + ScalarValue::iter_to_array(scalars.map(|x| x.unwrap_or_else(|| null.clone()))) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::{ + new_null_array, Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, + Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + TimestampNanosecondArray, + }; + use arrow_schema::{Field, SchemaRef}; + use bytes::Bytes; + use datafusion_common::test_util::parquet_test_data; + use parquet::arrow::arrow_reader::ArrowReaderBuilder; + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + use std::path::PathBuf; + use std::sync::Arc; + + // TODO error cases (with parquet statistics that are mismatched in expected type) + + #[test] + fn roundtrip_empty() { + let empty_bool_array = new_empty_array(&DataType::Boolean); + Test { + input: empty_bool_array.clone(), + expected_min: empty_bool_array.clone(), + expected_max: empty_bool_array.clone(), + } + .run() + } + + #[test] + fn roundtrip_bool() { + Test { + input: bool_array([ + // row group 1 + Some(true), + None, + Some(true), + // row group 2 + Some(true), + Some(false), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: bool_array([Some(true), Some(false), None]), + expected_max: bool_array([Some(true), Some(true), None]), + } + .run() + } + + #[test] + fn roundtrip_int32() { + Test { + input: i32_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i32_array([Some(1), Some(0), None]), + expected_max: i32_array([Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_int64() { + Test { + input: i64_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(0), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: i64_array([Some(1), Some(0), None]), + expected_max: i64_array(vec![Some(3), Some(5), None]), + } + .run() + } + + #[test] + fn roundtrip_f32() { + Test { + input: f32_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f32_array([Some(1.0), Some(-1.0), None]), + expected_max: f32_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + fn roundtrip_f64() { + Test { + input: f64_array([ + // row group 1 + Some(1.0), + None, + Some(3.0), + // row group 2 + Some(-1.0), + Some(5.0), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: f64_array([Some(1.0), Some(-1.0), None]), + expected_max: f64_array([Some(3.0), Some(5.0), None]), + } + .run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_timestamp() { + Test { + input: timestamp_array([ + // row group 1 + Some(1), + None, + Some(3), + // row group 2 + Some(9), + Some(5), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: timestamp_array([Some(1), Some(5), None]), + expected_max: timestamp_array([Some(3), Some(9), None]), + } + .run() + } + + #[test] + fn roundtrip_decimal() { + Test { + input: Arc::new( + Decimal128Array::from(vec![ + // row group 1 + Some(100), + None, + Some(22000), + // row group 2 + Some(500000), + Some(330000), + None, + // row group 3 + None, + None, + None, + ]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_min: Arc::new( + Decimal128Array::from(vec![Some(100), Some(330000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(22000), Some(500000), None]) + .with_precision_and_scale(9, 2) + .unwrap(), + ), + } + .run() + } + + #[test] + fn roundtrip_utf8() { + Test { + input: utf8_array([ + // row group 1 + Some("A"), + None, + Some("Q"), + // row group 2 + Some("ZZ"), + Some("AA"), + None, + // row group 3 + None, + None, + None, + ]), + expected_min: utf8_array([Some("A"), Some("AA"), None]), + expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), + } + .run() + } + + #[test] + fn roundtrip_struct() { + let mut test = Test { + input: struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + // row group 2 + (Some(true), Some(0)), + (Some(false), Some(5)), + (None, None), + // row group 3 + (None, None), + (None, None), + (None, None), + ]), + expected_min: struct_array(vec![ + (Some(true), Some(1)), + (Some(true), Some(0)), + (None, None), + ]), + + expected_max: struct_array(vec![ + (Some(true), Some(3)), + (Some(true), Some(0)), + (None, None), + ]), + }; + // Due to https://github.com/apache/arrow-datafusion/issues/8334, + // statistics for struct arrays are not supported + test.expected_min = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.expected_max = + new_null_array(test.input.data_type(), test.expected_min.len()); + test.run() + } + + #[test] + #[should_panic( + expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Utf8, got Binary(NULL)" + )] + // Due to https://github.com/apache/arrow-datafusion/issues/8295 + fn roundtrip_binary() { + Test { + input: Arc::new(BinaryArray::from_opt_vec(vec![ + // row group 1 + Some(b"A"), + None, + Some(b"Q"), + // row group 2 + Some(b"ZZ"), + Some(b"AA"), + None, + // row group 3 + None, + None, + None, + ])), + expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"A"), + Some(b"AA"), + None, + ])), + expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"Q"), + Some(b"ZZ"), + None, + ])), + } + .run() + } + + #[test] + fn struct_and_non_struct() { + // Ensures that statistics for an array that appears *after* a struct + // array are not wrong + let struct_col = struct_array(vec![ + // row group 1 + (Some(true), Some(1)), + (None, None), + (Some(true), Some(3)), + ]); + let int_col = i32_array([Some(100), Some(200), Some(300)]); + let expected_min = i32_array([Some(100)]); + let expected_max = i32_array(vec![Some(300)]); + + // use a name that shadows a name in the struct column + match struct_col.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.get(1).unwrap().name(), "int_col") + } + _ => panic!("unexpected data type for struct column"), + }; + + let input_batch = RecordBatch::try_from_iter([ + ("struct_col", struct_col), + ("int_col", int_col), + ]) + .unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + // read the int_col statistics + let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); + assert_eq!(idx, 2); + + let row_groups = metadata.row_groups(); + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + + let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(&DataType::Int32, iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + + #[test] + fn nan_in_stats() { + // /parquet-testing/data/nan_in_stats.parquet + // row_groups: 1 + // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + + TestFile::new("nan_in_stats.parquet") + .with_column(ExpectedColumn { + name: "x", + expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), + expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), + }) + .run(); + } + + #[test] + fn alltypes_plain() { + // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet + // row_groups: 1 + // (has no statistics) + TestFile::new("alltypes_plain.parquet") + // No column statistics should be read as NULL, but with the right type + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([None]), + expected_max: i32_array([None]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([None]), + expected_max: bool_array([None]), + }) + .run(); + } + + #[test] + fn alltypes_tiny_pages() { + // /parquet-testing/data/alltypes_tiny_pages.parquet + // row_groups: 1 + // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) + TestFile::new("alltypes_tiny_pages.parquet") + .with_column(ExpectedColumn { + name: "id", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(7299)]), + }) + .with_column(ExpectedColumn { + name: "bool_col", + expected_min: bool_array([Some(false)]), + expected_max: bool_array([Some(true)]), + }) + .with_column(ExpectedColumn { + name: "tinyint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "smallint_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "int_col", + expected_min: i32_array([Some(0)]), + expected_max: i32_array([Some(9)]), + }) + .with_column(ExpectedColumn { + name: "bigint_col", + expected_min: i64_array([Some(0)]), + expected_max: i64_array([Some(90)]), + }) + .with_column(ExpectedColumn { + name: "float_col", + expected_min: f32_array([Some(0.0)]), + expected_max: f32_array([Some(9.9)]), + }) + .with_column(ExpectedColumn { + name: "double_col", + expected_min: f64_array([Some(0.0)]), + expected_max: f64_array([Some(90.89999999999999)]), + }) + .with_column(ExpectedColumn { + name: "date_string_col", + expected_min: utf8_array([Some("01/01/09")]), + expected_max: utf8_array([Some("12/31/10")]), + }) + .with_column(ExpectedColumn { + name: "string_col", + expected_min: utf8_array([Some("0")]), + expected_max: utf8_array([Some("9")]), + }) + // File has no min/max for timestamp_col + .with_column(ExpectedColumn { + name: "timestamp_col", + expected_min: timestamp_array([None]), + expected_max: timestamp_array([None]), + }) + .with_column(ExpectedColumn { + name: "year", + expected_min: i32_array([Some(2009)]), + expected_max: i32_array([Some(2010)]), + }) + .with_column(ExpectedColumn { + name: "month", + expected_min: i32_array([Some(1)]), + expected_max: i32_array([Some(12)]), + }) + .run(); + } + + #[test] + fn fixed_length_decimal_legacy() { + // /parquet-testing/data/fixed_length_decimal_legacy.parquet + // row_groups: 1 + // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) + + TestFile::new("fixed_length_decimal_legacy.parquet") + .with_column(ExpectedColumn { + name: "value", + expected_min: Arc::new( + Decimal128Array::from(vec![Some(200)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + expected_max: Arc::new( + Decimal128Array::from(vec![Some(2400)]) + .with_precision_and_scale(13, 2) + .unwrap(), + ), + }) + .run(); + } + + const ROWS_PER_ROW_GROUP: usize = 3; + + /// Writes the input batch into a parquet file, with every every three rows as + /// their own row group, and compares the min/maxes to the expected values + struct Test { + input: ArrayRef, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + impl Test { + fn run(self) { + let Self { + input, + expected_min, + expected_max, + } = self; + + let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); + + let schema = input_batch.schema(); + + let metadata = parquet_metadata(schema.clone(), input_batch); + let parquet_schema = metadata.file_metadata().schema_descr(); + + let row_groups = metadata.row_groups(); + + for field in schema.fields() { + if field.data_type().is_nested() { + let lookup = parquet_column(parquet_schema, &schema, field.name()); + assert_eq!(lookup, None); + continue; + } + + let (idx, f) = + parquet_column(parquet_schema, &schema, field.name()).unwrap(); + assert_eq!(f, field); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!( + &min, + &expected_min, + "Min. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + + let max = max_statistics(f.data_type(), iter).unwrap(); + assert_eq!( + &max, + &expected_max, + "Max. Statistics\n\n{}\n\n", + DisplayStats(row_groups) + ); + } + } + } + + /// Write the specified batches out as parquet and return the metadata + fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_max_row_group_size(ROWS_PER_ROW_GROUP) + .build(); + + let mut buffer = Vec::new(); + let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); + reader.metadata().clone() + } + + /// Formats the statistics nicely for display + struct DisplayStats<'a>(&'a [RowGroupMetaData]); + impl<'a> std::fmt::Display for DisplayStats<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row_groups = self.0; + writeln!(f, " row_groups: {}", row_groups.len())?; + for rg in row_groups { + for col in rg.columns() { + if let Some(statistics) = col.statistics() { + writeln!(f, " {}: {:?}", col.column_path(), statistics)?; + } + } + } + Ok(()) + } + } + + struct ExpectedColumn { + name: &'static str, + expected_min: ArrayRef, + expected_max: ArrayRef, + } + + /// Reads statistics out of the specified, and compares them to the expected values + struct TestFile { + file_name: &'static str, + expected_columns: Vec, + } + + impl TestFile { + fn new(file_name: &'static str) -> Self { + Self { + file_name, + expected_columns: Vec::new(), + } + } + + fn with_column(mut self, column: ExpectedColumn) -> Self { + self.expected_columns.push(column); + self + } + + /// Reads the specified parquet file and validates that the exepcted min/max + /// values for the specified columns are as expected. + fn run(self) { + let path = PathBuf::from(parquet_test_data()).join(self.file_name); + let file = std::fs::File::open(path).unwrap(); + let reader = ArrowReaderBuilder::try_new(file).unwrap(); + let arrow_schema = reader.schema(); + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + let parquet_schema = metadata.file_metadata().schema_descr(); + + for expected_column in self.expected_columns { + let ExpectedColumn { + name, + expected_min, + expected_max, + } = expected_column; + + let (idx, field) = + parquet_column(parquet_schema, arrow_schema, name).unwrap(); + + let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + assert_eq!(&expected_min, &actual_min, "column {name}"); + + let actual_max = max_statistics(field.data_type(), iter).unwrap(); + assert_eq!(&expected_max, &actual_max, "column {name}"); + } + } + } + + fn bool_array(input: impl IntoIterator>) -> ArrayRef { + let array: BooleanArray = input.into_iter().collect(); + Arc::new(array) + } + + fn i32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn i64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Int64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f32_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float32Array = input.into_iter().collect(); + Arc::new(array) + } + + fn f64_array(input: impl IntoIterator>) -> ArrayRef { + let array: Float64Array = input.into_iter().collect(); + Arc::new(array) + } + + fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { + let array: TimestampNanosecondArray = input.into_iter().collect(); + Arc::new(array) + } + + fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { + let array: StringArray = input + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect(); + Arc::new(array) + } + + // returns a struct array with columns "bool_col" and "int_col" with the specified values + fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { + let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); + let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); + + let nullable = true; + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), + Arc::new(boolean) as ArrayRef, + ), + ( + Arc::new(Field::new("int_col", DataType::Int32, nullable)), + Arc::new(int) as ArrayRef, + ), + ]); + Arc::new(struct_array) + } +} diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs new file mode 100644 index 0000000000000..275523405a094 --- /dev/null +++ b/datafusion/core/src/datasource/provider.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Data source traits + +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_common::{not_impl_err, Constraints, DataFusionError, Statistics}; +use datafusion_expr::{CreateExternalTable, LogicalPlan}; +pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; + +use crate::arrow::datatypes::SchemaRef; +use crate::datasource::listing_table_factory::ListingTableFactory; +use crate::datasource::stream::StreamTableFactory; +use crate::error::Result; +use crate::execution::context::SessionState; +use crate::logical_expr::Expr; +use crate::physical_plan::ExecutionPlan; + +/// Source table +#[async_trait] +pub trait TableProvider: Sync + Send { + /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Get a reference to the schema for this table + fn schema(&self) -> SchemaRef; + + /// Get a reference to the constraints of the table. + /// Returns: + /// - `None` for tables that do not support constraints. + /// - `Some(&Constraints)` for tables supporting constraints. + /// Therefore, a `Some(&Constraints::empty())` return value indicates that + /// this table supports constraints, but there are no constraints. + fn constraints(&self) -> Option<&Constraints> { + None + } + + /// Get the type of this table for metadata/catalog purposes. + fn table_type(&self) -> TableType; + + /// Get the create statement used to create this table, if available. + fn get_table_definition(&self) -> Option<&str> { + None + } + + /// Get the [`LogicalPlan`] of this table, if available + fn get_logical_plan(&self) -> Option<&LogicalPlan> { + None + } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } + + /// Create an [`ExecutionPlan`] for scanning the table with optionally + /// specified `projection`, `filter` and `limit`, described below. + /// + /// The `ExecutionPlan` is responsible scanning the datasource's + /// partitions in a streaming, parallelized fashion. + /// + /// # Projection + /// + /// If specified, only a subset of columns should be returned, in the order + /// specified. The projection is a set of indexes of the fields in + /// [`Self::schema`]. + /// + /// DataFusion provides the projection to scan only the columns actually + /// used in the query to improve performance, an optimization called + /// "Projection Pushdown". Some datasources, such as Parquet, can use this + /// information to go significantly faster when only a subset of columns is + /// required. + /// + /// # Filters + /// + /// A list of boolean filter [`Expr`]s to evaluate *during* the scan, in the + /// manner specified by [`Self::supports_filters_pushdown`]. Only rows for + /// which *all* of the `Expr`s evaluate to `true` must be returned (aka the + /// expressions are `AND`ed together). + /// + /// DataFusion pushes filtering into the scans whenever possible + /// ("Projection Pushdown"), and depending on the format and the + /// implementation of the format, evaluating the predicate during the scan + /// can increase performance significantly. + /// + /// ## Note: Some columns may appear *only* in Filters + /// + /// In certain cases, a query may only use a certain column in a Filter that + /// has been completely pushed down to the scan. In this case, the + /// projection will not contain all the columns found in the filter + /// expressions. + /// + /// For example, given the query `SELECT t.a FROM t WHERE t.b > 5`, + /// + /// ```text + /// ┌────────────────────┐ + /// │ Projection(t.a) │ + /// └────────────────────┘ + /// ▲ + /// │ + /// │ + /// ┌────────────────────┐ Filter ┌────────────────────┐ Projection ┌────────────────────┐ + /// │ Filter(t.b > 5) │────Pushdown──▶ │ Projection(t.a) │ ───Pushdown───▶ │ Projection(t.a) │ + /// └────────────────────┘ └────────────────────┘ └────────────────────┘ + /// ▲ ▲ ▲ + /// │ │ │ + /// │ │ ┌────────────────────┐ + /// ┌────────────────────┐ ┌────────────────────┐ │ Scan │ + /// │ Scan │ │ Scan │ │ filter=(t.b > 5) │ + /// └────────────────────┘ │ filter=(t.b > 5) │ │ projection=(t.a) │ + /// └────────────────────┘ └────────────────────┘ + /// + /// Initial Plan If `TableProviderFilterPushDown` Projection pushdown notes that + /// returns true, filter pushdown the scan only needs t.a + /// pushes the filter into the scan + /// BUT internally evaluating the + /// predicate still requires t.b + /// ``` + /// + /// # Limit + /// + /// If `limit` is specified, must only produce *at least* this many rows, + /// (though it may return more). Like Projection Pushdown and Filter + /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as + /// possible, called "Limit Pushdown" as some sources can use this + /// information to improve their performance. + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result>; + + /// Tests whether the table provider can make use of a filter expression + /// to optimise data retrieval. + #[deprecated(since = "20.0.0", note = "use supports_filters_pushdown instead")] + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> Result { + Ok(TableProviderFilterPushDown::Unsupported) + } + + /// Tests whether the table provider can make use of any or all filter expressions + /// to optimise data retrieval. + #[allow(deprecated)] + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + filters + .iter() + .map(|f| self.supports_filter_pushdown(f)) + .collect() + } + + /// Get statistics for this table, if available + fn statistics(&self) -> Option { + None + } + + /// Return an [`ExecutionPlan`] to insert data into this table, if + /// supported. + /// + /// The returned plan should return a single row in a UInt64 + /// column called "count" such as the following + /// + /// ```text + /// +-------+, + /// | count |, + /// +-------+, + /// | 6 |, + /// +-------+, + /// ``` + /// + /// # See Also + /// + /// See [`FileSinkExec`] for the common pattern of inserting a + /// streams of `RecordBatch`es as files to an ObjectStore. + /// + /// [`FileSinkExec`]: crate::physical_plan::insert::FileSinkExec + async fn insert_into( + &self, + _state: &SessionState, + _input: Arc, + _overwrite: bool, + ) -> Result> { + not_impl_err!("Insert into not implemented for this table") + } +} + +/// A factory which creates [`TableProvider`]s at runtime given a URL. +/// +/// For example, this can be used to create a table "on the fly" +/// from a directory of files only when that name is referenced. +#[async_trait] +pub trait TableProviderFactory: Sync + Send { + /// Create a TableProvider with the given url + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result>; +} + +/// The default [`TableProviderFactory`] +/// +/// If [`CreateExternalTable`] is unbounded calls [`StreamTableFactory::create`], +/// otherwise calls [`ListingTableFactory::create`] +#[derive(Debug, Default)] +pub struct DefaultTableFactory { + stream: StreamTableFactory, + listing: ListingTableFactory, +} + +impl DefaultTableFactory { + /// Creates a new [`DefaultTableFactory`] + pub fn new() -> Self { + Self::default() + } +} + +#[async_trait] +impl TableProviderFactory for DefaultTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let mut unbounded = cmd.unbounded; + for (k, v) in &cmd.options { + if k.eq_ignore_ascii_case("unbounded") && v.eq_ignore_ascii_case("true") { + unbounded = true + } + } + + match unbounded { + true => self.stream.create(state, cmd).await, + false => self.listing.create(state, cmd).await, + } + } +} diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs new file mode 100644 index 0000000000000..695e139517cff --- /dev/null +++ b/datafusion/core/src/datasource/statistics.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::listing::PartitionedFile; +use crate::arrow::datatypes::{Schema, SchemaRef}; +use crate::error::Result; +use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; +use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; + +use datafusion_common::stats::Precision; +use datafusion_common::ScalarValue; + +use futures::{Stream, StreamExt}; +use itertools::izip; +use itertools::multiunzip; + +/// Get all files as well as the file level summary statistics (no statistic for partition columns). +/// If the optional `limit` is provided, includes only sufficient files. +/// Needed to read up to `limit` number of rows. +pub async fn get_statistics_with_limit( + all_files: impl Stream>, + file_schema: SchemaRef, + limit: Option, +) -> Result<(Vec, Statistics)> { + let mut result_files = vec![]; + // These statistics can be calculated as long as at least one file provides + // useful information. If none of the files provides any information, then + // they will end up having `Precision::Absent` values. Throughout calculations, + // missing values will be imputed as: + // - zero for summations, and + // - neutral element for extreme points. + let size = file_schema.fields().len(); + let mut null_counts: Vec> = vec![Precision::Absent; size]; + let mut max_values: Vec> = vec![Precision::Absent; size]; + let mut min_values: Vec> = vec![Precision::Absent; size]; + let mut num_rows = Precision::::Absent; + let mut total_byte_size = Precision::::Absent; + + // Fusing the stream allows us to call next safely even once it is finished. + let mut all_files = Box::pin(all_files.fuse()); + + if let Some(first_file) = all_files.next().await { + let (file, file_stats) = first_file?; + result_files.push(file); + + // First file, we set them directly from the file statistics. + num_rows = file_stats.num_rows; + total_byte_size = file_stats.total_byte_size; + for (index, file_column) in file_stats.column_statistics.into_iter().enumerate() { + null_counts[index] = file_column.null_count; + max_values[index] = file_column.max_value; + min_values[index] = file_column.min_value; + } + + // If the number of rows exceeds the limit, we can stop processing + // files. This only applies when we know the number of rows. It also + // currently ignores tables that have no statistics regarding the + // number of rows. + let conservative_num_rows = match num_rows { + Precision::Exact(nr) => nr, + _ => usize::MIN, + }; + if conservative_num_rows <= limit.unwrap_or(usize::MAX) { + while let Some(current) = all_files.next().await { + let (file, file_stats) = current?; + result_files.push(file); + + // We accumulate the number of rows, total byte size and null + // counts across all the files in question. If any file does not + // provide any information or provides an inexact value, we demote + // the statistic precision to inexact. + num_rows = add_row_stats(file_stats.num_rows, num_rows); + + total_byte_size = + add_row_stats(file_stats.total_byte_size, total_byte_size); + + (null_counts, max_values, min_values) = multiunzip( + izip!( + file_stats.column_statistics.into_iter(), + null_counts.into_iter(), + max_values.into_iter(), + min_values.into_iter() + ) + .map( + |( + ColumnStatistics { + null_count: file_nc, + max_value: file_max, + min_value: file_min, + distinct_count: _, + }, + null_count, + max_value, + min_value, + )| { + ( + add_row_stats(file_nc, null_count), + set_max_if_greater(file_max, max_value), + set_min_if_lesser(file_min, min_value), + ) + }, + ), + ); + + // If the number of rows exceeds the limit, we can stop processing + // files. This only applies when we know the number of rows. It also + // currently ignores tables that have no statistics regarding the + // number of rows. + if num_rows.get_value().unwrap_or(&usize::MIN) + > &limit.unwrap_or(usize::MAX) + { + break; + } + } + } + }; + + let mut statistics = Statistics { + num_rows, + total_byte_size, + column_statistics: get_col_stats_vec(null_counts, max_values, min_values), + }; + if all_files.next().await.is_some() { + // If we still have files in the stream, it means that the limit kicked + // in, and the statistic could have been different had we processed the + // files in a different order. + statistics = statistics.into_inexact() + } + + Ok((result_files, statistics)) +} + +pub(crate) fn create_max_min_accs( + schema: &Schema, +) -> (Vec>, Vec>) { + let max_values: Vec> = schema + .fields() + .iter() + .map(|field| MaxAccumulator::try_new(field.data_type()).ok()) + .collect(); + let min_values: Vec> = schema + .fields() + .iter() + .map(|field| MinAccumulator::try_new(field.data_type()).ok()) + .collect(); + (max_values, min_values) +} + +fn add_row_stats( + file_num_rows: Precision, + num_rows: Precision, +) -> Precision { + match (file_num_rows, &num_rows) { + (Precision::Absent, _) => num_rows.to_inexact(), + (lhs, Precision::Absent) => lhs.to_inexact(), + (lhs, rhs) => lhs.add(rhs), + } +} + +pub(crate) fn get_col_stats_vec( + null_counts: Vec>, + max_values: Vec>, + min_values: Vec>, +) -> Vec { + izip!(null_counts, max_values, min_values) + .map(|(null_count, max_value, min_value)| ColumnStatistics { + null_count, + max_value, + min_value, + distinct_count: Precision::Absent, + }) + .collect() +} + +pub(crate) fn get_col_stats( + schema: &Schema, + null_counts: Vec>, + max_values: &mut [Option], + min_values: &mut [Option], +) -> Vec { + (0..schema.fields().len()) + .map(|i| { + let max_value = match &max_values[i] { + Some(max_value) => max_value.evaluate().ok(), + None => None, + }; + let min_value = match &min_values[i] { + Some(min_value) => min_value.evaluate().ok(), + None => None, + }; + ColumnStatistics { + null_count: null_counts[i].clone(), + max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent), + min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent), + distinct_count: Precision::Absent, + } + }) + .collect() +} + +/// If the given value is numerically greater than the original maximum value, +/// return the new maximum value with appropriate exactness information. +fn set_max_if_greater( + max_nominee: Precision, + max_values: Precision, +) -> Precision { + match (&max_values, &max_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 < val2 => max_nominee, + (Precision::Exact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Exact(val2)) + if val1 < val2 => + { + max_nominee.to_inexact() + } + (Precision::Exact(_), Precision::Absent) => max_values.to_inexact(), + (Precision::Absent, Precision::Exact(_)) => max_nominee.to_inexact(), + (Precision::Absent, Precision::Inexact(_)) => max_nominee, + (Precision::Absent, Precision::Absent) => Precision::Absent, + _ => max_values, + } +} + +/// If the given value is numerically lesser than the original minimum value, +/// return the new minimum value with appropriate exactness information. +fn set_min_if_lesser( + min_nominee: Precision, + min_values: Precision, +) -> Precision { + match (&min_values, &min_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 > val2 => min_nominee, + (Precision::Exact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Inexact(val2)) + | (Precision::Inexact(val1), Precision::Exact(val2)) + if val1 > val2 => + { + min_nominee.to_inexact() + } + (Precision::Exact(_), Precision::Absent) => min_values.to_inexact(), + (Precision::Absent, Precision::Exact(_)) => min_nominee.to_inexact(), + (Precision::Absent, Precision::Inexact(_)) => min_nominee, + (Precision::Absent, Precision::Absent) => Precision::Absent, + _ => min_values, + } +} diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs new file mode 100644 index 0000000000000..e7512499eb9d8 --- /dev/null +++ b/datafusion/core/src/datasource/stream.rs @@ -0,0 +1,358 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TableProvider for stream sources, such as FIFO files + +use std::any::Any; +use std::fmt::Formatter; +use std::fs::{File, OpenOptions}; +use std::io::BufReader; +use std::path::PathBuf; +use std::str::FromStr; +use std::sync::Arc; + +use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use futures::StreamExt; +use tokio::task::spawn_blocking; + +use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_physical_plan::common::AbortOnDropSingle; +use datafusion_physical_plan::insert::{DataSink, FileSinkExec}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use crate::datasource::provider::TableProviderFactory; +use crate::datasource::{create_ordering, TableProvider}; +use crate::execution::context::SessionState; + +/// A [`TableProviderFactory`] for [`StreamTable`] +#[derive(Debug, Default)] +pub struct StreamTableFactory {} + +#[async_trait] +impl TableProviderFactory for StreamTableFactory { + async fn create( + &self, + state: &SessionState, + cmd: &CreateExternalTable, + ) -> Result> { + let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let location = cmd.location.clone(); + let encoding = cmd.file_type.parse()?; + + let config = StreamConfig::new_file(schema, location.into()) + .with_encoding(encoding) + .with_order(cmd.order_exprs.clone()) + .with_header(cmd.has_header) + .with_batch_size(state.config().batch_size()); + + Ok(Arc::new(StreamTable(Arc::new(config)))) + } +} + +/// The data encoding for [`StreamTable`] +#[derive(Debug, Clone)] +pub enum StreamEncoding { + /// CSV records + Csv, + /// Newline-delimited JSON records + Json, +} + +impl FromStr for StreamEncoding { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "csv" => Ok(Self::Csv), + "json" => Ok(Self::Json), + _ => plan_err!("Unrecognised StreamEncoding {}", s), + } + } +} + +/// The configuration for a [`StreamTable`] +#[derive(Debug)] +pub struct StreamConfig { + schema: SchemaRef, + location: PathBuf, + batch_size: usize, + encoding: StreamEncoding, + header: bool, + order: Vec>, + constraints: Constraints, +} + +impl StreamConfig { + /// Stream data from the file at `location` + /// + /// * Data will be read sequentially from the provided `location` + /// * New data will be appended to the end of the file + /// + /// The encoding can be configured with [`Self::with_encoding`] and + /// defaults to [`StreamEncoding::Csv`] + pub fn new_file(schema: SchemaRef, location: PathBuf) -> Self { + Self { + schema, + location, + batch_size: 1024, + encoding: StreamEncoding::Csv, + order: vec![], + header: false, + constraints: Constraints::empty(), + } + } + + /// Specify a sort order for the stream + pub fn with_order(mut self, order: Vec>) -> Self { + self.order = order; + self + } + + /// Specify the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Specify whether the file has a header (only applicable for [`StreamEncoding::Csv`]) + pub fn with_header(mut self, header: bool) -> Self { + self.header = header; + self + } + + /// Specify an encoding for the stream + pub fn with_encoding(mut self, encoding: StreamEncoding) -> Self { + self.encoding = encoding; + self + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + fn reader(&self) -> Result> { + let file = File::open(&self.location)?; + let schema = self.schema.clone(); + match &self.encoding { + StreamEncoding::Csv => { + let reader = arrow::csv::ReaderBuilder::new(schema) + .with_header(self.header) + .with_batch_size(self.batch_size) + .build(file)?; + + Ok(Box::new(reader)) + } + StreamEncoding::Json => { + let reader = arrow::json::ReaderBuilder::new(schema) + .with_batch_size(self.batch_size) + .build(BufReader::new(file))?; + + Ok(Box::new(reader)) + } + } + } + + fn writer(&self) -> Result> { + match &self.encoding { + StreamEncoding::Csv => { + let header = self.header && !self.location.exists(); + let file = OpenOptions::new().append(true).open(&self.location)?; + let writer = arrow::csv::WriterBuilder::new() + .with_header(header) + .build(file); + + Ok(Box::new(writer)) + } + StreamEncoding::Json => { + let file = OpenOptions::new().append(true).open(&self.location)?; + Ok(Box::new(arrow::json::LineDelimitedWriter::new(file))) + } + } + } +} + +/// A [`TableProvider`] for an unbounded stream source +/// +/// Currently only reading from / appending to a single file in-place is supported, but +/// other stream sources and sinks may be added in future. +/// +/// Applications looking to read/write datasets comprising multiple files, e.g. [Hadoop]-style +/// data stored in object storage, should instead consider [`ListingTable`]. +/// +/// [Hadoop]: https://hadoop.apache.org/ +/// [`ListingTable`]: crate::datasource::listing::ListingTable +pub struct StreamTable(Arc); + +impl StreamTable { + /// Create a new [`StreamTable`] for the given [`StreamConfig`] + pub fn new(config: Arc) -> Self { + Self(config) + } +} + +#[async_trait] +impl TableProvider for StreamTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.0.schema.clone() + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.0.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let projected_schema = match projection { + Some(p) => { + let projected = self.0.schema.project(p)?; + create_ordering(&projected, &self.0.order)? + } + None => create_ordering(self.0.schema.as_ref(), &self.0.order)?, + }; + + Ok(Arc::new(StreamingTableExec::try_new( + self.0.schema.clone(), + vec![Arc::new(StreamRead(self.0.clone())) as _], + projection, + projected_schema, + true, + )?)) + } + + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + _overwrite: bool, + ) -> Result> { + let ordering = match self.0.order.first() { + Some(x) => { + let schema = self.0.schema.as_ref(); + let orders = create_ordering(schema, std::slice::from_ref(x))?; + let ordering = orders.into_iter().next().unwrap(); + Some(ordering.into_iter().map(Into::into).collect()) + } + None => None, + }; + + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(StreamWrite(self.0.clone())), + self.0.schema.clone(), + ordering, + ))) + } +} + +struct StreamRead(Arc); + +impl PartitionStream for StreamRead { + fn schema(&self) -> &SchemaRef { + &self.0.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let config = self.0.clone(); + let schema = self.0.schema.clone(); + let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); + let tx = builder.tx(); + builder.spawn_blocking(move || { + let reader = config.reader()?; + for b in reader { + if tx.blocking_send(b.map_err(Into::into)).is_err() { + break; + } + } + Ok(()) + }); + builder.build() + } +} + +#[derive(Debug)] +struct StreamWrite(Arc); + +impl DisplayAs for StreamWrite { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("StreamWrite") + .field("location", &self.0.location) + .field("batch_size", &self.0.batch_size) + .field("encoding", &self.0.encoding) + .field("header", &self.0.header) + .finish_non_exhaustive() + } +} + +#[async_trait] +impl DataSink for StreamWrite { + fn as_any(&self) -> &dyn Any { + self + } + + fn metrics(&self) -> Option { + None + } + + async fn write_all( + &self, + mut data: SendableRecordBatchStream, + _context: &Arc, + ) -> Result { + let config = self.0.clone(); + let (sender, mut receiver) = tokio::sync::mpsc::channel::(2); + // Note: FIFO Files support poll so this could use AsyncFd + let write = AbortOnDropSingle::new(spawn_blocking(move || { + let mut count = 0_u64; + let mut writer = config.writer()?; + while let Some(batch) = receiver.blocking_recv() { + count += batch.num_rows() as u64; + writer.write(&batch)?; + } + Ok(count) + })); + + while let Some(b) = data.next().await.transpose()? { + if sender.send(b).await.is_err() { + break; + } + } + drop(sender); + write.await.unwrap() + } +} diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index 4a234fbe138b7..3eb120653ce38 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -23,22 +23,14 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{Expr, TableType}; +use log::debug; use crate::datasource::TableProvider; -use crate::execution::context::{SessionState, TaskContext}; -use crate::physical_plan::streaming::StreamingTableExec; -use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; - -/// A partition that can be converted into a [`SendableRecordBatchStream`] -pub trait PartitionStream: Send + Sync { - /// Returns the schema of this partition - fn schema(&self) -> &SchemaRef; - - /// Returns a stream yielding this partitions values - fn execute(&self, ctx: Arc) -> SendableRecordBatchStream; -} +use crate::execution::context::SessionState; +use crate::physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use crate::physical_plan::ExecutionPlan; /// A [`TableProvider`] that streams a set of [`PartitionStream`] pub struct StreamingTable { @@ -53,10 +45,15 @@ impl StreamingTable { schema: SchemaRef, partitions: Vec>, ) -> Result { - if !partitions.iter().all(|x| schema.contains(x.schema())) { - return Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )); + for x in partitions.iter() { + let partition_schema = x.schema(); + if !schema.contains(partition_schema) { + debug!( + "target schema does not contain partition schema. \ + Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + ); + return plan_err!("Mismatch between schema and batches"); + } } Ok(Self { @@ -98,6 +95,7 @@ impl TableProvider for StreamingTable { self.schema.clone(), self.partitions.clone(), projection, + None, self.infinite, )?)) } diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 391e4b93c4e57..85fb8939886c5 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -108,12 +108,20 @@ impl TableProvider for ViewTable { filters: &[Expr], limit: Option, ) -> Result> { - let plan = if let Some(projection) = projection { + let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new)); + let plan = self.logical_plan().clone(); + let mut plan = LogicalPlanBuilder::from(plan); + + if let Some(filter) = filter { + plan = plan.filter(filter)?; + } + + let mut plan = if let Some(projection) = projection { // avoiding adding a redundant projection (e.g. SELECT * FROM view) let current_projection = - (0..self.logical_plan.schema().fields().len()).collect::>(); + (0..plan.schema().fields().len()).collect::>(); if projection == ¤t_projection { - self.logical_plan().clone() + plan } else { let fields: Vec = projection .iter() @@ -123,19 +131,11 @@ impl TableProvider for ViewTable { ) }) .collect(); - LogicalPlanBuilder::from(self.logical_plan.clone()) - .project(fields)? - .build()? + plan.project(fields)? } } else { - self.logical_plan().clone() + plan }; - let mut plan = LogicalPlanBuilder::from(plan); - let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new)); - - if let Some(filter) = filter { - plan = plan.filter(filter)?; - } if let Some(limit) = limit { plan = plan.limit(0, Some(limit))?; @@ -159,7 +159,7 @@ mod tests { #[tokio::test] async fn issue_3242() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/pull/3242 - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -175,7 +175,7 @@ mod tests { .collect() .await?; - let expected = vec!["+---+", "| b |", "+---+", "| 2 |", "+---+"]; + let expected = ["+---+", "| b |", "+---+", "| 2 |", "+---+"]; assert_batches_eq!(expected, &results); @@ -199,7 +199,7 @@ mod tests { #[tokio::test] async fn query_view() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -221,7 +221,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------+---------+---------+", "| column1 | column2 | column3 |", "+---------+---------+---------+", @@ -237,7 +237,7 @@ mod tests { #[tokio::test] async fn query_view_with_alias() -> Result<()> { - let session_ctx = SessionContext::with_config(SessionConfig::new()); + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); session_ctx .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") @@ -254,7 +254,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------------+", "| column1_alias |", "+---------------+", @@ -270,7 +270,7 @@ mod tests { #[tokio::test] async fn query_view_with_inline_alias() -> Result<()> { - let session_ctx = SessionContext::with_config(SessionConfig::new()); + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); session_ctx .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") @@ -287,7 +287,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------------+---------------+", "| column2_alias | column1_alias |", "+---------------+---------------+", @@ -303,7 +303,7 @@ mod tests { #[tokio::test] async fn query_view_with_projection() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -325,7 +325,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------+", "| column1 |", "+---------+", @@ -341,7 +341,7 @@ mod tests { #[tokio::test] async fn query_view_with_filter() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -363,7 +363,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------+", "| column1 |", "+---------+", @@ -378,7 +378,7 @@ mod tests { #[tokio::test] async fn query_join_views() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -403,7 +403,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------+---------+---------+", "| column2 | column1 | column3 |", "+---------+---------+---------+", @@ -439,6 +439,7 @@ mod tests { .select_columns(&["bool_col", "int_col"])?; let plan = df.explain(false, false)?.collect().await?; + // Filters all the way to Parquet let formatted = arrow::util::pretty::pretty_format_batches(&plan) .unwrap() @@ -480,7 +481,7 @@ mod tests { #[tokio::test] async fn create_view_plan() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -533,7 +534,7 @@ mod tests { #[tokio::test] async fn create_or_replace_view() -> Result<()> { - let session_ctx = SessionContext::with_config( + let session_ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -558,7 +559,7 @@ mod tests { .collect() .await?; - let expected = vec![ + let expected = [ "+---------+", "| column1 |", "+---------+", diff --git a/datafusion/core/src/error.rs b/datafusion/core/src/error.rs index 0a138c80df2a5..5a5faa7896e3e 100644 --- a/datafusion/core/src/error.rs +++ b/datafusion/core/src/error.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion error types +//! DataFusion error type [`DataFusionError`] and [`Result`]. pub use datafusion_common::{DataFusionError, Result, SharedResult}; diff --git a/datafusion/core/src/execution/context/avro.rs b/datafusion/core/src/execution/context/avro.rs new file mode 100644 index 0000000000000..d60e79862ef2d --- /dev/null +++ b/datafusion/core/src/execution/context/avro.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::super::options::{AvroReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading an Avro data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_avro( + &self, + table_paths: P, + options: AvroReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers an Avro file as a table that can be referenced from + /// SQL statements executed against this context. + pub async fn register_avro( + &self, + name: &str, + table_path: &str, + options: AvroReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_avro(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_avro(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_avro("dummy", AvroReadOptions::default()) + .await + .unwrap() + } + } +} diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs new file mode 100644 index 0000000000000..f3675422c7d5d --- /dev/null +++ b/datafusion/core/src/execution/context/csv.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::plan_to_csv; + +use super::super::options::{CsvReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading a CSV data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// Example usage is given below: + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// // You can read a single file using `read_csv` + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// // you can also read multiple files: + /// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn read_csv( + &self, + table_paths: P, + options: CsvReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a CSV file as a table which can referenced from SQL + /// statements executed against this context. + pub async fn register_csv( + &self, + name: &str, + table_path: &str, + options: CsvReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + + Ok(()) + } + + /// Executes a query and writes the results to a partitioned CSV file. + pub async fn write_csv( + &self, + plan: Arc, + path: impl AsRef, + ) -> Result<()> { + plan_to_csv(self.task_ctx(), plan, path).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::assert_batches_eq; + use crate::test_util::{plan_and_collect, populate_csv_partitions}; + use async_trait::async_trait; + use tempfile::TempDir; + + #[tokio::test] + async fn query_csv_with_custom_partition_extension() -> Result<()> { + let tmp_dir = TempDir::new()?; + + // The main stipulation of this test: use a file extension that isn't .csv. + let file_extension = ".tst"; + + let ctx = SessionContext::new(); + let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .file_extension(file_extension), + ) + .await?; + let results = + plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; + + assert_eq!(results.len(), 1); + let expected = [ + "+--------------+--------------+----------+", + "| SUM(test.c1) | SUM(test.c2) | COUNT(*) |", + "+--------------+--------------+----------+", + "| 10 | 110 | 20 |", + "+--------------+--------------+----------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) + } + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_csv(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_csv(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() + } + } +} diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs new file mode 100644 index 0000000000000..f67693aa8f317 --- /dev/null +++ b/datafusion/core/src/execution/context/json.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::plan_to_json; + +use super::super::options::{NdJsonReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading an JSON data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_json( + &self, + table_paths: P, + options: NdJsonReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a JSON file as a table that it can be referenced + /// from SQL statements executed against this context. + pub async fn register_json( + &self, + name: &str, + table_path: &str, + options: NdJsonReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.copied_config()); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } + + /// Executes a query and writes the results to a partitioned JSON file. + pub async fn write_json( + &self, + plan: Arc, + path: impl AsRef, + ) -> Result<()> { + plan_to_json(self.task_ctx(), plan, path).await + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context/mod.rs similarity index 70% rename from datafusion/core/src/execution/context.rs rename to datafusion/core/src/execution/context/mod.rs index 6b81a39691d69..58a4f08341d64 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -16,25 +16,34 @@ // under the License. //! [`SessionContext`] contains methods for registering data sources and executing queries + +mod avro; +mod csv; +mod json; +#[cfg(feature = "parquet")] +mod parquet; + use crate::{ - catalog::catalog::{CatalogList, MemoryCatalogList}, + catalog::{CatalogList, MemoryCatalogList}, datasource::{ - datasource::TableProviderFactory, + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, - listing_table_factory::ListingTableFactory, + provider::TableProviderFactory, }, datasource::{MemTable, ViewTable}, logical_expr::{PlanType, ToStringifiedPlan}, optimizer::optimizer::Optimizer, - physical_optimizer::{ - aggregate_statistics::AggregateStatistics, join_selection::JoinSelection, - optimizer::PhysicalOptimizerRule, - }, + physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, +}; +use datafusion_common::{ + alias::AliasGenerator, + exec_err, not_impl_err, plan_datafusion_err, plan_err, + tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, StringifiedPlan, UserDefinedLogicalNode, + Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -48,15 +57,12 @@ use std::{ }; use std::{ops::ControlFlow, sync::Weak}; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow::{ - array::StringBuilder, - datatypes::{DataType, Field, Schema, SchemaRef}, -}; use crate::catalog::{ - catalog::{CatalogProvider, MemoryCatalogProvider}, schema::{MemorySchemaProvider, SchemaProvider}, + {CatalogProvider, MemoryCatalogProvider}, }; use crate::dataframe::DataFrame; use crate::datasource::{ @@ -77,18 +83,13 @@ use datafusion_sql::{ }; use sqlparser::dialect::dialect_from_str; -use crate::physical_optimizer::coalesce_batches::CoalesceBatches; -use crate::physical_optimizer::repartition::Repartition; - use crate::config::ConfigOptions; -use crate::datasource::physical_plan::{plan_to_csv, plan_to_json, plan_to_parquet}; use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry}; -use crate::physical_optimizer::dist_enforcement::EnforceDistribution; -use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; -use crate::physical_plan::PhysicalPlanner; +use crate::physical_planner::DefaultPhysicalPlanner; +use crate::physical_planner::PhysicalPlanner; use crate::variable::{VarProvider, VarType}; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -97,16 +98,11 @@ use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, }; -use parquet::file::properties::WriterProperties; use url::Url; use crate::catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; -use crate::physical_optimizer::global_sort_selection::GlobalSortSelection; -use crate::physical_optimizer::pipeline_checker::PipelineChecker; -use crate::physical_optimizer::pipeline_fixer::PipelineFixer; -use crate::physical_optimizer::sort_enforcement::EnforceSorting; use datafusion_optimizer::{ analyzer::{Analyzer, AnalyzerRule}, OptimizerConfig, @@ -115,14 +111,12 @@ use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; // backwards compatibility +use crate::datasource::provider::DefaultTableFactory; use crate::execution::options::ArrowReadOptions; -use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; -use super::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, ReadOptions, -}; +use super::options::ReadOptions; /// DataFilePaths adds a method to convert strings and vector of strings to vector of [`ListingTableUrl`] URLs. /// This allows methods such [`SessionContext::read_csv`] and [`SessionContext::read_avro`] @@ -174,12 +168,14 @@ where /// * Register a custom data source that can be referenced from a SQL query. /// * Execution a SQL query /// +/// # Example: DataFrame API +/// /// The following example demonstrates how to use the context to execute a query against a CSV /// data source using the DataFrame API: /// /// ``` /// use datafusion::prelude::*; -/// # use datafusion::error::Result; +/// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -187,22 +183,49 @@ where /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? /// .limit(0, Some(100))?; -/// let results = df.collect(); +/// let results = df +/// .collect() +/// .await?; +/// assert_batches_eq!( +/// &[ +/// "+---+----------------+", +/// "| a | MIN(?table?.b) |", +/// "+---+----------------+", +/// "| 1 | 2 |", +/// "+---+----------------+", +/// ], +/// &results +/// ); /// # Ok(()) /// # } /// ``` /// +/// # Example: SQL API +/// /// The following example demonstrates how to execute the same query using SQL: /// /// ``` /// use datafusion::prelude::*; -/// -/// # use datafusion::error::Result; +/// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let mut ctx = SessionContext::new(); /// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; -/// let results = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?; +/// let results = ctx +/// .sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100") +/// .await? +/// .collect() +/// .await?; +/// assert_batches_eq!( +/// &[ +/// "+---+----------------+", +/// "| a | MIN(example.b) |", +/// "+---+----------------+", +/// "| 1 | 2 |", +/// "+---+----------------+", +/// ], +/// &results +/// ); /// # Ok(()) /// # } /// ``` @@ -239,7 +262,7 @@ impl Default for SessionContext { impl SessionContext { /// Creates a new `SessionContext` using the default [`SessionConfig`]. pub fn new() -> Self { - Self::with_config(SessionConfig::new()) + Self::new_with_config(SessionConfig::new()) } /// Finds any [`ListingSchemaProvider`]s and instructs them to reload tables from "disk" @@ -265,11 +288,18 @@ impl SessionContext { /// Creates a new `SessionContext` using the provided /// [`SessionConfig`] and a new [`RuntimeEnv`]. /// - /// See [`Self::with_config_rt`] for more details on resource + /// See [`Self::new_with_config_rt`] for more details on resource /// limits. - pub fn with_config(config: SessionConfig) -> Self { + pub fn new_with_config(config: SessionConfig) -> Self { let runtime = Arc::new(RuntimeEnv::default()); - Self::with_config_rt(config, runtime) + Self::new_with_config_rt(config, runtime) + } + + /// Creates a new `SessionContext` using the provided + /// [`SessionConfig`] and a new [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_config")] + pub fn with_config(config: SessionConfig) -> Self { + Self::new_with_config(config) } /// Creates a new `SessionContext` using the provided @@ -285,13 +315,20 @@ impl SessionContext { /// memory used) across all DataFusion queries in a process, /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + let state = SessionState::new_with_config_rt(config, runtime); + Self::new_with_state(state) + } + + /// Creates a new `SessionContext` using the provided + /// [`SessionConfig`] and a [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::with_config_rt(config, runtime); - Self::with_state(state) + Self::new_with_config_rt(config, runtime) } /// Creates a new `SessionContext` using the provided [`SessionState`] - pub fn with_state(state: SessionState) -> Self { + pub fn new_with_state(state: SessionState) -> Self { Self { session_id: state.session_id.clone(), session_start_time: Utc::now(), @@ -299,6 +336,11 @@ impl SessionContext { } } + /// Creates a new `SessionContext` using the provided [`SessionState`] + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] + pub fn with_state(state: SessionState) -> Self { + Self::new_with_state(state) + } /// Returns the time this `SessionContext` was created pub fn session_start_time(&self) -> DateTime { self.session_start_time @@ -353,22 +395,81 @@ impl SessionContext { self.state.read().config.clone() } - /// Creates a [`DataFrame`] that will execute a SQL query. + /// Creates a [`DataFrame`] from SQL query text. /// /// Note: This API implements DDL statements such as `CREATE TABLE` and /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory - /// default implementations. + /// default implementations. See [`Self::sql_with_options`]. + /// + /// # Example: Running SQL queries + /// + /// See the example on [`Self`] + /// + /// # Example: Creating a Table with SQL /// - /// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which - /// does not mutate the state based on such statements. + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::{error::Result, assert_batches_eq}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let mut ctx = SessionContext::new(); + /// ctx + /// .sql("CREATE TABLE foo (x INTEGER)") + /// .await? + /// .collect() + /// .await?; + /// assert!(ctx.table_exist("foo").unwrap()); + /// # Ok(()) + /// # } + /// ``` pub async fn sql(&self, sql: &str) -> Result { - // create a query planner + self.sql_with_options(sql, SQLOptions::new()).await + } + + /// Creates a [`DataFrame`] from SQL query text, first validating + /// that the queries are allowed by `options` + /// + /// # Example: Preventing Creating a Table with SQL + /// + /// If you want to avoid creating tables, or modifying data or the + /// session, set [`SQLOptions`] appropriately: + /// + /// ``` + /// use datafusion::prelude::*; + /// # use datafusion::{error::Result}; + /// # use datafusion::physical_plan::collect; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let mut ctx = SessionContext::new(); + /// let options = SQLOptions::new() + /// .with_allow_ddl(false); + /// let err = ctx.sql_with_options("CREATE TABLE foo (x INTEGER)", options) + /// .await + /// .unwrap_err(); + /// assert!( + /// err.to_string().starts_with("Error during planning: DDL not supported: CreateMemoryTable") + /// ); + /// # Ok(()) + /// # } + /// ``` + pub async fn sql_with_options( + &self, + sql: &str, + options: SQLOptions, + ) -> Result { let plan = self.state().create_logical_plan(sql).await?; + options.verify_plan(&plan)?; self.execute_logical_plan(plan).await } - /// Execute the [`LogicalPlan`], return a [`DataFrame`] + /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API + /// is not featured limited (so all SQL such as `CREATE TABLE` and + /// `COPY` will be run). + /// + /// If you wish to limit the type of plan that can be run from + /// SQL, see [`Self::sql_with_options`] and + /// [`SQLOptions::verify_plan`]. pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result { match plan { LogicalPlan::Ddl(ddl) => match ddl { @@ -391,9 +492,6 @@ impl SessionContext { LogicalPlan::Statement(Statement::SetVariable(stmt)) => { self.set_variable(stmt).await } - LogicalPlan::DescribeTable(DescribeTable { schema, .. }) => { - self.return_describe_table_dataframe(schema).await - } plan => Ok(DataFrame::new(self.state(), plan)), } @@ -405,53 +503,6 @@ impl SessionContext { Ok(DataFrame::new(self.state(), plan)) } - // return an record_batch which describe table - async fn return_describe_table_record_batch( - &self, - schema: Arc, - ) -> Result { - let record_batch_schema = Arc::new(Schema::new(vec![ - Field::new("column_name", DataType::Utf8, false), - Field::new("data_type", DataType::Utf8, false), - Field::new("is_nullable", DataType::Utf8, false), - ])); - - let mut column_names = StringBuilder::new(); - let mut data_types = StringBuilder::new(); - let mut is_nullables = StringBuilder::new(); - for (_, field) in schema.fields().iter().enumerate() { - column_names.append_value(field.name()); - - // "System supplied type" --> Use debug format of the datatype - let data_type = field.data_type(); - data_types.append_value(format!("{data_type:?}")); - - // "YES if the column is possibly nullable, NO if it is known not nullable. " - let nullable_str = if field.is_nullable() { "YES" } else { "NO" }; - is_nullables.append_value(nullable_str); - } - - let record_batch = RecordBatch::try_new( - record_batch_schema, - vec![ - Arc::new(column_names.finish()), - Arc::new(data_types.finish()), - Arc::new(is_nullables.finish()), - ], - )?; - - Ok(record_batch) - } - - // return an dataframe which describe file - async fn return_describe_table_dataframe( - &self, - schema: Arc, - ) -> Result { - let record_batch = self.return_describe_table_record_batch(schema).await?; - self.read_batch(record_batch) - } - async fn create_external_table( &self, cmd: &CreateExternalTable, @@ -461,10 +512,7 @@ impl SessionContext { match cmd.if_not_exists { true => return self.return_empty_dataframe(), false => { - return Err(DataFusionError::Execution(format!( - "Table '{}' already exists", - cmd.name - ))); + return exec_err!("Table '{}' already exists", cmd.name); } } } @@ -481,19 +529,13 @@ impl SessionContext { input, if_not_exists, or_replace, - primary_key, + constraints, + column_defaults, } = cmd; - if !primary_key.is_empty() { - Err(DataFusionError::Execution( - "Primary keys on MemoryTables are not currently supported!".to_string(), - ))?; - } - let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); let input = self.state().optimize(&input)?; let table = self.table(&name).await; - match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { @@ -502,27 +544,36 @@ impl SessionContext { let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), + ); self.register_table(&name, table)?; self.return_empty_dataframe() } - (true, true, Ok(_)) => Err(DataFusionError::Execution( - "'IF NOT EXISTS' cannot coexist with 'REPLACE'".to_string(), - )), + (true, true, Ok(_)) => { + exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'") + } (_, _, Err(_)) => { - let schema = Arc::new(input.schema().as_ref().into()); + let df_schema = input.schema(); + let schema = Arc::new(df_schema.as_ref().into()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints and column defaults to the mem table. + MemTable::try_new(schema, batches)? + .with_constraints(constraints) + .with_column_defaults(column_defaults.into_iter().collect()), + ); self.register_table(&name, table)?; self.return_empty_dataframe() } - (false, false, Ok(_)) => Err(DataFusionError::Execution(format!( - "Table '{name}' already exists" - ))), + (false, false, Ok(_)) => exec_err!("Table '{name}' already exists"), } } @@ -550,9 +601,7 @@ impl SessionContext { self.register_table(&name, table)?; self.return_empty_dataframe() } - (false, Ok(_)) => Err(DataFusionError::Execution(format!( - "Table '{name}' already exists" - ))), + (false, Ok(_)) => exec_err!("Table '{name}' already exists"), } } @@ -584,11 +633,7 @@ impl SessionContext { })?; (catalog, tokens[1]) } - _ => { - return Err(DataFusionError::Execution(format!( - "Unable to parse catalog from {schema_name}" - ))) - } + _ => return exec_err!("Unable to parse catalog from {schema_name}"), }; let schema = catalog.schema(schema_name); @@ -599,9 +644,7 @@ impl SessionContext { catalog.register_schema(schema_name, schema)?; self.return_empty_dataframe() } - (false, Some(_)) => Err(DataFusionError::Execution(format!( - "Schema '{schema_name}' already exists" - ))), + (false, Some(_)) => exec_err!("Schema '{schema_name}' already exists"), } } @@ -623,9 +666,7 @@ impl SessionContext { .register_catalog(catalog_name, new_catalog); self.return_empty_dataframe() } - (false, Some(_)) => Err(DataFusionError::Execution(format!( - "Catalog '{catalog_name}' already exists" - ))), + (false, Some(_)) => exec_err!("Catalog '{catalog_name}' already exists"), } } @@ -637,9 +678,7 @@ impl SessionContext { match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), - (_, _) => Err(DataFusionError::Execution(format!( - "Table '{name}' doesn't exist." - ))), + (_, _) => exec_err!("Table '{name}' doesn't exist."), } } @@ -651,9 +690,7 @@ impl SessionContext { match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), - (_, _) => Err(DataFusionError::Execution(format!( - "View '{name}' doesn't exist." - ))), + (_, _) => exec_err!("View '{name}' doesn't exist."), } } @@ -692,9 +729,7 @@ impl SessionContext { &self, schemaref: SchemaReference<'_>, ) -> Result { - Err(DataFusionError::Execution(format!( - "Schema '{schemaref}' doesn't exist." - ))) + exec_err!("Schema '{schemaref}' doesn't exist.") } async fn set_variable(&self, stmt: SetVariable) -> Result { @@ -769,6 +804,14 @@ impl SessionContext { .add_var_provider(variable_type, provider); } + /// Register a table UDF with this context + pub fn register_udtf(&self, name: &str, fun: Arc) { + self.state.write().table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + /// Registers a scalar UDF within this context. /// /// Note in SQL queries, function names are looked up using @@ -776,11 +819,18 @@ impl SessionContext { /// /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + /// Any functions registered with the udf name or its aliases will be overwritten with this new function pub fn register_udf(&self, f: ScalarUDF) { - self.state - .write() + let mut state = self.state.write(); + let aliases = f.aliases(); + for alias in aliases { + state + .scalar_functions + .insert(alias.to_string(), Arc::new(f.clone())); + } + state .scalar_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers an aggregate UDF within this context. @@ -794,7 +844,21 @@ impl SessionContext { self.state .write() .aggregate_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); + } + + /// Registers a window UDF within this context. + /// + /// Note in SQL queries, window function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"` + /// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"` + pub fn register_udwf(&self, f: WindowUDF) { + self.state + .write() + .window_functions + .insert(f.name().to_string(), Arc::new(f)); } /// Creates a [`DataFrame`] for reading a data source. @@ -809,6 +873,25 @@ impl SessionContext { let table_paths = table_paths.to_urls()?; let session_config = self.copied_config(); let listing_options = options.to_listing_options(&session_config); + + let option_extension = listing_options.file_extension.clone(); + + if table_paths.is_empty() { + return exec_err!("No table paths were provided"); + } + + // check if the file extension matches the expected extension + for path in &table_paths { + let file_path = path.as_str(); + if !file_path.ends_with(option_extension.clone().as_str()) + && !path.is_collection() + { + return exec_err!( + "File path '{file_path}' does not match the expected extension '{option_extension}'" + ); + } + } + let resolved_schema = options .get_resolved_schema(&session_config, self.state(), table_paths[0].clone()) .await?; @@ -819,34 +902,6 @@ impl SessionContext { self.read_table(Arc::new(provider)) } - /// Creates a [`DataFrame`] for reading an Avro data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_avro( - &self, - table_paths: P, - options: AvroReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - - /// Creates a [`DataFrame`] for reading an JSON data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_json( - &self, - table_paths: P, - options: NdJsonReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - /// Creates a [`DataFrame`] for reading an Arrow data source. /// /// For more control such as reading multiple files, you can use @@ -869,48 +924,6 @@ impl SessionContext { )) } - /// Creates a [`DataFrame`] for reading a CSV data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// Example usage is given below: - /// - /// ``` - /// use datafusion::prelude::*; - /// # use datafusion::error::Result; - /// # #[tokio::main] - /// # async fn main() -> Result<()> { - /// let ctx = SessionContext::new(); - /// // You can read a single file using `read_csv` - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// // you can also read multiple files: - /// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn read_csv( - &self, - table_paths: P, - options: CsvReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - - /// Creates a [`DataFrame`] for reading a Parquet data source. - /// - /// For more control such as reading multiple files, you can use - /// [`read_table`](Self::read_table) with a [`ListingTable`]. - /// - /// For an example, see [`read_csv`](Self::read_csv) - pub async fn read_parquet( - &self, - table_paths: P, - options: ParquetReadOptions<'_>, - ) -> Result { - self._read_type(table_paths, options).await - } - /// Creates a [`DataFrame`] for a [`TableProvider`] such as a /// [`ListingTable`] or a custom user defined provider. pub fn read_table(&self, provider: Arc) -> Result { @@ -955,10 +968,9 @@ impl SessionContext { (Some(s), _) => s, (None, false) => options.infer_schema(&self.state(), &table_path).await?, (None, true) => { - return Err(DataFusionError::Plan( + return plan_err!( "Schema inference for infinite data sources is not supported." - .to_string(), - )) + ) } }; let config = ListingTableConfig::new(table_path) @@ -972,85 +984,6 @@ impl SessionContext { Ok(()) } - /// Registers a CSV file as a table which can referenced from SQL - /// statements executed against this context. - pub async fn register_csv( - &self, - name: &str, - table_path: &str, - options: CsvReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - - Ok(()) - } - - /// Registers a JSON file as a table that it can be referenced - /// from SQL statements executed against this context. - pub async fn register_json( - &self, - name: &str, - table_path: &str, - options: NdJsonReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - Ok(()) - } - - /// Registers a Parquet file as a table that can be referenced from SQL - /// statements executed against this context. - pub async fn register_parquet( - &self, - name: &str, - table_path: &str, - options: ParquetReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.state.read().config); - - self.register_listing_table(name, table_path, listing_options, None, None) - .await?; - Ok(()) - } - - /// Registers an Avro file as a table that can be referenced from - /// SQL statements executed against this context. - pub async fn register_avro( - &self, - name: &str, - table_path: &str, - options: AvroReadOptions<'_>, - ) -> Result<()> { - let listing_options = options.to_listing_options(&self.copied_config()); - - self.register_listing_table( - name, - table_path, - listing_options, - options.schema.map(|s| Arc::new(s.to_owned())), - None, - ) - .await?; - Ok(()) - } - /// Registers an Arrow file as a table that can be referenced from /// SQL statements executed against this context. pub async fn register_arrow( @@ -1179,7 +1112,7 @@ impl SessionContext { let schema = self.state.read().schema_for_ref(table_ref)?; match schema.table(&table).await { Some(ref provider) => Ok(Arc::clone(provider)), - _ => Err(DataFusionError::Plan(format!("No table named '{table}'"))), + _ => plan_err!("No table named '{table}'"), } } @@ -1226,34 +1159,6 @@ impl SessionContext { self.state().create_physical_plan(logical_plan).await } - /// Executes a query and writes the results to a partitioned CSV file. - pub async fn write_csv( - &self, - plan: Arc, - path: impl AsRef, - ) -> Result<()> { - plan_to_csv(self.task_ctx(), plan, path).await - } - - /// Executes a query and writes the results to a partitioned JSON file. - pub async fn write_json( - &self, - plan: Arc, - path: impl AsRef, - ) -> Result<()> { - plan_to_json(self.task_ctx(), plan, path).await - } - - /// Executes a query and writes the results to a partitioned Parquet file. - pub async fn write_parquet( - &self, - plan: Arc, - path: impl AsRef, - writer_properties: Option, - ) -> Result<()> { - plan_to_parquet(self.task_ctx(), plan, path, writer_properties).await - } - /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) @@ -1290,12 +1195,16 @@ impl FunctionRegistry for SessionContext { fn udaf(&self, name: &str) -> Result> { self.state.read().udaf(name) } + + fn udwf(&self, name: &str) -> Result> { + self.state.read().udwf(name) + } } /// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] pub trait QueryPlanner { - /// Given a `LogicalPlan`, create an `ExecutionPlan` suitable for execution + /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1308,7 +1217,7 @@ struct DefaultQueryPlanner {} #[async_trait] impl QueryPlanner for DefaultQueryPlanner { - /// Given a `LogicalPlan`, create an `ExecutionPlan` suitable for execution + /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1321,7 +1230,12 @@ impl QueryPlanner for DefaultQueryPlanner { } } -/// Execution context for registering data sources and executing queries +/// Execution context for registering data sources and executing queries. +/// See [`SessionContext`] for a higher level API. +/// +/// Note that there is no `Default` or `new()` for SessionState, +/// to avoid accidentally running queries or other operations without passing through +/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionContext`]. #[derive(Clone)] pub struct SessionState { /// A unique UUID that identifies the session @@ -1331,15 +1245,19 @@ pub struct SessionState { /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan - physical_optimizers: Vec>, + physical_optimizers: PhysicalOptimizer, /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` query_planner: Arc, /// Collection of catalogs containing schemas and ultimately TableProviders catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, + /// Window functions registered in the context + window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, /// Session configuration @@ -1367,25 +1285,24 @@ impl Debug for SessionState { } } -/// Default session builder using the provided configuration -#[deprecated( - since = "23.0.0", - note = "See SessionContext::with_config() or SessionState::with_config_rt" -)] -pub fn default_session_builder(config: SessionConfig) -> SessionState { - SessionState::with_config_rt(config, Arc::new(RuntimeEnv::default())) -} - impl SessionState { /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. - pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; - Self::with_config_rt_and_catalog_list(config, runtime, catalog_list) + Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) } - /// Returns new SessionState using the provided configuration, runtime and catalog list. - pub fn with_config_rt_and_catalog_list( + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] + pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + Self::new_with_config_rt(config, runtime) + } + + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogList`] + pub fn new_with_config_rt_and_catalog_list( config: SessionConfig, runtime: Arc, catalog_list: Arc, @@ -1395,12 +1312,13 @@ impl SessionState { // Create table_factories for all default formats let mut table_factories: HashMap> = HashMap::new(); - table_factories.insert("PARQUET".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(ListingTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(ListingTableFactory::new())); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); if config.create_default_catalog_and_schema() { let default_catalog = MemoryCatalogProvider::new(); @@ -1425,64 +1343,17 @@ impl SessionState { ); } - // We need to take care of the rule ordering. They may influence each other. - let physical_optimizers: Vec> = vec![ - Arc::new(AggregateStatistics::new()), - // Statistics-based join selection will change the Auto mode to a real join implementation, - // like collect left, or hash join, or future sort merge join, which will influence the - // EnforceDistribution and EnforceSorting rules as they decide whether to add additional - // repartitioning and local sorting steps to meet distribution and ordering requirements. - // Therefore, it should run before EnforceDistribution and EnforceSorting. - Arc::new(JoinSelection::new()), - // If the query is processing infinite inputs, the PipelineFixer rule applies the - // necessary transformations to make the query runnable (if it is not already runnable). - // If the query can not be made runnable, the rule emits an error with a diagnostic message. - // Since the transformations it applies may alter output partitioning properties of operators - // (e.g. by swapping hash join sides), this rule runs before EnforceDistribution. - Arc::new(PipelineFixer::new()), - // In order to increase the parallelism, the Repartition rule will change the - // output partitioning of some operators in the plan tree, which will influence - // other rules. Therefore, it should run as soon as possible. It is optional because: - // - It's not used for the distributed engine, Ballista. - // - It's conflicted with some parts of the EnforceDistribution, since it will - // introduce additional repartitioning while EnforceDistribution aims to - // reduce unnecessary repartitioning. - Arc::new(Repartition::new()), - // - Currently it will depend on the partition number to decide whether to change the - // single node sort to parallel local sort and merge. Therefore, GlobalSortSelection - // should run after the Repartition. - // - Since it will change the output ordering of some operators, it should run - // before JoinSelection and EnforceSorting, which may depend on that. - Arc::new(GlobalSortSelection::new()), - // The EnforceDistribution rule is for adding essential repartition to satisfy the required - // distribution. Please make sure that the whole plan tree is determined before this rule. - Arc::new(EnforceDistribution::new()), - // The CombinePartialFinalAggregate rule should be applied after the EnforceDistribution rule - Arc::new(CombinePartialFinalAggregate::new()), - // The EnforceSorting rule is for adding essential local sorting to satisfy the required - // ordering. Please make sure that the whole plan tree is determined before this rule. - // Note that one should always run this rule after running the EnforceDistribution rule - // as the latter may break local sorting requirements. - Arc::new(EnforceSorting::new()), - // The CoalesceBatches rule will not influence the distribution and ordering of the - // whole plan tree. Therefore, to avoid influencing other rules, it should run last. - Arc::new(CoalesceBatches::new()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), - ]; - SessionState { session_id, analyzer: Analyzer::new(), optimizer: Optimizer::new(), - physical_optimizers, + physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, + table_functions: HashMap::new(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), config, execution_props: ExecutionProps::new(), @@ -1490,7 +1361,19 @@ impl SessionState { table_factories, } } - + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated( + since = "32.0.0", + note = "Use SessionState::new_with_config_rt_and_catalog_list" + )] + pub fn with_config_rt_and_catalog_list( + config: SessionConfig, + runtime: Arc, + catalog_list: Arc, + ) -> Self { + Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) + } fn register_default_schema( config: &SessionConfig, table_factories: &HashMap>, @@ -1561,17 +1444,14 @@ impl SessionState { self.catalog_list .catalog(&resolved_ref.catalog) .ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "failed to resolve catalog: {}", resolved_ref.catalog - )) + ) })? .schema(&resolved_ref.schema) .ok_or_else(|| { - DataFusionError::Plan(format!( - "failed to resolve schema: {}", - resolved_ref.schema - )) + plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema) }) } @@ -1613,7 +1493,7 @@ impl SessionState { mut self, physical_optimizers: Vec>, ) -> Self { - self.physical_optimizers = physical_optimizers; + self.physical_optimizers = PhysicalOptimizer::with_rules(physical_optimizers); self } @@ -1640,7 +1520,7 @@ impl SessionState { mut self, optimizer_rule: Arc, ) -> Self { - self.physical_optimizers.push(optimizer_rule); + self.physical_optimizers.rules.push(optimizer_rule); self } @@ -1665,24 +1545,25 @@ impl SessionState { &mut self.table_factories } - /// Convert a SQL string into an AST Statement + /// Parse an SQL string into an DataFusion specific AST + /// [`Statement`]. See [`SessionContext::sql`] for running queries. pub fn sql_to_statement( &self, sql: &str, dialect: &str, ) -> Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ MsSQL, ClickHouse, BigQuery, Ansi." - )) + ) })?; let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; if statements.len() > 1 { - return Err(DataFusionError::NotImplemented( - "The context currently only supports a single SQL statement".to_string(), - )); + return not_impl_err!( + "The context currently only supports a single SQL statement" + ); } let statement = statements.pop_front().ok_or_else(|| { DataFusionError::NotImplemented( @@ -1735,30 +1616,39 @@ impl SessionState { } let mut visitor = RelationVisitor(&mut relations); - match statement { - DFStatement::Statement(s) => { - let _ = s.as_ref().visit(&mut visitor); - } - DFStatement::CreateExternalTable(table) => { - visitor - .0 - .insert(ObjectName(vec![Ident::from(table.name.as_str())])); - } - DFStatement::DescribeTableStmt(table) => visitor.insert(&table.table_name), - DFStatement::CopyTo(CopyToStatement { - source, - target: _, - options: _, - }) => match source { - CopyToSource::Relation(table_name) => { - visitor.insert(table_name); + fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor<'_>) { + match statement { + DFStatement::Statement(s) => { + let _ = s.as_ref().visit(visitor); } - CopyToSource::Query(query) => { - query.visit(&mut visitor); + DFStatement::CreateExternalTable(table) => { + visitor + .0 + .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - }, + DFStatement::DescribeTableStmt(table) => { + visitor.insert(&table.table_name) + } + DFStatement::CopyTo(CopyToStatement { + source, + target: _, + options: _, + }) => match source { + CopyToSource::Relation(table_name) => { + visitor.insert(table_name); + } + CopyToSource::Query(query) => { + query.visit(visitor); + } + }, + DFStatement::Explain(explain) => { + visit_statement(&explain.statement, visitor) + } + } } + visit_statement(statement, &mut visitor); + // Always include information_schema if available if self.config.information_schema() { for s in INFORMATION_SCHEMA_TABLES { @@ -1815,9 +1705,15 @@ impl SessionState { query.statement_to_plan(statement) } - /// Creates a [`LogicalPlan`] from the provided SQL string + /// Creates a [`LogicalPlan`] from the provided SQL string. This + /// interface will plan any SQL DataFusion supports, including DML + /// like `CREATE TABLE`, and `COPY` (which can write to local + /// files. /// - /// See [`SessionContext::sql`] for a higher-level interface that also handles DDL + /// See [`SessionContext::sql`] and + /// [`SessionContext::sql_with_options`] for a higher-level + /// interface that handles DDL and verification of allowed + /// statements. pub async fn create_logical_plan(&self, sql: &str) -> Result { let dialect = self.config.options().sql_parser.dialect.as_str(); let statement = self.sql_to_statement(sql, dialect)?; @@ -1898,7 +1794,11 @@ impl SessionState { /// Creates a physical plan from a logical plan. /// - /// Note: this first calls [`Self::optimize`] on the provided plan + /// Note: this first calls [`Self::optimize`] on the provided + /// plan. + /// + /// This function will error for [`LogicalPlan`]s such as catalog + /// DDL `CREATE TABLE` must be handled by another layer. pub async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1931,7 +1831,7 @@ impl SessionState { /// Return the physical optimizers pub fn physical_optimizers(&self) -> &[Arc] { - &self.physical_optimizers + &self.physical_optimizers.rules } /// return the configuration options @@ -1959,6 +1859,11 @@ impl SessionState { &self.aggregate_functions } + /// Return reference to window functions + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + /// Return [SerializerRegistry] for extensions pub fn serializer_registry(&self) -> Arc { self.serializer_registry.clone() @@ -1976,12 +1881,28 @@ struct SessionContextProvider<'a> { } impl<'a> ContextProvider for SessionContextProvider<'a> { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { let name = self.state.resolve_table_ref(name).to_string(); self.tables .get(&name) .cloned() - .ok_or_else(|| DataFusionError::Plan(format!("table '{name}' not found"))) + .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) + } + + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) } fn get_function_meta(&self, name: &str) -> Option> { @@ -1992,6 +1913,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; @@ -2024,9 +1949,7 @@ impl FunctionRegistry for SessionState { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDF named \"{name}\" in the registry" - )) + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") }) } @@ -2034,9 +1957,15 @@ impl FunctionRegistry for SessionState { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Plan(format!( - "There is no UDAF named \"{name}\" in the registry" - )) + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") + }) + } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") }) } } @@ -2046,6 +1975,10 @@ impl OptimizerConfig for SessionState { self.execution_props.query_execution_start_time } + fn alias_generator(&self) -> Arc { + self.execution_props.alias_generator.clone() + } + fn options(&self) -> &ConfigOptions { self.config_options() } @@ -2068,6 +2001,7 @@ impl From<&SessionState> for TaskContext { state.config.clone(), state.scalar_functions.clone(), state.aggregate_functions.clone(), + state.window_functions.clone(), state.runtime_env.clone(), ) } @@ -2082,10 +2016,10 @@ impl SerializerRegistry for EmptySerializerRegistry { &self, node: &dyn UserDefinedLogicalNode, ) -> Result> { - Err(DataFusionError::NotImplemented(format!( + not_impl_err!( "Serializing user defined logical plan node `{}` is not supported", node.name() - ))) + ) } fn deserialize_logical_plan( @@ -2093,32 +2027,115 @@ impl SerializerRegistry for EmptySerializerRegistry { name: &str, _bytes: &[u8], ) -> Result> { - Err(DataFusionError::NotImplemented(format!( + not_impl_err!( "Deserializing user defined logical plan node `{name}` is not supported" - ))) + ) + } +} + +/// Describes which SQL statements can be run. +/// +/// See [`SessionContext::sql_with_options`] for more details. +#[derive(Clone, Debug, Copy)] +pub struct SQLOptions { + /// See [`Self::with_allow_ddl`] + allow_ddl: bool, + /// See [`Self::with_allow_dml`] + allow_dml: bool, + /// See [`Self::with_allow_statements`] + allow_statements: bool, +} + +impl Default for SQLOptions { + fn default() -> Self { + Self { + allow_ddl: true, + allow_dml: true, + allow_statements: true, + } + } +} + +impl SQLOptions { + /// Create a new `SQLOptions` with default values + pub fn new() -> Self { + Default::default() + } + + /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`. + pub fn with_allow_ddl(mut self, allow: bool) -> Self { + self.allow_ddl = allow; + self + } + + /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true` + pub fn with_allow_dml(mut self, allow: bool) -> Self { + self.allow_dml = allow; + self + } + + /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true` + pub fn with_allow_statements(mut self, allow: bool) -> Self { + self.allow_statements = allow; + self + } + + /// Return an error if the [`LogicalPlan`] has any nodes that are + /// incompatible with this [`SQLOptions`]. + pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> { + plan.visit(&mut BadPlanVisitor::new(self))?; + Ok(()) + } +} + +struct BadPlanVisitor<'a> { + options: &'a SQLOptions, +} +impl<'a> BadPlanVisitor<'a> { + fn new(options: &'a SQLOptions) -> Self { + Self { options } + } +} + +impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { + type N = LogicalPlan; + + fn pre_visit(&mut self, node: &Self::N) -> Result { + match node { + LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { + plan_err!("DDL not supported: {}", ddl.name()) + } + LogicalPlan::Dml(dml) if !self.options.allow_dml => { + plan_err!("DML not supported: {}", dml.op) + } + LogicalPlan::Copy(_) if !self.options.allow_dml => { + plan_err!("DML not supported: COPY") + } + LogicalPlan::Statement(stmt) if !self.options.allow_statements => { + plan_err!("Statement not supported: {}", stmt.name()) + } + _ => Ok(VisitRecursion::Continue), + } } } #[cfg(test)] mod tests { + use super::super::options::CsvReadOptions; use super::*; use crate::assert_batches_eq; use crate::execution::context::QueryPlanner; use crate::execution::memory_pool::MemoryConsumer; use crate::execution::runtime_env::RuntimeConfig; - use crate::physical_plan::expressions::AvgAccumulator; use crate::test; - use crate::test_util::parquet_test_data; + use crate::test_util::{plan_and_collect, populate_csv_partitions}; use crate::variable::VarType; - use arrow::array::ArrayRef; - use arrow::record_batch::RecordBatch; + use arrow_schema::Schema; use async_trait::async_trait; - use datafusion_expr::{create_udaf, create_udf, Expr, Volatility}; - use datafusion_physical_expr::functions::make_scalar_function; - use std::fs::File; + use datafusion_expr::Expr; + use std::env; use std::path::PathBuf; use std::sync::Weak; - use std::{env, io::prelude::*}; use tempfile::TempDir; #[tokio::test] @@ -2136,7 +2153,7 @@ mod tests { let disk_manager = ctx1.runtime_env().disk_manager.clone(); let ctx2 = - SessionContext::with_config_rt(SessionConfig::new(), ctx1.runtime_env()); + SessionContext::new_with_config_rt(SessionConfig::new(), ctx1.runtime_env()); assert_eq!(ctx1.runtime_env().memory_pool.reserved(), 100); assert_eq!(ctx2.runtime_env().memory_pool.reserved(), 100); @@ -2174,7 +2191,7 @@ mod tests { plan_and_collect(&ctx, "SELECT @@version, @name, @integer + 1 FROM dual") .await?; - let expected = vec![ + let expected = [ "+----------------------+------------------------+---------------------+", "| @@version | @name | @integer + Int64(1) |", "+----------------------+------------------------+---------------------+", @@ -2190,13 +2207,10 @@ mod tests { async fn create_variable_err() -> Result<()> { let ctx = SessionContext::new(); - let err = plan_and_collect(&ctx, "SElECT @= X#=?!~ 5") - .await - .unwrap_err(); - + let err = plan_and_collect(&ctx, "SElECT @= X3").await.unwrap_err(); assert_eq!( - err.to_string(), - "Execution error: variable [\"@\"] has no type information" + err.strip_backtrace(), + "Error during planning: variable [\"@=\"] has no type information" ); Ok(()) } @@ -2216,125 +2230,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); - let myfunc = make_scalar_function(myfunc); - - ctx.register_udf(create_udf( - "MY_FUNC", - vec![DataType::Int32], - Arc::new(DataType::Int32), - Volatility::Immutable, - myfunc, - )); - - // doesn't work as it was registered with non lowercase - let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_func\'")); - - // Can call it if you put quotes - let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; - - let expected = vec![ - "+--------------+", - "| MY_FUNC(t.i) |", - "+--------------+", - "| 1 |", - "+--------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // Note capitalization - let my_avg = create_udaf( - "MY_AVG", - DataType::Float64, - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), - Arc::new(vec![DataType::UInt64, DataType::Float64]), - ); - - ctx.register_udaf(my_avg); - - // doesn't work as it was registered as non lowercase - let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_avg\'")); - - // Can call it if you put quotes - let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM t").await?; - - let expected = vec![ - "+-------------+", - "| MY_AVG(t.i) |", - "+-------------+", - "| 1.0 |", - "+-------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn query_csv_with_custom_partition_extension() -> Result<()> { - let tmp_dir = TempDir::new()?; - - // The main stipulation of this test: use a file extension that isn't .csv. - let file_extension = ".tst"; - - let ctx = SessionContext::new(); - let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; - ctx.register_csv( - "test", - tmp_dir.path().to_str().unwrap(), - CsvReadOptions::new() - .schema(&schema) - .file_extension(file_extension), - ) - .await?; - let results = - plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; - - assert_eq!(results.len(), 1); - let expected = vec![ - "+--------------+--------------+-----------------+", - "| SUM(test.c1) | SUM(test.c2) | COUNT(UInt8(1)) |", - "+--------------+--------------+-----------------+", - "| 10 | 110 | 20 |", - "+--------------+--------------+-----------------+", - ]; - assert_batches_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn send_context_to_threads() -> Result<()> { // ensure SessionContexts can be used in a multi-threaded @@ -2372,8 +2267,8 @@ mod tests { .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::with_config_rt(cfg, runtime); - let ctx = SessionContext::with_state(session_state); + let session_state = SessionState::new_with_config_rt(cfg, runtime); + let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; let result = @@ -2398,9 +2293,10 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = SessionState::with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); - let ctx = SessionContext::with_state(session_state); + let session_state = + SessionState::new_with_config_rt(SessionConfig::new(), runtime) + .with_query_planner(Arc::new(MyQueryPlanner {})); + let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; df.collect().await.expect_err("query not supported"); @@ -2409,7 +2305,7 @@ mod tests { #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_create_default_catalog_and_schema(false), ); @@ -2452,7 +2348,7 @@ mod tests { } async fn catalog_and_schema_test(config: SessionConfig) { - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema @@ -2471,7 +2367,7 @@ mod tests { .await .unwrap(); - let expected = vec![ + let expected = [ "+-------+", "| count |", "+-------+", @@ -2513,7 +2409,7 @@ mod tests { ) .await?; - let expected = vec![ + let expected = [ "+-----+-------+", "| cat | total |", "+-----+-------+", @@ -2529,7 +2425,7 @@ mod tests { #[tokio::test] async fn catalogs_not_leaked() { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2552,7 +2448,7 @@ mod tests { #[tokio::test] async fn sql_create_schema() -> Result<()> { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2575,7 +2471,7 @@ mod tests { #[tokio::test] async fn sql_create_catalog() -> Result<()> { // the information schema used to introduce cyclic Arcs - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new().with_information_schema(true), ); @@ -2598,97 +2494,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn read_with_glob_path() -> Result<()> { - let ctx = SessionContext::new(); - - let df = ctx - .read_parquet( - format!("{}/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - - #[tokio::test] - async fn read_with_glob_path_issue_2465() -> Result<()> { - let ctx = SessionContext::new(); - - let df = ctx - .read_parquet( - // it was reported that when a path contains // (two consecutive separator) no files were found - // in this test, regardless of parquet_test_data() value, our path now contains a // - format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - - #[tokio::test] - async fn read_from_registered_table_with_glob_path() -> Result<()> { - let ctx = SessionContext::new(); - - ctx.register_parquet( - "test", - &format!("{}/alltypes_plain*.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - let df = ctx.sql("SELECT * FROM test").await?; - let results = df.collect().await?; - let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); - // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows - assert_eq!(total_rows, 10); - Ok(()) - } - - #[tokio::test] - async fn unsupported_sql_returns_error() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("test", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - let state = ctx.state(); - - // create view - let sql = "create view test_view as select * from test"; - let plan = state.create_logical_plan(sql).await; - let physical_plan = state.create_physical_plan(&plan.unwrap()).await; - assert!(physical_plan.is_err()); - assert_eq!( - format!("{}", physical_plan.unwrap_err()), - "This feature is not implemented: Unsupported logical plan: CreateView" - ); - // // drop view - let sql = "drop view test_view"; - let plan = state.create_logical_plan(sql).await; - let physical_plan = state.create_physical_plan(&plan.unwrap()).await; - assert!(physical_plan.is_err()); - assert_eq!( - format!("{}", physical_plan.unwrap_err()), - "This feature is not implemented: Unsupported logical plan: DropView" - ); - // // drop table - let sql = "drop table test"; - let plan = state.create_logical_plan(sql).await; - let physical_plan = state.create_physical_plan(&plan.unwrap()).await; - assert!(physical_plan.is_err()); - assert_eq!( - format!("{}", physical_plan.unwrap_err()), - "This feature is not implemented: Unsupported logical plan: DropTable" - ); - Ok(()) - } - struct MyPhysicalPlanner {} #[async_trait] @@ -2698,9 +2503,7 @@ mod tests { _logical_plan: &LogicalPlan, _session_state: &SessionState, ) -> Result> { - Err(DataFusionError::NotImplemented( - "query not supported".to_string(), - )) + not_impl_err!("query not supported") } fn create_physical_expr( @@ -2730,50 +2533,14 @@ mod tests { } } - /// Execute SQL and return results - async fn plan_and_collect( - ctx: &SessionContext, - sql: &str, - ) -> Result> { - ctx.sql(sql).await?.collect().await - } - - /// Generate CSV partitions within the supplied directory - fn populate_csv_partitions( - tmp_dir: &TempDir, - partition_count: usize, - file_extension: &str, - ) -> Result { - // define schema for data source (csv file) - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - // generate a partitioned file - for partition in 0..partition_count { - let filename = format!("partition-{partition}.{file_extension}"); - let file_path = tmp_dir.path().join(filename); - let mut file = File::create(file_path)?; - - // generate some data - for i in 0..=10 { - let data = format!("{},{},{}\n", partition, i, i % 2 == 0); - file.write_all(data.as_bytes())?; - } - } - - Ok(schema) - } - /// Generate a partitioned CSV file and register it with an execution context async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = - SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(8), + ); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -2787,37 +2554,4 @@ mod tests { Ok(ctx) } - - // Test for compilation error when calling read_* functions from an #[async_trait] function. - // See https://github.com/apache/arrow-datafusion/issues/1154 - #[async_trait] - trait CallReadTrait { - async fn call_read_csv(&self) -> DataFrame; - async fn call_read_avro(&self) -> DataFrame; - async fn call_read_parquet(&self) -> DataFrame; - } - - struct CallRead {} - - #[async_trait] - impl CallReadTrait for CallRead { - async fn call_read_csv(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() - } - - async fn call_read_avro(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_avro("dummy", AvroReadOptions::default()) - .await - .unwrap() - } - - async fn call_read_parquet(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_parquet("dummy", ParquetReadOptions::default()) - .await - .unwrap() - } - } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs new file mode 100644 index 0000000000000..5d649d3e6df8e --- /dev/null +++ b/datafusion/core/src/execution/context/parquet.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::datasource::physical_plan::parquet::plan_to_parquet; +use parquet::file::properties::WriterProperties; + +use super::super::options::{ParquetReadOptions, ReadOptions}; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; + +impl SessionContext { + /// Creates a [`DataFrame`] for reading a Parquet data source. + /// + /// For more control such as reading multiple files, you can use + /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. + /// + /// For an example, see [`read_csv`](Self::read_csv) + pub async fn read_parquet( + &self, + table_paths: P, + options: ParquetReadOptions<'_>, + ) -> Result { + self._read_type(table_paths, options).await + } + + /// Registers a Parquet file as a table that can be referenced from SQL + /// statements executed against this context. + pub async fn register_parquet( + &self, + name: &str, + table_path: &str, + options: ParquetReadOptions<'_>, + ) -> Result<()> { + let listing_options = options.to_listing_options(&self.state.read().config); + + self.register_listing_table( + name, + table_path, + listing_options, + options.schema.map(|s| Arc::new(s.to_owned())), + None, + ) + .await?; + Ok(()) + } + + /// Executes a query and writes the results to a partitioned Parquet file. + pub async fn write_parquet( + &self, + plan: Arc, + path: impl AsRef, + writer_properties: Option, + ) -> Result<()> { + plan_to_parquet(self.task_ctx(), plan, path, writer_properties).await + } +} + +#[cfg(test)] +mod tests { + use async_trait::async_trait; + + use crate::arrow::array::{Float32Array, Int32Array}; + use crate::arrow::datatypes::{DataType, Field, Schema}; + use crate::arrow::record_batch::RecordBatch; + use crate::dataframe::DataFrameWriteOptions; + use crate::parquet::basic::Compression; + use crate::test_util::parquet_test_data; + use tempfile::tempdir; + + use super::*; + + #[tokio::test] + async fn read_with_glob_path() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx + .read_parquet( + format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_with_glob_path_issue_2465() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx + .read_parquet( + // it was reported that when a path contains // (two consecutive separator) no files were found + // in this test, regardless of parquet_test_data() value, our path now contains a // + format!("{}/..//*/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_from_registered_table_with_glob_path() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_parquet( + "test", + &format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx.sql("SELECT * FROM test").await?; + let results = df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + // alltypes_plain.parquet = 8 rows, alltypes_plain.snappy.parquet = 2 rows, alltypes_dictionary.parquet = 2 rows + assert_eq!(total_rows, 10); + Ok(()) + } + + #[tokio::test] + async fn read_from_different_file_extension() -> Result<()> { + let ctx = SessionContext::new(); + let sep = std::path::MAIN_SEPARATOR.to_string(); + + // Make up a new dataframe. + let write_df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("purchase_id", DataType::Int32, false), + Field::new("price", DataType::Float32, false), + Field::new("quantity", DataType::Int32, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])), + Arc::new(Int32Array::from(vec![1, 3, 2, 4, 3])), + ], + )?)?; + + let temp_dir = tempdir()?; + let temp_dir_path = temp_dir.path(); + let path1 = temp_dir_path + .join("output1.parquet") + .to_str() + .unwrap() + .to_string(); + let path2 = temp_dir_path + .join("output2.parquet.snappy") + .to_str() + .unwrap() + .to_string(); + let path3 = temp_dir_path + .join("output3.parquet.snappy.parquet") + .to_str() + .unwrap() + .to_string(); + + let path4 = temp_dir_path + .join("output4.parquet".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + + let path5 = temp_dir_path + .join("bbb..bbb") + .join("filename.parquet") + .to_str() + .unwrap() + .to_string(); + let dir = temp_dir_path + .join("bbb..bbb".to_owned() + &sep) + .to_str() + .unwrap() + .to_string(); + std::fs::create_dir(dir).expect("create dir failed"); + + // Write the dataframe to a parquet file named 'output1.parquet' + write_df + .clone() + .write_parquet( + &path1, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output2.parquet.snappy' + write_df + .clone() + .write_parquet( + &path2, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' + write_df + .clone() + .write_parquet( + &path3, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet' + write_df + .write_parquet( + &path5, + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Read the dataframe from 'output1.parquet' with the default file extension. + let read_df = ctx + .read_parquet( + &path1, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output2.parquet.snappy' with the correct file extension. + let read_df = ctx + .read_parquet( + &path2, + ParquetReadOptions { + file_extension: "snappy", + ..Default::default() + }, + ) + .await?; + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension. + let read_df = ctx + .read_parquet( + &path2, + ParquetReadOptions { + ..Default::default() + }, + ) + .await; + let binding = DataFilePaths::to_urls(&path2).unwrap(); + let expexted_path = binding[0].as_str(); + assert_eq!( + read_df.unwrap_err().strip_backtrace(), + format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path) + ); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. + let read_df = ctx + .read_parquet( + &path3, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output4/' + std::fs::create_dir(&path4)?; + let read_df = ctx + .read_parquet( + &path4, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 0); + + // Read the datafram from doule dot folder; + let read_df = ctx + .read_parquet( + &path5, + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + Ok(()) + } + + // Test for compilation error when calling read_* functions from an #[async_trait] function. + // See https://github.com/apache/arrow-datafusion/issues/1154 + #[async_trait] + trait CallReadTrait { + async fn call_read_parquet(&self) -> DataFrame; + } + + struct CallRead {} + + #[async_trait] + impl CallReadTrait for CallRead { + async fn call_read_parquet(&self) -> DataFrame { + let ctx = SessionContext::new(); + ctx.read_parquet("dummy", ParquetReadOptions::default()) + .await + .unwrap() + } + } +} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 3e58923c3aad4..b3ebbc6e3637e 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -17,9 +17,9 @@ #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that -//! uses [Apache Arrow] as its in-memory format. DataFusion's [use -//! cases] include building very fast database and analytic systems, -//! customized to particular workloads. +//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use +//! cases] help developers build very fast and feature rich database +//! and analytic systems, customized to particular workloads. //! //! "Out of the box," DataFusion quickly runs complex [SQL] and //! [`DataFrame`] queries using a sophisticated query planner, a columnar, @@ -132,23 +132,30 @@ //! //! ## Customization and Extension //! -//! DataFusion supports extension at many points: +//! DataFusion is a "disaggregated" query engine. This +//! means developers can start with a working, full featured engine, and then +//! extend the parts of DataFusion they need to specialize for their usecase. For example, +//! some projects may add custom [`ExecutionPlan`] operators, or create their own +//! query language that directly creates [`LogicalPlan`] rather than using the +//! built in SQL planner, [`SqlToRel`]. +//! +//! In order to achieve this, DataFusion supports extension at many points: //! //! * read from any datasource ([`TableProvider`]) //! * define your own catalogs, schemas, and table lists ([`CatalogProvider`]) -//! * build your own query langue or plans using the ([`LogicalPlanBuilder`]) -//! * declare and use user-defined scalar functions ([`ScalarUDF`]) -//! * declare and use user-defined aggregate functions ([`AggregateUDF`]) +//! * build your own query language or plans ([`LogicalPlanBuilder`]) +//! * declare and use user-defined functions ([`ScalarUDF`], and [`AggregateUDF`], [`WindowUDF`]) //! * add custom optimizer rewrite passes ([`OptimizerRule`] and [`PhysicalOptimizerRule`]) //! * extend the planner to use user-defined logical and physical nodes ([`QueryPlanner`]) //! //! You can find examples of each of them in the [datafusion-examples] directory. //! //! [`TableProvider`]: crate::datasource::TableProvider -//! [`CatalogProvider`]: crate::catalog::catalog::CatalogProvider +//! [`CatalogProvider`]: crate::catalog::CatalogProvider //! [`LogicalPlanBuilder`]: datafusion_expr::logical_plan::builder::LogicalPlanBuilder -//! [`ScalarUDF`]: physical_plan::udf::ScalarUDF -//! [`AggregateUDF`]: physical_plan::udaf::AggregateUDF +//! [`ScalarUDF`]: crate::logical_expr::ScalarUDF +//! [`AggregateUDF`]: crate::logical_expr::AggregateUDF +//! [`WindowUDF`]: crate::logical_expr::WindowUDF //! [`QueryPlanner`]: execution::context::QueryPlanner //! [`OptimizerRule`]: datafusion_optimizer::optimizer::OptimizerRule //! [`PhysicalOptimizerRule`]: crate::physical_optimizer::optimizer::PhysicalOptimizerRule @@ -274,14 +281,22 @@ //! [`MemTable`]: crate::datasource::memory::MemTable //! [`StreamingTable`]: crate::datasource::streaming::StreamingTable //! -//! ## Plans +//! ## Plan Representations //! -//! Logical planning yields [`LogicalPlan`]s nodes and [`Expr`] +//! ### Logical Plans +//! Logical planning yields [`LogicalPlan`] nodes and [`Expr`] //! expressions which are [`Schema`] aware and represent statements //! independent of how they are physically executed. //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! +//! Examples of working with and executing `Expr`s can be found in the +//! [`expr_api`.rs] example +//! +//! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! +//! ### Physical Plans +//! //! An [`ExecutionPlan`] (sometimes referred to as a "physical plan") //! is a plan that can be executed against data. It a DAG of other //! [`ExecutionPlan`]s each potentially containing expressions of the @@ -325,13 +340,17 @@ //! ``` //! //! [`ExecutionPlan`]s process data using the [Apache Arrow] memory -//! format, largely with functions from the [arrow] crate. When -//! [`execute`] is called, a [`SendableRecordBatchStream`] is returned -//! that produces the desired output as a [`Stream`] of [`RecordBatch`]es. +//! format, making heavy use of functions from the [arrow] +//! crate. Calling [`execute`] produces 1 or more partitions of data, +//! consisting an operator that implements +//! [`SendableRecordBatchStream`]. //! -//! Values are -//! represented with [`ColumnarValue`], which are either single -//! constant values ([`ScalarValue`]) or Arrow Arrays ([`ArrayRef`]). +//! Values are represented with [`ColumnarValue`], which are either +//! [`ScalarValue`] (single constant values) or [`ArrayRef`] (Arrow +//! Arrays). +//! +//! Balanced parallelism is achieved using [`RepartitionExec`], which +//! implements a [Volcano style] "Exchange". //! //! [`execute`]: physical_plan::ExecutionPlan::execute //! [`SendableRecordBatchStream`]: crate::physical_plan::SendableRecordBatchStream @@ -340,9 +359,10 @@ //! [`ArrayRef`]: arrow::array::ArrayRef //! [`Stream`]: futures::stream::Stream //! -//! //! See the [implementors of `ExecutionPlan`] for a list of physical operators available. //! +//! [`RepartitionExec`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/repartition/struct.RepartitionExec.html +//! [Volcano style]: https://w6113.github.io/files/papers/volcanoparallelism-89.pdf //! [implementors of `ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#implementors //! //! ## State Management and Configuration @@ -384,11 +404,11 @@ //! and improve compilation times. The crates are: //! //! * [datafusion_common]: Common traits and types -//! * [datafusion_execution]: State needed for execution //! * [datafusion_expr]: [`LogicalPlan`], [`Expr`] and related logical planning structure +//! * [datafusion_execution]: State and structures needed for execution //! * [datafusion_optimizer]: [`OptimizerRule`]s and [`AnalyzerRule`]s //! * [datafusion_physical_expr]: [`PhysicalExpr`] and related expressions -//! * [datafusion_sql]: [`SqlToRel`] SQL planner +//! * [datafusion_sql]: SQL planner ([`SqlToRel`]) //! //! [sqlparser]: https://docs.rs/sqlparser/latest/sqlparser //! [`SqlToRel`]: sql::planner::SqlToRel @@ -397,7 +417,7 @@ //! [`AnalyzerRule`]: datafusion_optimizer::analyzer::AnalyzerRule //! [`OptimizerRule`]: optimizer::optimizer::OptimizerRule //! [`ExecutionPlan`]: physical_plan::ExecutionPlan -//! [`PhysicalPlanner`]: physical_plan::PhysicalPlanner +//! [`PhysicalPlanner`]: physical_planner::PhysicalPlanner //! [`PhysicalOptimizerRule`]: datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule //! [`Schema`]: arrow::datatypes::Schema //! [`PhysicalExpr`]: physical_plan::PhysicalExpr @@ -412,30 +432,64 @@ pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); extern crate core; extern crate sqlparser; -pub mod avro_to_arrow; pub mod catalog; pub mod dataframe; pub mod datasource; pub mod error; pub mod execution; pub mod physical_optimizer; -pub mod physical_plan; +pub mod physical_planner; pub mod prelude; pub mod scalar; pub mod variable; -// re-export dependencies from arrow-rs to minimise version maintenance for crate users +// re-export dependencies from arrow-rs to minimize version maintenance for crate users pub use arrow; +#[cfg(feature = "parquet")] pub use parquet; -// re-export DataFusion crates -pub use datafusion_common as common; -pub use datafusion_common::config; -pub use datafusion_expr as logical_expr; -pub use datafusion_optimizer as optimizer; -pub use datafusion_physical_expr as physical_expr; -pub use datafusion_row as row; -pub use datafusion_sql as sql; +// re-export DataFusion sub-crates at the top level. Use `pub use *` +// so that the contents of the subcrates appears in rustdocs +// for details, see https://github.com/apache/arrow-datafusion/issues/6648 + +/// re-export of [`datafusion_common`] crate +pub mod common { + pub use datafusion_common::*; +} + +// Backwards compatibility +pub use common::config; + +// NB datafusion execution is re-exported in the `execution` module + +/// re-export of [`datafusion_expr`] crate +pub mod logical_expr { + pub use datafusion_expr::*; +} + +/// re-export of [`datafusion_optimizer`] crate +pub mod optimizer { + pub use datafusion_optimizer::*; +} + +/// re-export of [`datafusion_physical_expr`] crate +pub mod physical_expr { + pub use datafusion_physical_expr::*; +} + +/// re-export of [`datafusion_physical_plan`] crate +pub mod physical_plan { + pub use datafusion_physical_plan::*; +} + +// Reexport testing macros for compatibility +pub use datafusion_common::assert_batches_eq; +pub use datafusion_common::assert_batches_sorted_eq; + +/// re-export of [`datafusion_sql`] crate +pub mod sql { + pub use datafusion_sql::*; +} #[cfg(test)] pub mod test; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 396e66972f304..795857b10ef5b 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -18,27 +18,25 @@ //! Utilizing exact statistics from sources to avoid scanning data use std::sync::Arc; +use super::optimizer::PhysicalOptimizerRule; use crate::config::ConfigOptions; -use datafusion_common::tree_node::TreeNode; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; - -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; -use crate::physical_plan::empty::EmptyExec; +use crate::error::Result; +use crate::physical_plan::aggregates::AggregateExec; use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{ - expressions, AggregateExpr, ColumnStatistics, ExecutionPlan, Statistics, -}; +use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; -use super::optimizer::PhysicalOptimizerRule; -use crate::error::Result; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNode; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; /// Optimizer that uses available statistics for aggregate functions #[derive(Default)] pub struct AggregateStatistics {} /// The name of the column corresponding to [`COUNT_STAR_EXPANSION`] -const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))"; +const COUNT_STAR_NAME: &str = "COUNT(*)"; impl AggregateStatistics { #[allow(missing_docs)] @@ -58,7 +56,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics(); + let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { if let Some((non_null_rows, name)) = @@ -84,7 +82,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { // input can be entirely removed Ok(Arc::new(ProjectionExec::try_new( projections, - Arc::new(EmptyExec::new(true, plan.schema())), + Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { plan.map_children(|child| self.optimize(child, _config)) @@ -107,13 +105,12 @@ impl PhysicalOptimizerRule for AggregateStatistics { /// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: /// - its child (with possible intermediate layers) is a partial `AggregateExec` node /// - they both have no grouping expression -/// - the statistics are exact /// If this is the case, return a ref to the partial `AggregateExec`, else `None`. /// We would have preferred to return a casted ref to AggregateExec but the recursion requires /// the `ExecutionPlan.children()` method that returns an owned reference. fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { if let Some(final_agg_exec) = node.as_any().downcast_ref::() { - if final_agg_exec.mode() == &AggregateMode::Final + if !final_agg_exec.mode().is_first_stage() && final_agg_exec.group_expr().is_empty() { let mut child = Arc::clone(final_agg_exec.input()); @@ -121,14 +118,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> if let Some(partial_agg_exec) = child.as_any().downcast_ref::() { - if partial_agg_exec.mode() == &AggregateMode::Partial + if partial_agg_exec.mode().is_first_stage() && partial_agg_exec.group_expr().is_empty() && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) { - let stats = partial_agg_exec.input().statistics(); - if stats.is_exact { - return Some(child); - } + return Some(child); } } if let [ref childrens_child] = child.children().as_slice() { @@ -142,13 +136,13 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that is defined in the statistics, return it +/// If this agg_expr is a count that is exactly defined in the statistics, return it. fn take_optimizable_table_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, &'static str)> { - if let (Some(num_rows), Some(casted_expr)) = ( - stats.num_rows, + if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { // TODO implementing Eq on PhysicalExpr would help a lot here @@ -169,14 +163,14 @@ fn take_optimizable_table_count( None } -/// If this agg_expr is a count that can be derived from the statistics, return it +/// If this agg_expr is a count that can be exactly derived from the statistics, return it. fn take_optimizable_column_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = ( - stats.num_rows, - &stats.column_statistics, + let col_stats = &stats.column_statistics; + if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( + &stats.num_rows, agg_expr.as_any().downcast_ref::(), ) { if casted_expr.expressions().len() == 1 { @@ -185,11 +179,8 @@ fn take_optimizable_column_count( .as_any() .downcast_ref::() { - if let ColumnStatistics { - null_count: Some(val), - .. - } = &col_stats[col_expr.index()] - { + let current_val = &col_stats[col_expr.index()].null_count; + if let &Precision::Exact(val) = current_val { return Some(( ScalarValue::Int64(Some((num_rows - val) as i64)), casted_expr.name().to_string(), @@ -201,27 +192,23 @@ fn take_optimizable_column_count( None } -/// If this agg_expr is a min that is defined in the statistics, return it +/// If this agg_expr is a min that is exactly defined in the statistics, return it. fn take_optimizable_min( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(col_stats), Some(casted_expr)) = ( - &stats.column_statistics, - agg_expr.as_any().downcast_ref::(), - ) { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] .as_any() .downcast_ref::() { - if let ColumnStatistics { - min_value: Some(val), - .. - } = &col_stats[col_expr.index()] - { - return Some((val.clone(), casted_expr.name().to_string())); + if let Precision::Exact(val) = &col_stats[col_expr.index()].min_value { + if !val.is_null() { + return Some((val.clone(), casted_expr.name().to_string())); + } } } } @@ -229,27 +216,23 @@ fn take_optimizable_min( None } -/// If this agg_expr is a max that is defined in the statistics, return it +/// If this agg_expr is a max that is exactly defined in the statistics, return it. fn take_optimizable_max( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let (Some(col_stats), Some(casted_expr)) = ( - &stats.column_statistics, - agg_expr.as_any().downcast_ref::(), - ) { + let col_stats = &stats.column_statistics; + if let Some(casted_expr) = agg_expr.as_any().downcast_ref::() { if casted_expr.expressions().len() == 1 { // TODO optimize with exprs other than Column if let Some(col_expr) = casted_expr.expressions()[0] .as_any() .downcast_ref::() { - if let ColumnStatistics { - max_value: Some(val), - .. - } = &col_stats[col_expr.index()] - { - return Some((val.clone(), casted_expr.name().to_string())); + if let Precision::Exact(val) = &col_stats[col_expr.index()].max_value { + if !val.is_null() { + return Some((val.clone(), casted_expr.name().to_string())); + } } } } @@ -258,17 +241,10 @@ fn take_optimizable_max( } #[cfg(test)] -mod tests { - use super::*; +pub(crate) mod tests { use std::sync::Arc; - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_int64_array; - use datafusion_physical_expr::expressions::cast; - use datafusion_physical_expr::PhysicalExpr; - + use super::*; use crate::error::Result; use crate::logical_expr::Operator; use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; @@ -279,6 +255,14 @@ mod tests { use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateMode; + /// Mock data using a MemoryExec which has an exact count statistic fn mock_data() -> Result> { let schema = Arc::new(Schema::new(vec![ @@ -308,7 +292,8 @@ mod tests { ) -> Result<()> { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let plan = Arc::new(plan) as _; + let plan: Arc = Arc::new(plan); + let optimized = AggregateStatistics::new() .optimize(Arc::clone(&plan), state.config_options())?; @@ -349,7 +334,7 @@ mod tests { } /// Describe the type of aggregate being tested - enum TestAggregate { + pub(crate) enum TestAggregate { /// Testing COUNT(*) type aggregates CountStar, @@ -358,7 +343,7 @@ mod tests { } impl TestAggregate { - fn new_count_star() -> Self { + pub(crate) fn new_count_star() -> Self { Self::CountStar } @@ -367,7 +352,7 @@ mod tests { } /// Return appropriate expr depending if COUNT is for col or table (*) - fn count_expr(&self) -> Arc { + pub(crate) fn count_expr(&self) -> Arc { Arc::new(Count::new( self.column(), self.column_name(), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 3ec9e9bbd0481..0948445de20dc 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -17,13 +17,15 @@ //! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs //! and try to combine them if necessary + +use std::sync::Arc; + use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::ExecutionPlan; -use datafusion_common::config::ConfigOptions; -use std::sync::Arc; +use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -50,68 +52,62 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { _config: &ConfigOptions, ) -> Result> { plan.transform_down(&|plan| { - let transformed = plan.as_any().downcast_ref::().and_then( - |AggregateExec { - mode: final_mode, - input: final_input, - group_by: final_group_by, - aggr_expr: final_aggr_expr, - filter_expr: final_filter_expr, - .. - }| { - if matches!( - final_mode, - AggregateMode::Final | AggregateMode::FinalPartitioned - ) { - final_input - .as_any() - .downcast_ref::() - .and_then( - |AggregateExec { - mode: input_mode, - input: partial_input, - group_by: input_group_by, - aggr_expr: input_aggr_expr, - filter_expr: input_filter_expr, - order_by_expr: input_order_by_expr, - input_schema, - .. - }| { - if matches!(input_mode, AggregateMode::Partial) - && can_combine( - ( - final_group_by, - final_aggr_expr, - final_filter_expr, - ), - ( - input_group_by, - input_aggr_expr, - input_filter_expr, - ), - ) - { + let transformed = + plan.as_any() + .downcast_ref::() + .and_then(|agg_exec| { + if matches!( + agg_exec.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) { + agg_exec + .input() + .as_any() + .downcast_ref::() + .and_then(|input_agg_exec| { + if matches!( + input_agg_exec.mode(), + AggregateMode::Partial + ) && can_combine( + ( + agg_exec.group_by(), + agg_exec.aggr_expr(), + agg_exec.filter_expr(), + ), + ( + input_agg_exec.group_by(), + input_agg_exec.aggr_expr(), + input_agg_exec.filter_expr(), + ), + ) { + let mode = + if agg_exec.mode() == &AggregateMode::Final { + AggregateMode::Single + } else { + AggregateMode::SinglePartitioned + }; AggregateExec::try_new( - AggregateMode::Single, - input_group_by.clone(), - input_aggr_expr.to_vec(), - input_filter_expr.to_vec(), - input_order_by_expr.to_vec(), - partial_input.clone(), - input_schema.clone(), + mode, + input_agg_exec.group_by().clone(), + input_agg_exec.aggr_expr().to_vec(), + input_agg_exec.filter_expr().to_vec(), + input_agg_exec.order_by_expr().to_vec(), + input_agg_exec.input().clone(), + input_agg_exec.input_schema(), ) + .map(|combined_agg| { + combined_agg.with_limit(agg_exec.limit()) + }) .ok() .map(Arc::new) } else { None } - }, - ) - } else { - None - } - }, - ); + }) + } else { + None + } + }); Ok(if let Some(transformed) = transformed { Transformed::Yes(transformed) @@ -200,10 +196,6 @@ fn discard_column_index(group_expr: Arc) -> Arc { @@ -225,7 +221,7 @@ mod tests { let config = ConfigOptions::new(); let optimized = optimizer.optimize($PLAN, &config)?; // Now format correctly - let plan = displayable(optimized.as_ref()).indent().to_string(); + let plan = displayable(optimized.as_ref()).indent(true).to_string(); let actual_lines = trim_plan_display(&plan); assert_eq!( @@ -257,7 +253,7 @@ mod tests { object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], @@ -435,4 +431,49 @@ mod tests { assert_optimized!(expected, plan); Ok(()) } + + #[test] + fn aggregations_with_limit_combined() -> Result<()> { + let schema = schema(); + let aggr_expr = vec![]; + + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let schema = partial_agg.schema(); + let final_agg = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + final_group_by, + aggr_expr, + vec![], + vec![], + partial_agg, + schema, + ) + .unwrap() + .with_limit(Some(5)), + ); + let plan: Arc = final_agg; + // should combine the Partial/Final AggregateExecs to a Single AggregateExec + // with the final limit preserved + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs deleted file mode 100644 index 4e456450bcb2f..0000000000000 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ /dev/null @@ -1,2192 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! EnforceDistribution optimizer rule inspects the physical plan with respect -//! to distribution requirements and adds [RepartitionExec]s to satisfy them -//! when necessary. -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; -use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, -}; -use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortOptions; -use crate::physical_plan::union::{can_interleave, InterleaveExec, UnionExec}; -use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::Partitioning; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_expr::logical_plan::JoinType; -use datafusion_physical_expr::equivalence::EquivalenceProperties; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::expressions::NoOp; -use datafusion_physical_expr::utils::map_columns_before_projection; -use datafusion_physical_expr::{ - expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, PhysicalExpr, -}; -use std::sync::Arc; - -/// The EnforceDistribution rule ensures that distribution requirements are met -/// in the strictest way. It might add additional [RepartitionExec] to the plan tree -/// and give a non-optimal plan, but it can avoid the possible data skew in joins. -/// -/// For example for a HashJoin with keys(a, b, c), the required Distribution(a, b, c) can be satisfied by -/// several alternative partitioning ways: [(a, b, c), (a, b), (a, c), (b, c), (a), (b), (c), ( )]. -/// -/// This rule only chooses the exactly match and satisfies the Distribution(a, b, c) by a HashPartition(a, b, c). -#[derive(Default)] -pub struct EnforceDistribution {} - -impl EnforceDistribution { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for EnforceDistribution { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - let target_partitions = config.execution.target_partitions; - let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; - let new_plan = if top_down_join_key_reordering { - // Run a top-down process to adjust input key ordering recursively - let plan_requirements = PlanWithKeyRequirements::new(plan); - let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; - adjusted.plan - } else { - plan - }; - // Distribution enforcement needs to be applied bottom-up. - new_plan.transform_up(&|plan| { - let adjusted = if !top_down_join_key_reordering { - reorder_join_keys_to_inputs(plan)? - } else { - plan - }; - ensure_distribution(adjusted, target_partitions) - }) - } - - fn name(&self) -> &str { - "EnforceDistribution" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// When the physical planner creates the Joins, the ordering of join keys is from the original query. -/// That might not match with the output partitioning of the join node's children -/// A Top-Down process will use this method to adjust children's output partitioning based on the parent key reordering requirements: -/// -/// Example: -/// TopJoin on (a, b, c) -/// bottom left join on(b, a, c) -/// bottom right join on(c, b, a) -/// -/// Will be adjusted to: -/// TopJoin on (a, b, c) -/// bottom left join on(a, b, c) -/// bottom right join on(a, b, c) -/// -/// Example: -/// TopJoin on (a, b, c) -/// Agg1 group by (b, a, c) -/// Agg2 group by (c, b, a) -/// -/// Will be adjusted to: -/// TopJoin on (a, b, c) -/// Projection(b, a, c) -/// Agg1 group by (a, b, c) -/// Projection(c, b, a) -/// Agg2 group by (a, b, c) -/// -/// Following is the explanation of the reordering process: -/// -/// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: -/// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. -/// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. -/// Requirements can be satisfied by adjusting keys ordering, clear the current requiements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. -/// -/// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: -/// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. -/// Requirements is already satisfied, clear all the requirements, return the unchanged plan. -/// Requirements can be satisfied by adjusting keys ordering, clear all the requirements, return the changed plan. -/// -/// 3) If the current plan is RepartitionExec, CoalescePartitionsExec or WindowAggExec, clear all the requirements, return the unchanged plan -/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements -/// 5) For other types of operators, by default, pushdown the parent requirements to children. -/// -fn adjust_input_keys_ordering( - requirements: PlanWithKeyRequirements, -) -> Result> { - let parent_required = requirements.required_key_ordering.clone(); - let plan_any = requirements.plan.as_any(); - let transformed = if let Some(HashJoinExec { - left, - right, - on, - filter, - join_type, - mode, - null_equals_null, - .. - }) = plan_any.downcast_ref::() - { - match mode { - PartitionMode::Partitioned => { - let join_constructor = - |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - new_conditions.0, - filter.clone(), - join_type, - PartitionMode::Partitioned, - *null_equals_null, - )?) as Arc) - }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), - &parent_required, - on, - vec![], - &join_constructor, - )?) - } - PartitionMode::CollectLeft => { - let new_right_request = match join_type { - JoinType::Inner | JoinType::Right => shift_right_required( - &parent_required, - left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => { - Some(parent_required.clone()) - } - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full => None, - }; - - // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![None, new_right_request], - }) - } - PartitionMode::Auto => { - // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) - } - } - } else if let Some(CrossJoinExec { left, .. }) = - plan_any.downcast_ref::() - { - let left_columns_len = left.schema().fields().len(); - // Push down requirements to the right side - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![ - None, - shift_right_required(&parent_required, left_columns_len), - ], - }) - } else if let Some(SortMergeJoinExec { - left, - right, - on, - join_type, - sort_options, - null_equals_null, - .. - }) = plan_any.downcast_ref::() - { - let join_constructor = - |new_conditions: (Vec<(Column, Column)>, Vec)| { - Ok(Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right.clone(), - new_conditions.0, - *join_type, - new_conditions.1, - *null_equals_null, - )?) as Arc) - }; - Some(reorder_partitioned_join_keys( - requirements.plan.clone(), - &parent_required, - on, - sort_options.clone(), - &join_constructor, - )?) - } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { - if !parent_required.is_empty() { - match aggregate_exec.mode { - AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( - requirements.plan.clone(), - &parent_required, - aggregate_exec, - )?), - _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), - } - } else { - // Keep everything unchanged - None - } - } else if let Some(ProjectionExec { expr, .. }) = - plan_any.downcast_ref::() - { - // For Projection, we need to transform the requirements to the columns before the Projection - // And then to push down the requirements - // Construct a mapping from new name to the the orginal Column - let new_required = map_columns_before_projection(&parent_required, expr); - if new_required.len() == parent_required.len() { - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(new_required.clone())], - }) - } else { - // Can not satisfy, clear the current requirements and generate new empty requirements - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) - } - } else if plan_any.downcast_ref::().is_some() - || plan_any.downcast_ref::().is_some() - || plan_any.downcast_ref::().is_some() - { - Some(PlanWithKeyRequirements::new(requirements.plan.clone())) - } else { - // By default, push down the parent requirements to children - let children_len = requirements.plan.children().len(); - Some(PlanWithKeyRequirements { - plan: requirements.plan.clone(), - required_key_ordering: vec![], - request_key_ordering: vec![Some(parent_required.clone()); children_len], - }) - }; - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(requirements) - }) -} - -fn reorder_partitioned_join_keys( - join_plan: Arc, - parent_required: &[Arc], - on: &[(Column, Column)], - sort_options: Vec, - join_constructor: &F, -) -> Result -where - F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, -{ - let join_key_pairs = extract_join_keys(on); - if let Some(( - JoinKeyPairs { - left_keys, - right_keys, - }, - new_positions, - )) = try_reorder( - join_key_pairs.clone(), - parent_required, - &join_plan.equivalence_properties(), - ) { - if !new_positions.is_empty() { - let new_join_on = new_join_conditions(&left_keys, &right_keys); - let mut new_sort_options: Vec = vec![]; - for idx in 0..sort_options.len() { - new_sort_options.push(sort_options[new_positions[idx]]) - } - - Ok(PlanWithKeyRequirements { - plan: join_constructor((new_join_on, new_sort_options))?, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) - } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![Some(left_keys), Some(right_keys)], - }) - } - } else { - Ok(PlanWithKeyRequirements { - plan: join_plan, - required_key_ordering: vec![], - request_key_ordering: vec![ - Some(join_key_pairs.left_keys), - Some(join_key_pairs.right_keys), - ], - }) - } -} - -fn reorder_aggregate_keys( - agg_plan: Arc, - parent_required: &[Arc], - agg_exec: &AggregateExec, -) -> Result { - let out_put_columns = agg_exec - .group_by - .expr() - .iter() - .enumerate() - .map(|(index, (_col, name))| Column::new(name, index)) - .collect::>(); - - let out_put_exprs = out_put_columns - .iter() - .map(|c| Arc::new(c.clone()) as Arc) - .collect::>(); - - if parent_required.len() != out_put_exprs.len() - || !agg_exec.group_by.null_expr().is_empty() - || expr_list_eq_strict_order(&out_put_exprs, parent_required) - { - Ok(PlanWithKeyRequirements::new(agg_plan)) - } else { - let new_positions = expected_expr_positions(&out_put_exprs, parent_required); - match new_positions { - None => Ok(PlanWithKeyRequirements::new(agg_plan)), - Some(positions) => { - let new_partial_agg = if let Some(AggregateExec { - mode, - group_by, - aggr_expr, - filter_expr, - order_by_expr, - input, - input_schema, - .. - }) = - agg_exec.input.as_any().downcast_ref::() - { - if matches!(mode, AggregateMode::Partial) { - let mut new_group_exprs = vec![]; - for idx in positions.iter() { - new_group_exprs.push(group_by.expr()[*idx].clone()); - } - let new_partial_group_by = - PhysicalGroupBy::new_single(new_group_exprs); - // new Partial AggregateExec - Some(Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - new_partial_group_by, - aggr_expr.clone(), - filter_expr.clone(), - order_by_expr.clone(), - input.clone(), - input_schema.clone(), - )?)) - } else { - None - } - } else { - None - }; - if let Some(partial_agg) = new_partial_agg { - // Build new group expressions that correspond to the output of partial_agg - let new_final_group: Vec> = - partial_agg.output_group_expr(); - let new_group_by = PhysicalGroupBy::new_single( - new_final_group - .iter() - .enumerate() - .map(|(i, expr)| { - ( - expr.clone(), - partial_agg.group_expr().expr()[i].1.clone(), - ) - }) - .collect(), - ); - - let new_final_agg = Arc::new(AggregateExec::try_new( - AggregateMode::FinalPartitioned, - new_group_by, - agg_exec.aggr_expr.to_vec(), - agg_exec.filter_expr.to_vec(), - agg_exec.order_by_expr.to_vec(), - partial_agg, - agg_exec.input_schema.clone(), - )?); - - // Need to create a new projection to change the expr ordering back - let mut proj_exprs = out_put_columns - .iter() - .map(|col| { - ( - Arc::new(Column::new( - col.name(), - new_final_agg.schema().index_of(col.name()).unwrap(), - )) - as Arc, - col.name().to_owned(), - ) - }) - .collect::>(); - let agg_schema = new_final_agg.schema(); - let agg_fields = agg_schema.fields(); - for (idx, field) in - agg_fields.iter().enumerate().skip(out_put_columns.len()) - { - proj_exprs.push(( - Arc::new(Column::new(field.name().as_str(), idx)) - as Arc, - field.name().clone(), - )) - } - // TODO merge adjacent Projections if there are - Ok(PlanWithKeyRequirements::new(Arc::new( - ProjectionExec::try_new(proj_exprs, new_final_agg)?, - ))) - } else { - Ok(PlanWithKeyRequirements::new(agg_plan)) - } - } - } - } -} - -fn shift_right_required( - parent_required: &[Arc], - left_columns_len: usize, -) -> Option>> { - let new_right_required: Vec> = parent_required - .iter() - .filter_map(|r| { - if let Some(col) = r.as_any().downcast_ref::() { - if col.index() >= left_columns_len { - Some( - Arc::new(Column::new(col.name(), col.index() - left_columns_len)) - as Arc, - ) - } else { - None - } - } else { - None - } - }) - .collect::>(); - - // if the parent required are all comming from the right side, the requirements can be pushdown - if new_right_required.len() != parent_required.len() { - None - } else { - Some(new_right_required) - } -} - -/// When the physical planner creates the Joins, the ordering of join keys is from the original query. -/// That might not match with the output partitioning of the join node's children -/// This method will try to change the ordering of the join keys to match with the -/// partitioning of the join nodes' children. If it can not match with both sides, it will try to -/// match with one, either the left side or the right side. -/// -/// Example: -/// TopJoin on (a, b, c) -/// bottom left join on(b, a, c) -/// bottom right join on(c, b, a) -/// -/// Will be adjusted to: -/// TopJoin on (b, a, c) -/// bottom left join on(b, a, c) -/// bottom right join on(c, b, a) -/// -/// Compared to the Top-Down reordering process, this Bottom-Up approach is much simpler, but might not reach a best result. -/// The Bottom-Up approach will be useful in future if we plan to support storage partition-wised Joins. -/// In that case, the datasources/tables might be pre-partitioned and we can't adjust the key ordering of the datasources -/// and then can't apply the Top-Down reordering process. -fn reorder_join_keys_to_inputs( - plan: Arc, -) -> Result> { - let plan_any = plan.as_any(); - if let Some(HashJoinExec { - left, - right, - on, - filter, - join_type, - mode, - null_equals_null, - .. - }) = plan_any.downcast_ref::() - { - match mode { - PartitionMode::Partitioned => { - let join_key_pairs = extract_join_keys(on); - if let Some(( - JoinKeyPairs { - left_keys, - right_keys, - }, - new_positions, - )) = reorder_current_join_keys( - join_key_pairs, - Some(left.output_partitioning()), - Some(right.output_partitioning()), - &left.equivalence_properties(), - &right.equivalence_properties(), - ) { - if !new_positions.is_empty() { - let new_join_on = new_join_conditions(&left_keys, &right_keys); - Ok(Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - new_join_on, - filter.clone(), - join_type, - PartitionMode::Partitioned, - *null_equals_null, - )?)) - } else { - Ok(plan) - } - } else { - Ok(plan) - } - } - _ => Ok(plan), - } - } else if let Some(SortMergeJoinExec { - left, - right, - on, - join_type, - sort_options, - null_equals_null, - .. - }) = plan_any.downcast_ref::() - { - let join_key_pairs = extract_join_keys(on); - if let Some(( - JoinKeyPairs { - left_keys, - right_keys, - }, - new_positions, - )) = reorder_current_join_keys( - join_key_pairs, - Some(left.output_partitioning()), - Some(right.output_partitioning()), - &left.equivalence_properties(), - &right.equivalence_properties(), - ) { - if !new_positions.is_empty() { - let new_join_on = new_join_conditions(&left_keys, &right_keys); - let mut new_sort_options = vec![]; - for idx in 0..sort_options.len() { - new_sort_options.push(sort_options[new_positions[idx]]) - } - Ok(Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right.clone(), - new_join_on, - *join_type, - new_sort_options, - *null_equals_null, - )?)) - } else { - Ok(plan) - } - } else { - Ok(plan) - } - } else { - Ok(plan) - } -} - -/// Reorder the current join keys ordering based on either left partition or right partition -fn reorder_current_join_keys( - join_keys: JoinKeyPairs, - left_partition: Option, - right_partition: Option, - left_equivalence_properties: &EquivalenceProperties, - right_equivalence_properties: &EquivalenceProperties, -) -> Option<(JoinKeyPairs, Vec)> { - match (left_partition, right_partition.clone()) { - (Some(Partitioning::Hash(left_exprs, _)), _) => { - try_reorder(join_keys.clone(), &left_exprs, left_equivalence_properties) - .or_else(|| { - reorder_current_join_keys( - join_keys, - None, - right_partition, - left_equivalence_properties, - right_equivalence_properties, - ) - }) - } - (_, Some(Partitioning::Hash(right_exprs, _))) => { - try_reorder(join_keys, &right_exprs, right_equivalence_properties) - } - _ => None, - } -} - -fn try_reorder( - join_keys: JoinKeyPairs, - expected: &[Arc], - equivalence_properties: &EquivalenceProperties, -) -> Option<(JoinKeyPairs, Vec)> { - let mut normalized_expected = vec![]; - let mut normalized_left_keys = vec![]; - let mut normalized_right_keys = vec![]; - if join_keys.left_keys.len() != expected.len() { - return None; - } - if expr_list_eq_strict_order(expected, &join_keys.left_keys) - || expr_list_eq_strict_order(expected, &join_keys.right_keys) - { - return Some((join_keys, vec![])); - } else if !equivalence_properties.classes().is_empty() { - normalized_expected = expected - .iter() - .map(|e| { - normalize_expr_with_equivalence_properties( - e.clone(), - equivalence_properties.classes(), - ) - }) - .collect::>(); - assert_eq!(normalized_expected.len(), expected.len()); - - normalized_left_keys = join_keys - .left_keys - .iter() - .map(|e| { - normalize_expr_with_equivalence_properties( - e.clone(), - equivalence_properties.classes(), - ) - }) - .collect::>(); - assert_eq!(join_keys.left_keys.len(), normalized_left_keys.len()); - - normalized_right_keys = join_keys - .right_keys - .iter() - .map(|e| { - normalize_expr_with_equivalence_properties( - e.clone(), - equivalence_properties.classes(), - ) - }) - .collect::>(); - assert_eq!(join_keys.right_keys.len(), normalized_right_keys.len()); - - if expr_list_eq_strict_order(&normalized_expected, &normalized_left_keys) - || expr_list_eq_strict_order(&normalized_expected, &normalized_right_keys) - { - return Some((join_keys, vec![])); - } - } - - let new_positions = expected_expr_positions(&join_keys.left_keys, expected) - .or_else(|| expected_expr_positions(&join_keys.right_keys, expected)) - .or_else(|| expected_expr_positions(&normalized_left_keys, &normalized_expected)) - .or_else(|| { - expected_expr_positions(&normalized_right_keys, &normalized_expected) - }); - - if let Some(positions) = new_positions { - let mut new_left_keys = vec![]; - let mut new_right_keys = vec![]; - for pos in positions.iter() { - new_left_keys.push(join_keys.left_keys[*pos].clone()); - new_right_keys.push(join_keys.right_keys[*pos].clone()); - } - Some(( - JoinKeyPairs { - left_keys: new_left_keys, - right_keys: new_right_keys, - }, - positions, - )) - } else { - None - } -} - -/// Return the expected expressions positions. -/// For example, the current expressions are ['c', 'a', 'a', b'], the expected expressions are ['b', 'c', 'a', 'a'], -/// -/// This method will return a Vec [3, 0, 1, 2] -fn expected_expr_positions( - current: &[Arc], - expected: &[Arc], -) -> Option> { - if current.is_empty() || expected.is_empty() { - return None; - } - let mut indexes: Vec = vec![]; - let mut current = current.to_vec(); - for expr in expected.iter() { - // Find the position of the expected expr in the current expressions - if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) { - current[expected_position] = Arc::new(NoOp::new()); - indexes.push(expected_position); - } else { - return None; - } - } - Some(indexes) -} - -fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { - let (left_keys, right_keys) = on - .iter() - .map(|(l, r)| { - ( - Arc::new(l.clone()) as Arc, - Arc::new(r.clone()) as Arc, - ) - }) - .unzip(); - JoinKeyPairs { - left_keys, - right_keys, - } -} - -fn new_join_conditions( - new_left_keys: &[Arc], - new_right_keys: &[Arc], -) -> Vec<(Column, Column)> { - let new_join_on = new_left_keys - .iter() - .zip(new_right_keys.iter()) - .map(|(l_key, r_key)| { - ( - l_key.as_any().downcast_ref::().unwrap().clone(), - r_key.as_any().downcast_ref::().unwrap().clone(), - ) - }) - .collect::>(); - new_join_on -} - -/// This function checks whether we need to add additional data exchange -/// operators to satisfy distribution requirements. Since this function -/// takes care of such requirements, we should avoid manually adding data -/// exchange operators in other places. -fn ensure_distribution( - plan: Arc, - target_partitions: usize, -) -> Result>> { - if plan.children().is_empty() { - return Ok(Transformed::No(plan)); - } - - // special case for UnionExec: We want to "bubble up" hash-partitioned data. So instead of: - // - // Agg: - // Repartition (hash): - // Union: - // - Agg: - // Repartition (hash): - // Data - // - Agg: - // Repartition (hash): - // Data - // - // We can use: - // - // Agg: - // Interleave: - // - Agg: - // Repartition (hash): - // Data - // - Agg: - // Repartition (hash): - // Data - if let Some(union_exec) = plan.as_any().downcast_ref::() { - if can_interleave(union_exec.inputs()) { - let plan = InterleaveExec::try_new(union_exec.inputs().clone())?; - return Ok(Transformed::Yes(Arc::new(plan))); - } - } - - let required_input_distributions = plan.required_input_distribution(); - let children: Vec> = plan.children(); - assert_eq!(children.len(), required_input_distributions.len()); - - // Add RepartitionExec to guarantee output partitioning - let new_children: Result>> = children - .into_iter() - .zip(required_input_distributions.into_iter()) - .map(|(child, required)| { - if child - .output_partitioning() - .satisfy(required.clone(), || child.equivalence_properties()) - { - Ok(child) - } else { - let new_child: Result> = match required { - Distribution::SinglePartition - if child.output_partitioning().partition_count() > 1 => - { - Ok(Arc::new(CoalescePartitionsExec::new(child.clone()))) - } - _ => { - let partition = required.create_partitioning(target_partitions); - Ok(Arc::new(RepartitionExec::try_new(child, partition)?)) - } - }; - new_child - } - }) - .collect(); - with_new_children_if_necessary(plan, new_children?) -} - -#[derive(Debug, Clone)] -struct JoinKeyPairs { - left_keys: Vec>, - right_keys: Vec>, -} - -#[derive(Debug, Clone)] -struct PlanWithKeyRequirements { - plan: Arc, - /// Parent required key ordering - required_key_ordering: Vec>, - /// The request key ordering to children - request_key_ordering: Vec>>>, -} - -impl PlanWithKeyRequirements { - pub fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithKeyRequirements { - plan, - required_key_ordering: vec![], - request_key_ordering: vec![None; children_len], - } - } - - pub fn children(&self) -> Vec { - let plan_children = self.plan.children(); - assert_eq!(plan_children.len(), self.request_key_ordering.len()); - plan_children - .into_iter() - .zip(self.request_key_ordering.clone().into_iter()) - .map(|(child, required)| { - let from_parent = required.unwrap_or_default(); - let length = child.children().len(); - PlanWithKeyRequirements { - plan: child, - required_key_ordering: from_parent.clone(), - request_key_ordering: vec![None; length], - } - }) - .collect() - } -} - -impl TreeNode for PlanWithKeyRequirements { - fn apply_children(&self, op: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let children = self.children(); - for child in children { - match op(&child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } - } - - Ok(VisitRecursion::Continue) - } - - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - - let children_plans = new_children? - .into_iter() - .map(|child| child.plan) - .collect::>(); - let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithKeyRequirements { - plan: new_plan.into(), - required_key_ordering: self.required_key_ordering, - request_key_ordering: self.request_key_ordering, - }) - } else { - Ok(self) - } - } -} - -#[cfg(test)] -mod tests { - use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_expr::logical_plan::JoinType; - use datafusion_expr::Operator; - use datafusion_physical_expr::{ - expressions, expressions::binary, expressions::lit, expressions::Column, - PhysicalExpr, PhysicalSortExpr, - }; - use std::ops::Deref; - - use super::*; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; - use crate::physical_optimizer::sort_enforcement::EnforceSorting; - use crate::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, - }; - use crate::physical_plan::expressions::col; - use crate::physical_plan::joins::{ - utils::JoinOn, HashJoinExec, PartitionMode, SortMergeJoinExec, - }; - use crate::physical_plan::projection::ProjectionExec; - use crate::physical_plan::{displayable, Statistics}; - - fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Boolean, true), - ])) - } - - fn parquet_exec() -> Arc { - parquet_exec_with_sort(vec![]) - } - - fn parquet_exec_with_sort( - output_ordering: Vec>, - ) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - infinite_source: false, - }, - None, - None, - )) - } - - // Created a sorted parquet exec with multiple files - fn parquet_exec_multiple_sorted( - output_ordering: Vec>, - ) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![ - vec![PartitionedFile::new("x".to_string(), 100)], - vec![PartitionedFile::new("y".to_string(), 100)], - ], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - infinite_source: false, - }, - None, - None, - )) - } - - fn projection_exec_with_alias( - input: Arc, - alias_pairs: Vec<(String, String)>, - ) -> Arc { - let mut exprs = vec![]; - for (column, alias) in alias_pairs.iter() { - exprs.push((col(column, &input.schema()).unwrap(), alias.to_string())); - } - Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) - } - - fn aggregate_exec_with_alias( - input: Arc, - alias_pairs: Vec<(String, String)>, - ) -> Arc { - let schema = schema(); - let mut group_by_expr: Vec<(Arc, String)> = vec![]; - for (column, alias) in alias_pairs.iter() { - group_by_expr - .push((col(column, &input.schema()).unwrap(), alias.to_string())); - } - let group_by = PhysicalGroupBy::new_single(group_by_expr.clone()); - - let final_group_by_expr = group_by_expr - .iter() - .enumerate() - .map(|(index, (_col, name))| { - ( - Arc::new(expressions::Column::new(name, index)) - as Arc, - name.clone(), - ) - }) - .collect::>(); - let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); - - Arc::new( - AggregateExec::try_new( - AggregateMode::FinalPartitioned, - final_grouping, - vec![], - vec![], - vec![], - Arc::new( - AggregateExec::try_new( - AggregateMode::Partial, - group_by, - vec![], - vec![], - vec![], - input, - schema.clone(), - ) - .unwrap(), - ), - schema, - ) - .unwrap(), - ) - } - - fn hash_join_exec( - left: Arc, - right: Arc, - join_on: &JoinOn, - join_type: &JoinType, - ) -> Arc { - Arc::new( - HashJoinExec::try_new( - left, - right, - join_on.clone(), - None, - join_type, - PartitionMode::Partitioned, - false, - ) - .unwrap(), - ) - } - - fn sort_merge_join_exec( - left: Arc, - right: Arc, - join_on: &JoinOn, - join_type: &JoinType, - ) -> Arc { - Arc::new( - SortMergeJoinExec::try_new( - left, - right, - join_on.clone(), - *join_type, - vec![SortOptions::default(); join_on.len()], - false, - ) - .unwrap(), - ) - } - - fn trim_plan_display(plan: &str) -> Vec<&str> { - plan.split('\n') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .collect() - } - - /// Runs the repartition optimizer and asserts the plan against the expected - macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - - let mut config = ConfigOptions::new(); - config.execution.target_partitions = 10; - - // run optimizer - let optimizer = EnforceDistribution {}; - let optimized = optimizer.optimize($PLAN, &config)?; - // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade - // because they were written prior to the separation of `BasicEnforcement` into - // `EnforceSorting` and `EnfoceDistribution`. - // TODO: Orthogonalize the tests here just to verify `EnforceDistribution` and create - // new tests for the cascade. - let optimizer = EnforceSorting::new(); - let optimized = optimizer.optimize(optimized, &config)?; - - // Now format correctly - let plan = displayable(optimized.as_ref()).indent().to_string(); - let actual_lines = trim_plan_display(&plan); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; - } - - macro_rules! assert_plan_txt { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - // Now format correctly - let plan = displayable($PLAN.as_ref()).indent().to_string(); - let actual_lines = trim_plan_display(&plan); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; - } - - #[test] - fn multi_hash_joins() -> Result<()> { - let left = parquet_exec(); - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ("c".to_string(), "c1".to_string()), - ("d".to_string(), "d1".to_string()), - ("e".to_string(), "e1".to_string()), - ]; - let right = projection_exec_with_alias(parquet_exec(), alias_pairs); - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightSemi, - JoinType::RightAnti, - ]; - - // Join on (a == b1) - let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - )]; - - for join_type in join_types { - let join = hash_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]"); - - match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::LeftSemi - | JoinType::LeftAnti => { - // Join on (a == c) - let top_join_on = vec![( - Column::new_with_schema("a", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - let top_join = hash_join_exec( - join.clone(), - parquet_exec(), - &top_join_on, - &join_type, - ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]"); - - let expected = match join_type { - // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ - top_join_plan.as_str(), - join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - // Should include 4 RepartitionExecs - _ => vec![ - top_join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=10", - join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - }; - assert_optimized!(expected, top_join); - } - JoinType::RightSemi | JoinType::RightAnti => {} - } - - match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::RightSemi - | JoinType::RightAnti => { - // This time we use (b1 == c) for top join - // Join on (b1 == c) - let top_join_on = vec![( - Column::new_with_schema("b1", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - - let top_join = - hash_join_exec(join, parquet_exec(), &top_join_on, &join_type); - let top_join_plan = match join_type { - JoinType::RightSemi | JoinType::RightAnti => - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(Column {{ name: \"b1\", index: 1 }}, Column {{ name: \"c\", index: 2 }})]"), - _ => - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]"), - }; - - let expected = match join_type { - // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => - vec![ - top_join_plan.as_str(), - join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - // Should include 4 RepartitionExecs - _ => - vec![ - top_join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10), input_partitions=10", - join_plan.as_str(), - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - }; - assert_optimized!(expected, top_join); - } - JoinType::LeftSemi | JoinType::LeftAnti => {} - } - } - - Ok(()) - } - - #[test] - fn multi_joins_after_alias() -> Result<()> { - let left = parquet_exec(); - let right = parquet_exec(); - - // Join on (a == b) - let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b", &schema()).unwrap(), - )]; - let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Projection(a as a1, a as a2) - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("a".to_string(), "a2".to_string()), - ]; - let projection = projection_exec_with_alias(join, alias_pairs); - - // Join on (a1 == c) - let top_join_on = vec![( - Column::new_with_schema("a1", &projection.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - - let top_join = hash_join_exec( - projection.clone(), - right.clone(), - &top_join_on, - &JoinType::Inner, - ); - - // Output partition need to respect the Alias and should not introduce additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"c\", index: 2 })]", - "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, top_join); - - // Join on (a2 == c) - let top_join_on = vec![( - Column::new_with_schema("a2", &projection.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - - let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner); - - // Output partition need to respect the Alias and should not introduce additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a2\", index: 1 }, Column { name: \"c\", index: 2 })]", - "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - - assert_optimized!(expected, top_join); - Ok(()) - } - - #[test] - fn multi_joins_after_multi_alias() -> Result<()> { - let left = parquet_exec(); - let right = parquet_exec(); - - // Join on (a == b) - let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b", &schema()).unwrap(), - )]; - - let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Projection(c as c1) - let alias_pairs: Vec<(String, String)> = - vec![("c".to_string(), "c1".to_string())]; - let projection = projection_exec_with_alias(join, alias_pairs); - - // Projection(c1 as a) - let alias_pairs: Vec<(String, String)> = - vec![("c1".to_string(), "a".to_string())]; - let projection2 = projection_exec_with_alias(projection, alias_pairs); - - // Join on (a == c) - let top_join_on = vec![( - Column::new_with_schema("a", &projection2.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - - let top_join = hash_join_exec(projection2, right, &top_join_on, &JoinType::Inner); - - // The Column 'a' has different meaning now after the two Projections - // The original Output partition can not satisfy the Join requirements and need to add an additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"c\", index: 2 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=10", - "ProjectionExec: expr=[c1@0 as a]", - "ProjectionExec: expr=[c@2 as c1]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - - assert_optimized!(expected, top_join); - Ok(()) - } - - #[test] - fn join_after_agg_alias() -> Result<()> { - // group by (a as a1) - let left = aggregate_exec_with_alias( - parquet_exec(), - vec![("a".to_string(), "a1".to_string())], - ); - // group by (a as a2) - let right = aggregate_exec_with_alias( - parquet_exec(), - vec![("a".to_string(), "a2".to_string())], - ); - - // Join on (a1 == a2) - let join_on = vec![( - Column::new_with_schema("a1", &left.schema()).unwrap(), - Column::new_with_schema("a2", &right.schema()).unwrap(), - )]; - let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Only two RepartitionExecs added - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"a2\", index: 0 })]", - "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 0 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, join); - Ok(()) - } - - #[test] - fn hash_join_key_ordering() -> Result<()> { - // group by (a as a1, b as b1) - let left = aggregate_exec_with_alias( - parquet_exec(), - vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ], - ); - // group by (b, a) - let right = aggregate_exec_with_alias( - parquet_exec(), - vec![ - ("b".to_string(), "b".to_string()), - ("a".to_string(), "a".to_string()), - ], - ); - - // Join on (b1 == b && a1 == a) - let join_on = vec![ - ( - Column::new_with_schema("b1", &left.schema()).unwrap(), - Column::new_with_schema("b", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("a1", &left.schema()).unwrap(), - Column::new_with_schema("a", &right.schema()).unwrap(), - ), - ]; - let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Only two RepartitionExecs added - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]", - "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, join); - Ok(()) - } - - #[test] - fn multi_hash_join_key_ordering() -> Result<()> { - let left = parquet_exec(); - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ("c".to_string(), "c1".to_string()), - ]; - let right = projection_exec_with_alias(parquet_exec(), alias_pairs); - - // Join on (a == a1 and b == b1 and c == c1) - let join_on = vec![ - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), - ), - ]; - let bottom_left_join = - hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner); - - // Projection(a as A, a as AA, b as B, c as C) - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "A".to_string()), - ("a".to_string(), "AA".to_string()), - ("b".to_string(), "B".to_string()), - ("c".to_string(), "C".to_string()), - ]; - let bottom_left_projection = - projection_exec_with_alias(bottom_left_join, alias_pairs); - - // Join on (c == c1 and b == b1 and a == a1) - let join_on = vec![ - ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ]; - let bottom_right_join = - hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Join on (B == b1 and C == c and AA = a1) - let top_join_on = vec![ - ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), - ), - ]; - - let top_join = hash_join_exec( - bottom_left_projection.clone(), - bottom_right_join, - &top_join_on, - &JoinType::Inner, - ); - - let predicate: Arc = binary( - col("c", top_join.schema().deref())?, - Operator::Gt, - lit(1i64), - top_join.schema().deref(), - )?; - - let filter_top_join: Arc = - Arc::new(FilterExec::try_new(predicate, top_join)?); - - // The bottom joins' join key ordering is adjusted based on the top join. And the top join should not introduce additional RepartitionExec - let expected = &[ - "FilterExec: c@6 > 1", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"B\", index: 2 }, Column { name: \"b1\", index: 6 }), (Column { name: \"C\", index: 3 }, Column { name: \"c\", index: 2 }), (Column { name: \"AA\", index: 1 }, Column { name: \"a1\", index: 5 })]", - "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, filter_top_join); - Ok(()) - } - - #[test] - fn reorder_join_keys_to_left_input() -> Result<()> { - let left = parquet_exec(); - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ("c".to_string(), "c1".to_string()), - ]; - let right = projection_exec_with_alias(parquet_exec(), alias_pairs); - - // Join on (a == a1 and b == b1 and c == c1) - let join_on = vec![ - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), - ), - ]; - let bottom_left_join = ensure_distribution( - hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), - 10, - )? - .into(); - - // Projection(a as A, a as AA, b as B, c as C) - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "A".to_string()), - ("a".to_string(), "AA".to_string()), - ("b".to_string(), "B".to_string()), - ("c".to_string(), "C".to_string()), - ]; - let bottom_left_projection = - projection_exec_with_alias(bottom_left_join, alias_pairs); - - // Join on (c == c1 and b == b1 and a == a1) - let join_on = vec![ - ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ]; - let bottom_right_join = ensure_distribution( - hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), - 10, - )? - .into(); - - // Join on (B == b1 and C == c and AA = a1) - let top_join_on = vec![ - ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), - ), - ]; - - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightSemi, - JoinType::RightAnti, - ]; - - for join_type in join_types { - let top_join = hash_join_exec( - bottom_left_projection.clone(), - bottom_right_join.clone(), - &top_join_on, - &join_type, - ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(Column {{ name: \"AA\", index: 1 }}, Column {{ name: \"a1\", index: 5 }}), (Column {{ name: \"B\", index: 2 }}, Column {{ name: \"b1\", index: 6 }}), (Column {{ name: \"C\", index: 3 }}, Column {{ name: \"c\", index: 2 }})]", &join_type); - - let reordered = reorder_join_keys_to_inputs(top_join)?; - - // The top joins' join key ordering is adjusted based on the children inputs. - let expected = &[ - top_join_plan.as_str(), - "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }, Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }, Column { name: \"b\", index: 1 }, Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 2 }, Column { name: \"b1\", index: 1 }, Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - - assert_plan_txt!(expected, reordered); - } - - Ok(()) - } - - #[test] - fn reorder_join_keys_to_right_input() -> Result<()> { - let left = parquet_exec(); - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ("c".to_string(), "c1".to_string()), - ]; - let right = projection_exec_with_alias(parquet_exec(), alias_pairs); - - // Join on (a == a1 and b == b1) - let join_on = vec![ - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ]; - let bottom_left_join = ensure_distribution( - hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), - 10, - )? - .into(); - - // Projection(a as A, a as AA, b as B, c as C) - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "A".to_string()), - ("a".to_string(), "AA".to_string()), - ("b".to_string(), "B".to_string()), - ("c".to_string(), "C".to_string()), - ]; - let bottom_left_projection = - projection_exec_with_alias(bottom_left_join, alias_pairs); - - // Join on (c == c1 and b == b1 and a == a1) - let join_on = vec![ - ( - Column::new_with_schema("c", &schema()).unwrap(), - Column::new_with_schema("c1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("b", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("a1", &right.schema()).unwrap(), - ), - ]; - let bottom_right_join = ensure_distribution( - hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), - 10, - )? - .into(); - - // Join on (B == b1 and C == c and AA = a1) - let top_join_on = vec![ - ( - Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), - ), - ( - Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), - Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), - ), - ]; - - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightSemi, - JoinType::RightAnti, - ]; - - for join_type in join_types { - let top_join = hash_join_exec( - bottom_left_projection.clone(), - bottom_right_join.clone(), - &top_join_on, - &join_type, - ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(Column {{ name: \"C\", index: 3 }}, Column {{ name: \"c\", index: 2 }}), (Column {{ name: \"B\", index: 2 }}, Column {{ name: \"b1\", index: 6 }}), (Column {{ name: \"AA\", index: 1 }}, Column {{ name: \"a1\", index: 5 }})]", &join_type); - - let reordered = reorder_join_keys_to_inputs(top_join)?; - - // The top joins' join key ordering is adjusted based on the children inputs. - let expected = &[ - top_join_plan.as_str(), - "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }, Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }, Column { name: \"b\", index: 1 }, Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 2 }, Column { name: \"b1\", index: 1 }, Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - - assert_plan_txt!(expected, reordered); - } - - Ok(()) - } - - #[test] - fn multi_smj_joins() -> Result<()> { - let left = parquet_exec(); - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ("c".to_string(), "c1".to_string()), - ("d".to_string(), "d1".to_string()), - ("e".to_string(), "e1".to_string()), - ]; - let right = projection_exec_with_alias(parquet_exec(), alias_pairs); - - // SortMergeJoin does not support RightSemi and RightAnti join now - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; - - // Join on (a == b1) - let join_on = vec![( - Column::new_with_schema("a", &schema()).unwrap(), - Column::new_with_schema("b1", &right.schema()).unwrap(), - )]; - - for join_type in join_types { - let join = - sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]"); - - // Top join on (a == c) - let top_join_on = vec![( - Column::new_with_schema("a", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - let top_join = sort_merge_join_exec( - join.clone(), - parquet_exec(), - &top_join_on, - &join_type, - ); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]"); - - let expected = match join_type { - // Should include 3 RepartitionExecs 3 SortExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => - vec![ - top_join_plan.as_str(), - join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - // Should include 4 RepartitionExecs - _ => vec![ - top_join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=10", - join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - }; - assert_optimized!(expected, top_join); - - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - // This time we use (b1 == c) for top join - // Join on (b1 == c) - let top_join_on = vec![( - Column::new_with_schema("b1", &join.schema()).unwrap(), - Column::new_with_schema("c", &schema()).unwrap(), - )]; - let top_join = sort_merge_join_exec( - join, - parquet_exec(), - &top_join_on, - &join_type, - ); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]"); - - let expected = match join_type { - // Should include 3 RepartitionExecs and 3 SortExecs - JoinType::Inner | JoinType::Right => vec![ - top_join_plan.as_str(), - join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - // Should include 4 RepartitionExecs and 4 SortExecs - _ => vec![ - top_join_plan.as_str(), - "SortExec: expr=[b1@6 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10), input_partitions=10", - join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10), input_partitions=1", - "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", - "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ], - }; - assert_optimized!(expected, top_join); - } - _ => {} - } - } - - Ok(()) - } - - #[test] - fn smj_join_key_ordering() -> Result<()> { - // group by (a as a1, b as b1) - let left = aggregate_exec_with_alias( - parquet_exec(), - vec![ - ("a".to_string(), "a1".to_string()), - ("b".to_string(), "b1".to_string()), - ], - ); - //Projection(a1 as a3, b1 as b3) - let alias_pairs: Vec<(String, String)> = vec![ - ("a1".to_string(), "a3".to_string()), - ("b1".to_string(), "b3".to_string()), - ]; - let left = projection_exec_with_alias(left, alias_pairs); - - // group by (b, a) - let right = aggregate_exec_with_alias( - parquet_exec(), - vec![ - ("b".to_string(), "b".to_string()), - ("a".to_string(), "a".to_string()), - ], - ); - - //Projection(a as a2, b as b2) - let alias_pairs: Vec<(String, String)> = vec![ - ("a".to_string(), "a2".to_string()), - ("b".to_string(), "b2".to_string()), - ]; - let right = projection_exec_with_alias(right, alias_pairs); - - // Join on (b3 == b2 && a3 == a2) - let join_on = vec![ - ( - Column::new_with_schema("b3", &left.schema()).unwrap(), - Column::new_with_schema("b2", &right.schema()).unwrap(), - ), - ( - Column::new_with_schema("a3", &left.schema()).unwrap(), - Column::new_with_schema("a2", &right.schema()).unwrap(), - ), - ]; - let join = sort_merge_join_exec(left, right.clone(), &join_on, &JoinType::Inner); - - // Only two RepartitionExecs added - let expected = &[ - "SortMergeJoin: join_type=Inner, on=[(Column { name: \"b3\", index: 1 }, Column { name: \"b2\", index: 1 }), (Column { name: \"a3\", index: 0 }, Column { name: \"a2\", index: 0 })]", - "SortExec: expr=[b3@1 ASC,a3@0 ASC]", - "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", - "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b2@1 ASC,a2@0 ASC]", - "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", - "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, join); - Ok(()) - } - - #[test] - fn merge_does_not_need_sort() -> Result<()> { - // see https://github.com/apache/arrow-datafusion/issues/4331 - let schema = schema(); - let sort_key = vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), - options: SortOptions::default(), - }]; - - // Scan some sorted parquet files - let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); - - // CoalesceBatchesExec to mimic behavior after a filter - let exec = Arc::new(CoalesceBatchesExec::new(exec, 4096)); - - // Merge from multiple parquet files and keep the data sorted - let exec = Arc::new(SortPreservingMergeExec::new(sort_key, exec)); - - // The optimizer should not add an additional SortExec as the - // data is already sorted - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - "CoalesceBatchesExec: target_batch_size=4096", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", - ]; - assert_optimized!(expected, exec); - Ok(()) - } - - #[test] - fn union_to_interleave() -> Result<()> { - // group by (a as a1) - let left = aggregate_exec_with_alias( - parquet_exec(), - vec![("a".to_string(), "a1".to_string())], - ); - // group by (a as a2) - let right = aggregate_exec_with_alias( - parquet_exec(), - vec![("a".to_string(), "a1".to_string())], - ); - - // Union - let plan = Arc::new(UnionExec::new(vec![left, right])); - - // final agg - let plan = - aggregate_exec_with_alias(plan, vec![("a1".to_string(), "a2".to_string())]); - - // Only two RepartitionExecs added, no final RepartionExec required - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", - "AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[]", - "InterleaveExec", - "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }], 10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - ]; - assert_optimized!(expected, plan); - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs new file mode 100644 index 0000000000000..4befea741c8c8 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -0,0 +1,4649 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EnforceDistribution optimizer rule inspects the physical plan with respect +//! to distribution requirements and adds [`RepartitionExec`]s to satisfy them +//! when necessary. If increasing parallelism is beneficial (and also desirable +//! according to the configuration), this rule increases partition counts in +//! the physical plan. + +use std::fmt; +use std::fmt::Formatter; +use std::sync::Arc; + +use crate::config::ConfigOptions; +use crate::error::Result; +use crate::physical_optimizer::utils::{ + add_sort_above, get_children_exectrees, is_coalesce_partitions, is_repartition, + is_sort_preserving_merge, ExecTree, +}; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, +}; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::union::{can_interleave, InterleaveExec, UnionExec}; +use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, Partitioning, +}; + +use arrow::compute::SortOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_expr::logical_plan::JoinType; +use datafusion_physical_expr::expressions::{Column, NoOp}; +use datafusion_physical_expr::utils::map_columns_before_projection; +use datafusion_physical_expr::{ + physical_exprs_equal, EquivalenceProperties, PhysicalExpr, +}; +use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; +use datafusion_physical_plan::{get_plan_string, unbounded_output}; + +use itertools::izip; + +/// The `EnforceDistribution` rule ensures that distribution requirements are +/// met. In doing so, this rule will increase the parallelism in the plan by +/// introducing repartitioning operators to the physical plan. +/// +/// For example, given an input such as: +/// +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// │ │ +/// │ ExecutionPlan │ +/// │ │ +/// └─────────────────────────────────┘ +/// ▲ ▲ +/// │ │ +/// ┌─────┘ └─────┐ +/// │ │ +/// │ │ +/// │ │ +/// ┌───────────┐ ┌───────────┐ +/// │ │ │ │ +/// │ batch A1 │ │ batch B1 │ +/// │ │ │ │ +/// ├───────────┤ ├───────────┤ +/// │ │ │ │ +/// │ batch A2 │ │ batch B2 │ +/// │ │ │ │ +/// ├───────────┤ ├───────────┤ +/// │ │ │ │ +/// │ batch A3 │ │ batch B3 │ +/// │ │ │ │ +/// └───────────┘ └───────────┘ +/// +/// Input Input +/// A B +/// ``` +/// +/// This rule will attempt to add a `RepartitionExec` to increase parallelism +/// (to 3, in this case) and create the following arrangement: +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// │ │ +/// │ ExecutionPlan │ +/// │ │ +/// └─────────────────────────────────┘ +/// ▲ ▲ ▲ Input now has 3 +/// │ │ │ partitions +/// ┌───────┘ │ └───────┐ +/// │ │ │ +/// │ │ │ +/// ┌───────────┐ ┌───────────┐ ┌───────────┐ +/// │ │ │ │ │ │ +/// │ batch A1 │ │ batch A3 │ │ batch B3 │ +/// │ │ │ │ │ │ +/// ├───────────┤ ├───────────┤ ├───────────┤ +/// │ │ │ │ │ │ +/// │ batch B2 │ │ batch B1 │ │ batch A2 │ +/// │ │ │ │ │ │ +/// └───────────┘ └───────────┘ └───────────┘ +/// ▲ ▲ ▲ +/// │ │ │ +/// └─────────┐ │ ┌──────────┘ +/// │ │ │ +/// │ │ │ +/// ┌─────────────────────────────────┐ batches are +/// │ RepartitionExec(3) │ repartitioned +/// │ RoundRobin │ +/// │ │ +/// └─────────────────────────────────┘ +/// ▲ ▲ +/// │ │ +/// ┌─────┘ └─────┐ +/// │ │ +/// │ │ +/// │ │ +/// ┌───────────┐ ┌───────────┐ +/// │ │ │ │ +/// │ batch A1 │ │ batch B1 │ +/// │ │ │ │ +/// ├───────────┤ ├───────────┤ +/// │ │ │ │ +/// │ batch A2 │ │ batch B2 │ +/// │ │ │ │ +/// ├───────────┤ ├───────────┤ +/// │ │ │ │ +/// │ batch A3 │ │ batch B3 │ +/// │ │ │ │ +/// └───────────┘ └───────────┘ +/// +/// +/// Input Input +/// A B +/// ``` +/// +/// The `EnforceDistribution` rule +/// - is idempotent; i.e. it can be applied multiple times, each time producing +/// the same result. +/// - always produces a valid plan in terms of distribution requirements. Its +/// input plan can be valid or invalid with respect to distribution requirements, +/// but the output plan will always be valid. +/// - produces a valid plan in terms of ordering requirements, *if* its input is +/// a valid plan in terms of ordering requirements. If the input plan is invalid, +/// this rule does not attempt to fix it as doing so is the responsibility of the +/// `EnforceSorting` rule. +/// +/// Note that distribution requirements are met in the strictest way. This may +/// result in more than strictly necessary [`RepartitionExec`]s in the plan, but +/// meeting the requirements in the strictest way may help avoid possible data +/// skew in joins. +/// +/// For example for a hash join with keys (a, b, c), the required Distribution(a, b, c) +/// can be satisfied by several alternative partitioning ways: (a, b, c), (a, b), +/// (a, c), (b, c), (a), (b), (c) and ( ). +/// +/// This rule only chooses the exact match and satisfies the Distribution(a, b, c) +/// by a HashPartition(a, b, c). +#[derive(Default)] +pub struct EnforceDistribution {} + +impl EnforceDistribution { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for EnforceDistribution { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; + + let adjusted = if top_down_join_key_reordering { + // Run a top-down process to adjust input key ordering recursively + let plan_requirements = PlanWithKeyRequirements::new(plan); + let adjusted = + plan_requirements.transform_down(&adjust_input_keys_ordering)?; + adjusted.plan + } else { + // Run a bottom-up process + plan.transform_up(&|plan| { + Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) + })? + }; + + let distribution_context = DistributionContext::new(adjusted); + // Distribution enforcement needs to be applied bottom-up. + let distribution_context = + distribution_context.transform_up(&|distribution_context| { + ensure_distribution(distribution_context, config) + })?; + Ok(distribution_context.plan) + } + + fn name(&self) -> &str { + "EnforceDistribution" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// A Top-Down process will use this method to adjust children's output partitioning based on the parent key reordering requirements: +/// +/// Example: +/// TopJoin on (a, b, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (a, b, c) +/// bottom left join on(a, b, c) +/// bottom right join on(a, b, c) +/// +/// Example: +/// TopJoin on (a, b, c) +/// Agg1 group by (b, a, c) +/// Agg2 group by (c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (a, b, c) +/// Projection(b, a, c) +/// Agg1 group by (a, b, c) +/// Projection(c, b, a) +/// Agg2 group by (a, b, c) +/// +/// Following is the explanation of the reordering process: +/// +/// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: +/// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. +/// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. +/// Requirements can be satisfied by adjusting keys ordering, clear the current requiements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. +/// +/// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: +/// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. +/// Requirements is already satisfied, clear all the requirements, return the unchanged plan. +/// Requirements can be satisfied by adjusting keys ordering, clear all the requirements, return the changed plan. +/// +/// 3) If the current plan is RepartitionExec, CoalescePartitionsExec or WindowAggExec, clear all the requirements, return the unchanged plan +/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements +/// 5) For other types of operators, by default, pushdown the parent requirements to children. +/// +fn adjust_input_keys_ordering( + requirements: PlanWithKeyRequirements, +) -> Result> { + let parent_required = requirements.required_key_ordering.clone(); + let plan_any = requirements.plan.as_any(); + let transformed = if let Some(HashJoinExec { + left, + right, + on, + filter, + join_type, + mode, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + match mode { + PartitionMode::Partitioned => { + let join_constructor = + |new_conditions: (Vec<(Column, Column)>, Vec)| { + Ok(Arc::new(HashJoinExec::try_new( + left.clone(), + right.clone(), + new_conditions.0, + filter.clone(), + join_type, + PartitionMode::Partitioned, + *null_equals_null, + )?) as Arc) + }; + Some(reorder_partitioned_join_keys( + requirements.plan.clone(), + &parent_required, + on, + vec![], + &join_constructor, + )?) + } + PartitionMode::CollectLeft => { + let new_right_request = match join_type { + JoinType::Inner | JoinType::Right => shift_right_required( + &parent_required, + left.schema().fields().len(), + ), + JoinType::RightSemi | JoinType::RightAnti => { + Some(parent_required.clone()) + } + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full => None, + }; + + // Push down requirements to the right side + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![None, new_right_request], + }) + } + PartitionMode::Auto => { + // Can not satisfy, clear the current requirements and generate new empty requirements + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + } + } + } else if let Some(CrossJoinExec { left, .. }) = + plan_any.downcast_ref::() + { + let left_columns_len = left.schema().fields().len(); + // Push down requirements to the right side + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![ + None, + shift_right_required(&parent_required, left_columns_len), + ], + }) + } else if let Some(SortMergeJoinExec { + left, + right, + on, + join_type, + sort_options, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + let join_constructor = + |new_conditions: (Vec<(Column, Column)>, Vec)| { + Ok(Arc::new(SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + new_conditions.0, + *join_type, + new_conditions.1, + *null_equals_null, + )?) as Arc) + }; + Some(reorder_partitioned_join_keys( + requirements.plan.clone(), + &parent_required, + on, + sort_options.clone(), + &join_constructor, + )?) + } else if let Some(aggregate_exec) = plan_any.downcast_ref::() { + if !parent_required.is_empty() { + match aggregate_exec.mode() { + AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys( + requirements.plan.clone(), + &parent_required, + aggregate_exec, + )?), + _ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())), + } + } else { + // Keep everything unchanged + None + } + } else if let Some(proj) = plan_any.downcast_ref::() { + let expr = proj.expr(); + // For Projection, we need to transform the requirements to the columns before the Projection + // And then to push down the requirements + // Construct a mapping from new name to the the orginal Column + let new_required = map_columns_before_projection(&parent_required, expr); + if new_required.len() == parent_required.len() { + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![Some(new_required.clone())], + }) + } else { + // Can not satisfy, clear the current requirements and generate new empty requirements + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + } + } else if plan_any.downcast_ref::().is_some() + || plan_any.downcast_ref::().is_some() + || plan_any.downcast_ref::().is_some() + { + Some(PlanWithKeyRequirements::new(requirements.plan.clone())) + } else { + // By default, push down the parent requirements to children + let children_len = requirements.plan.children().len(); + Some(PlanWithKeyRequirements { + plan: requirements.plan.clone(), + required_key_ordering: vec![], + request_key_ordering: vec![Some(parent_required.clone()); children_len], + }) + }; + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(requirements) + }) +} + +fn reorder_partitioned_join_keys( + join_plan: Arc, + parent_required: &[Arc], + on: &[(Column, Column)], + sort_options: Vec, + join_constructor: &F, +) -> Result +where + F: Fn((Vec<(Column, Column)>, Vec)) -> Result>, +{ + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = try_reorder( + join_key_pairs.clone(), + parent_required, + &join_plan.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + let mut new_sort_options: Vec = vec![]; + for idx in 0..sort_options.len() { + new_sort_options.push(sort_options[new_positions[idx]]) + } + + Ok(PlanWithKeyRequirements { + plan: join_constructor((new_join_on, new_sort_options))?, + required_key_ordering: vec![], + request_key_ordering: vec![Some(left_keys), Some(right_keys)], + }) + } else { + Ok(PlanWithKeyRequirements { + plan: join_plan, + required_key_ordering: vec![], + request_key_ordering: vec![Some(left_keys), Some(right_keys)], + }) + } + } else { + Ok(PlanWithKeyRequirements { + plan: join_plan, + required_key_ordering: vec![], + request_key_ordering: vec![ + Some(join_key_pairs.left_keys), + Some(join_key_pairs.right_keys), + ], + }) + } +} + +fn reorder_aggregate_keys( + agg_plan: Arc, + parent_required: &[Arc], + agg_exec: &AggregateExec, +) -> Result { + let output_columns = agg_exec + .group_by() + .expr() + .iter() + .enumerate() + .map(|(index, (_col, name))| Column::new(name, index)) + .collect::>(); + + let output_exprs = output_columns + .iter() + .map(|c| Arc::new(c.clone()) as _) + .collect::>(); + + if parent_required.len() != output_exprs.len() + || !agg_exec.group_by().null_expr().is_empty() + || physical_exprs_equal(&output_exprs, parent_required) + { + Ok(PlanWithKeyRequirements::new(agg_plan)) + } else { + let new_positions = expected_expr_positions(&output_exprs, parent_required); + match new_positions { + None => Ok(PlanWithKeyRequirements::new(agg_plan)), + Some(positions) => { + let new_partial_agg = if let Some(agg_exec) = + agg_exec.input().as_any().downcast_ref::() + { + if matches!(agg_exec.mode(), &AggregateMode::Partial) { + let group_exprs = agg_exec.group_by().expr(); + let new_group_exprs = positions + .into_iter() + .map(|idx| group_exprs[idx].clone()) + .collect(); + let new_partial_group_by = + PhysicalGroupBy::new_single(new_group_exprs); + Some(Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + new_partial_group_by, + agg_exec.aggr_expr().to_vec(), + agg_exec.filter_expr().to_vec(), + agg_exec.order_by_expr().to_vec(), + agg_exec.input().clone(), + agg_exec.input_schema.clone(), + )?)) + } else { + None + } + } else { + None + }; + if let Some(partial_agg) = new_partial_agg { + // Build new group expressions that correspond to the output of partial_agg + let group_exprs = partial_agg.group_expr().expr(); + let new_final_group = partial_agg.output_group_expr(); + let new_group_by = PhysicalGroupBy::new_single( + new_final_group + .iter() + .enumerate() + .map(|(idx, expr)| (expr.clone(), group_exprs[idx].1.clone())) + .collect(), + ); + + let new_final_agg = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + new_group_by, + agg_exec.aggr_expr().to_vec(), + agg_exec.filter_expr().to_vec(), + agg_exec.order_by_expr().to_vec(), + partial_agg, + agg_exec.input_schema(), + )?); + + // Need to create a new projection to change the expr ordering back + let agg_schema = new_final_agg.schema(); + let mut proj_exprs = output_columns + .iter() + .map(|col| { + let name = col.name(); + ( + Arc::new(Column::new( + name, + agg_schema.index_of(name).unwrap(), + )) as _, + name.to_owned(), + ) + }) + .collect::>(); + let agg_fields = agg_schema.fields(); + for (idx, field) in + agg_fields.iter().enumerate().skip(output_columns.len()) + { + let name = field.name(); + proj_exprs + .push((Arc::new(Column::new(name, idx)) as _, name.clone())) + } + // TODO merge adjacent Projections if there are + Ok(PlanWithKeyRequirements::new(Arc::new( + ProjectionExec::try_new(proj_exprs, new_final_agg)?, + ))) + } else { + Ok(PlanWithKeyRequirements::new(agg_plan)) + } + } + } + } +} + +fn shift_right_required( + parent_required: &[Arc], + left_columns_len: usize, +) -> Option>> { + let new_right_required = parent_required + .iter() + .filter_map(|r| { + if let Some(col) = r.as_any().downcast_ref::() { + let idx = col.index(); + if idx >= left_columns_len { + let result = Column::new(col.name(), idx - left_columns_len); + Some(Arc::new(result) as _) + } else { + None + } + } else { + None + } + }) + .collect::>(); + + // if the parent required are all comming from the right side, the requirements can be pushdown + (new_right_required.len() == parent_required.len()).then_some(new_right_required) +} + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// This method will try to change the ordering of the join keys to match with the +/// partitioning of the join nodes' children. If it can not match with both sides, it will try to +/// match with one, either the left side or the right side. +/// +/// Example: +/// TopJoin on (a, b, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (b, a, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Compared to the Top-Down reordering process, this Bottom-Up approach is much simpler, but might not reach a best result. +/// The Bottom-Up approach will be useful in future if we plan to support storage partition-wised Joins. +/// In that case, the datasources/tables might be pre-partitioned and we can't adjust the key ordering of the datasources +/// and then can't apply the Top-Down reordering process. +pub(crate) fn reorder_join_keys_to_inputs( + plan: Arc, +) -> Result> { + let plan_any = plan.as_any(); + if let Some(HashJoinExec { + left, + right, + on, + filter, + join_type, + mode, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + if matches!(mode, PartitionMode::Partitioned) { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = reorder_current_join_keys( + join_key_pairs, + Some(left.output_partitioning()), + Some(right.output_partitioning()), + &left.equivalence_properties(), + &right.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + return Ok(Arc::new(HashJoinExec::try_new( + left.clone(), + right.clone(), + new_join_on, + filter.clone(), + join_type, + PartitionMode::Partitioned, + *null_equals_null, + )?)); + } + } + } + } else if let Some(SortMergeJoinExec { + left, + right, + on, + join_type, + sort_options, + null_equals_null, + .. + }) = plan_any.downcast_ref::() + { + let join_key_pairs = extract_join_keys(on); + if let Some(( + JoinKeyPairs { + left_keys, + right_keys, + }, + new_positions, + )) = reorder_current_join_keys( + join_key_pairs, + Some(left.output_partitioning()), + Some(right.output_partitioning()), + &left.equivalence_properties(), + &right.equivalence_properties(), + ) { + if !new_positions.is_empty() { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + let new_sort_options = (0..sort_options.len()) + .map(|idx| sort_options[new_positions[idx]]) + .collect(); + return Ok(Arc::new(SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + new_join_on, + *join_type, + new_sort_options, + *null_equals_null, + )?)); + } + } + } + Ok(plan) +} + +/// Reorder the current join keys ordering based on either left partition or right partition +fn reorder_current_join_keys( + join_keys: JoinKeyPairs, + left_partition: Option, + right_partition: Option, + left_equivalence_properties: &EquivalenceProperties, + right_equivalence_properties: &EquivalenceProperties, +) -> Option<(JoinKeyPairs, Vec)> { + match (left_partition, right_partition.clone()) { + (Some(Partitioning::Hash(left_exprs, _)), _) => { + try_reorder(join_keys.clone(), &left_exprs, left_equivalence_properties) + .or_else(|| { + reorder_current_join_keys( + join_keys, + None, + right_partition, + left_equivalence_properties, + right_equivalence_properties, + ) + }) + } + (_, Some(Partitioning::Hash(right_exprs, _))) => { + try_reorder(join_keys, &right_exprs, right_equivalence_properties) + } + _ => None, + } +} + +fn try_reorder( + join_keys: JoinKeyPairs, + expected: &[Arc], + equivalence_properties: &EquivalenceProperties, +) -> Option<(JoinKeyPairs, Vec)> { + let eq_groups = equivalence_properties.eq_group(); + let mut normalized_expected = vec![]; + let mut normalized_left_keys = vec![]; + let mut normalized_right_keys = vec![]; + if join_keys.left_keys.len() != expected.len() { + return None; + } + if physical_exprs_equal(expected, &join_keys.left_keys) + || physical_exprs_equal(expected, &join_keys.right_keys) + { + return Some((join_keys, vec![])); + } else if !equivalence_properties.eq_group().is_empty() { + normalized_expected = expected + .iter() + .map(|e| eq_groups.normalize_expr(e.clone())) + .collect::>(); + assert_eq!(normalized_expected.len(), expected.len()); + + normalized_left_keys = join_keys + .left_keys + .iter() + .map(|e| eq_groups.normalize_expr(e.clone())) + .collect::>(); + assert_eq!(join_keys.left_keys.len(), normalized_left_keys.len()); + + normalized_right_keys = join_keys + .right_keys + .iter() + .map(|e| eq_groups.normalize_expr(e.clone())) + .collect::>(); + assert_eq!(join_keys.right_keys.len(), normalized_right_keys.len()); + + if physical_exprs_equal(&normalized_expected, &normalized_left_keys) + || physical_exprs_equal(&normalized_expected, &normalized_right_keys) + { + return Some((join_keys, vec![])); + } + } + + let new_positions = expected_expr_positions(&join_keys.left_keys, expected) + .or_else(|| expected_expr_positions(&join_keys.right_keys, expected)) + .or_else(|| expected_expr_positions(&normalized_left_keys, &normalized_expected)) + .or_else(|| { + expected_expr_positions(&normalized_right_keys, &normalized_expected) + }); + + if let Some(positions) = new_positions { + let mut new_left_keys = vec![]; + let mut new_right_keys = vec![]; + for pos in positions.iter() { + new_left_keys.push(join_keys.left_keys[*pos].clone()); + new_right_keys.push(join_keys.right_keys[*pos].clone()); + } + Some(( + JoinKeyPairs { + left_keys: new_left_keys, + right_keys: new_right_keys, + }, + positions, + )) + } else { + None + } +} + +/// Return the expected expressions positions. +/// For example, the current expressions are ['c', 'a', 'a', b'], the expected expressions are ['b', 'c', 'a', 'a'], +/// +/// This method will return a Vec [3, 0, 1, 2] +fn expected_expr_positions( + current: &[Arc], + expected: &[Arc], +) -> Option> { + if current.is_empty() || expected.is_empty() { + return None; + } + let mut indexes: Vec = vec![]; + let mut current = current.to_vec(); + for expr in expected.iter() { + // Find the position of the expected expr in the current expressions + if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) { + current[expected_position] = Arc::new(NoOp::new()); + indexes.push(expected_position); + } else { + return None; + } + } + Some(indexes) +} + +fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs { + let (left_keys, right_keys) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + JoinKeyPairs { + left_keys, + right_keys, + } +} + +fn new_join_conditions( + new_left_keys: &[Arc], + new_right_keys: &[Arc], +) -> Vec<(Column, Column)> { + new_left_keys + .iter() + .zip(new_right_keys.iter()) + .map(|(l_key, r_key)| { + ( + l_key.as_any().downcast_ref::().unwrap().clone(), + r_key.as_any().downcast_ref::().unwrap().clone(), + ) + }) + .collect() +} + +/// Updates `dist_onward` such that, to keep track of +/// `input` in the `exec_tree`. +/// +/// # Arguments +/// +/// * `input`: Current execution plan +/// * `dist_onward`: It keeps track of executors starting from a distribution +/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) +/// until child of `input` (`input` should have single child). +/// * `input_idx`: index of the `input`, for its parent. +/// +fn update_distribution_onward( + input: Arc, + dist_onward: &mut Option, + input_idx: usize, +) { + // Update the onward tree if there is an active branch + if let Some(exec_tree) = dist_onward { + // When we add a new operator to change distribution + // we add RepartitionExec, SortPreservingMergeExec, CoalescePartitionsExec + // in this case, we need to update exec tree idx such that exec tree is now child of these + // operators (change the 0, since all of the operators have single child). + exec_tree.idx = 0; + *exec_tree = ExecTree::new(input, input_idx, vec![exec_tree.clone()]); + } else { + *dist_onward = Some(ExecTree::new(input, input_idx, vec![])); + } +} + +/// Adds RoundRobin repartition operator to the plan increase parallelism. +/// +/// # Arguments +/// +/// * `input`: Current execution plan +/// * `n_target`: desired target partition number, if partition number of the +/// current executor is less than this value. Partition number will be increased. +/// * `dist_onward`: It keeps track of executors starting from a distribution +/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) +/// until `input` plan. +/// * `input_idx`: index of the `input`, for its parent. +/// +/// # Returns +/// +/// A [Result] object that contains new execution plan, where desired partition number +/// is achieved by adding RoundRobin Repartition. +fn add_roundrobin_on_top( + input: Arc, + n_target: usize, + dist_onward: &mut Option, + input_idx: usize, +) -> Result> { + // Adding repartition is helpful + if input.output_partitioning().partition_count() < n_target { + // When there is an existing ordering, we preserve ordering + // during repartition. This will be un-done in the future + // If any of the following conditions is true + // - Preserving ordering is not helpful in terms of satisfying ordering requirements + // - Usage of order preserving variants is not desirable + // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + let partitioning = Partitioning::RoundRobinBatch(n_target); + let repartition = + RepartitionExec::try_new(input, partitioning)?.with_preserve_order(); + + // update distribution onward with new operator + let new_plan = Arc::new(repartition) as Arc; + update_distribution_onward(new_plan.clone(), dist_onward, input_idx); + Ok(new_plan) + } else { + // Partition is not helpful, we already have desired number of partitions. + Ok(input) + } +} + +/// Adds a hash repartition operator: +/// - to increase parallelism, and/or +/// - to satisfy requirements of the subsequent operators. +/// Repartition(Hash) is added on top of operator `input`. +/// +/// # Arguments +/// +/// * `input`: Current execution plan +/// * `hash_exprs`: Stores Physical Exprs that are used during hashing. +/// * `n_target`: desired target partition number, if partition number of the +/// current executor is less than this value. Partition number will be increased. +/// * `dist_onward`: It keeps track of executors starting from a distribution +/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) +/// until `input` plan. +/// * `input_idx`: index of the `input`, for its parent. +/// +/// # Returns +/// +/// A [`Result`] object that contains new execution plan, where desired distribution is +/// satisfied by adding Hash Repartition. +fn add_hash_on_top( + input: Arc, + hash_exprs: Vec>, + // Repartition(Hash) will have `n_target` partitions at the output. + n_target: usize, + // Stores executors starting from Repartition(RoundRobin) until + // current executor. When Repartition(Hash) is added, `dist_onward` + // is updated such that it stores connection from Repartition(RoundRobin) + // until Repartition(Hash). + dist_onward: &mut Option, + input_idx: usize, + repartition_beneficial_stats: bool, +) -> Result> { + if n_target == input.output_partitioning().partition_count() && n_target == 1 { + // In this case adding a hash repartition is unnecessary as the hash + // requirement is implicitly satisfied. + return Ok(input); + } + let satisfied = input + .output_partitioning() + .satisfy(Distribution::HashPartitioned(hash_exprs.clone()), || { + input.equivalence_properties() + }); + // Add hash repartitioning when: + // - The hash distribution requirement is not satisfied, or + // - We can increase parallelism by adding hash partitioning. + if !satisfied || n_target > input.output_partitioning().partition_count() { + // When there is an existing ordering, we preserve ordering during + // repartition. This will be rolled back in the future if any of the + // following conditions is true: + // - Preserving ordering is not helpful in terms of satisfying ordering + // requirements. + // - Usage of order preserving variants is not desirable (per the flag + // `config.optimizer.bounded_order_preserving_variants`). + let mut new_plan = if repartition_beneficial_stats { + // Since hashing benefits from partitioning, add a round-robin repartition + // before it: + add_roundrobin_on_top(input, n_target, dist_onward, 0)? + } else { + input + }; + let partitioning = Partitioning::Hash(hash_exprs, n_target); + let repartition = RepartitionExec::try_new(new_plan, partitioning)? + // preserve any ordering if possible + .with_preserve_order(); + new_plan = Arc::new(repartition) as _; + + // update distribution onward with new operator + update_distribution_onward(new_plan.clone(), dist_onward, input_idx); + Ok(new_plan) + } else { + Ok(input) + } +} + +/// Adds a `SortPreservingMergeExec` operator on top of input executor: +/// - to satisfy single distribution requirement. +/// +/// # Arguments +/// +/// * `input`: Current execution plan +/// * `dist_onward`: It keeps track of executors starting from a distribution +/// changing operator (e.g Repartition, SortPreservingMergeExec, etc.) +/// until `input` plan. +/// * `input_idx`: index of the `input`, for its parent. +/// +/// # Returns +/// +/// New execution plan, where desired single +/// distribution is satisfied by adding `SortPreservingMergeExec`. +fn add_spm_on_top( + input: Arc, + dist_onward: &mut Option, + input_idx: usize, +) -> Arc { + // Add SortPreservingMerge only when partition count is larger than 1. + if input.output_partitioning().partition_count() > 1 { + // When there is an existing ordering, we preserve ordering + // during decreasıng partıtıons. This will be un-done in the future + // If any of the following conditions is true + // - Preserving ordering is not helpful in terms of satisfying ordering requirements + // - Usage of order preserving variants is not desirable + // (determined by flag `config.optimizer.bounded_order_preserving_variants`) + let should_preserve_ordering = input.output_ordering().is_some(); + let new_plan: Arc = if should_preserve_ordering { + let existing_ordering = input.output_ordering().unwrap_or(&[]); + Arc::new(SortPreservingMergeExec::new( + existing_ordering.to_vec(), + input, + )) as _ + } else { + Arc::new(CoalescePartitionsExec::new(input)) as _ + }; + + // update repartition onward with new operator + update_distribution_onward(new_plan.clone(), dist_onward, input_idx); + new_plan + } else { + input + } +} + +/// Updates the physical plan inside `distribution_context` so that distribution +/// changing operators are removed from the top. If they are necessary, they will +/// be added in subsequent stages. +/// +/// Assume that following plan is given: +/// ```text +/// "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", +/// ``` +/// +/// Since `RepartitionExec`s change the distribution, this function removes +/// them and returns following plan: +/// +/// ```text +/// "ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", +/// ``` +fn remove_dist_changing_operators( + distribution_context: DistributionContext, +) -> Result { + let DistributionContext { + mut plan, + mut distribution_onwards, + } = distribution_context; + + // Remove any distribution changing operators at the beginning: + // Note that they will be re-inserted later on if necessary or helpful. + while is_repartition(&plan) + || is_coalesce_partitions(&plan) + || is_sort_preserving_merge(&plan) + { + // All of above operators have a single child. When we remove the top + // operator, we take the first child. + plan = plan.children().swap_remove(0); + distribution_onwards = + get_children_exectrees(plan.children().len(), &distribution_onwards[0]); + } + + // Create a plan with the updated children: + Ok(DistributionContext { + plan, + distribution_onwards, + }) +} + +/// Updates the physical plan `input` by using `dist_onward` replace order preserving operator variants +/// with their corresponding operators that do not preserve order. It is a wrapper for `replace_order_preserving_variants_helper` +fn replace_order_preserving_variants( + input: &mut Arc, + dist_onward: &mut Option, +) -> Result<()> { + if let Some(dist_onward) = dist_onward { + *input = replace_order_preserving_variants_helper(dist_onward)?; + } + *dist_onward = None; + Ok(()) +} + +/// Updates the physical plan inside `ExecTree` if preserving ordering while changing partitioning +/// is not helpful or desirable. +/// +/// Assume that following plan is given: +/// ```text +/// "SortPreservingMergeExec: \[a@0 ASC]" +/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", +/// " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", +/// ``` +/// +/// This function converts plan above (inside `ExecTree`) to the following: +/// +/// ```text +/// "CoalescePartitionsExec" +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " ParquetExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC]", +/// ``` +fn replace_order_preserving_variants_helper( + exec_tree: &ExecTree, +) -> Result> { + let mut updated_children = exec_tree.plan.children(); + for child in &exec_tree.children { + updated_children[child.idx] = replace_order_preserving_variants_helper(child)?; + } + if is_sort_preserving_merge(&exec_tree.plan) { + return Ok(Arc::new(CoalescePartitionsExec::new( + updated_children.swap_remove(0), + ))); + } + if let Some(repartition) = exec_tree.plan.as_any().downcast_ref::() { + if repartition.preserve_order() { + return Ok(Arc::new( + // new RepartitionExec don't preserve order + RepartitionExec::try_new( + updated_children.swap_remove(0), + repartition.partitioning().clone(), + )?, + )); + } + } + exec_tree.plan.clone().with_new_children(updated_children) +} + +/// This function checks whether we need to add additional data exchange +/// operators to satisfy distribution requirements. Since this function +/// takes care of such requirements, we should avoid manually adding data +/// exchange operators in other places. +fn ensure_distribution( + dist_context: DistributionContext, + config: &ConfigOptions, +) -> Result> { + let target_partitions = config.execution.target_partitions; + // When `false`, round robin repartition will not be added to increase parallelism + let enable_round_robin = config.optimizer.enable_round_robin_repartition; + let repartition_file_scans = config.optimizer.repartition_file_scans; + let batch_size = config.execution.batch_size; + let is_unbounded = unbounded_output(&dist_context.plan); + // Use order preserving variants either of the conditions true + // - it is desired according to config + // - when plan is unbounded + let order_preserving_variants_desirable = + is_unbounded || config.optimizer.prefer_existing_sort; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::No(dist_context)); + } + + // Remove unnecessary repartition from the physical plan if any + let DistributionContext { + mut plan, + mut distribution_onwards, + } = remove_dist_changing_operators(dist_context)?; + + if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys, + )? { + plan = updated_window; + } + } else if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys, + )? { + plan = updated_window; + } + }; + let n_children = plan.children().len(); + // This loop iterates over all the children to: + // - Increase parallelism for every child if it is beneficial. + // - Satisfy the distribution requirements of every child, if it is not + // already satisfied. + // We store the updated children in `new_children`. + let new_children = izip!( + plan.children().into_iter(), + plan.required_input_distribution().iter(), + plan.required_input_ordering().iter(), + distribution_onwards.iter_mut(), + plan.benefits_from_input_partitioning(), + plan.maintains_input_order(), + 0..n_children + ) + .map( + |( + mut child, + requirement, + required_input_ordering, + dist_onward, + would_benefit, + maintains, + child_idx, + )| { + // Don't need to apply when the returned row count is not greater than 1: + let num_rows = child.statistics()?.num_rows; + let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { + num_rows + .get_value() + .map(|value| value > &batch_size) + .unwrap_or(true) + } else { + true + }; + if enable_round_robin + // Operator benefits from partitioning (e.g. filter): + && (would_benefit && repartition_beneficial_stats) + // Unless partitioning doesn't increase the partition count, it is not beneficial: + && child.output_partitioning().partition_count() < target_partitions + { + // When `repartition_file_scans` is set, attempt to increase + // parallelism at the source. + if repartition_file_scans { + if let Some(new_child) = + child.repartitioned(target_partitions, config)? + { + child = new_child; + } + } + // Increase parallelism by adding round-robin repartitioning + // on top of the operator. Note that we only do this if the + // partition count is not already equal to the desired partition + // count. + child = add_roundrobin_on_top( + child, + target_partitions, + dist_onward, + child_idx, + )?; + } + + // Satisfy the distribution requirement if it is unmet. + match requirement { + Distribution::SinglePartition => { + child = add_spm_on_top(child, dist_onward, child_idx); + } + Distribution::HashPartitioned(exprs) => { + child = add_hash_on_top( + child, + exprs.to_vec(), + target_partitions, + dist_onward, + child_idx, + repartition_beneficial_stats, + )?; + } + Distribution::UnspecifiedDistribution => {} + }; + + // There is an ordering requirement of the operator: + if let Some(required_input_ordering) = required_input_ordering { + // Either: + // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or + // - using order preserving variant is not desirable. + let ordering_satisfied = child + .equivalence_properties() + .ordering_satisfy_requirement(required_input_ordering); + if !ordering_satisfied || !order_preserving_variants_desirable { + replace_order_preserving_variants(&mut child, dist_onward)?; + // If ordering requirements were satisfied before repartitioning, + // make sure ordering requirements are still satisfied after. + if ordering_satisfied { + // Make sure to satisfy ordering requirement: + add_sort_above(&mut child, required_input_ordering, None); + } + } + // Stop tracking distribution changing operators + *dist_onward = None; + } else { + // no ordering requirement + match requirement { + // Operator requires specific distribution. + Distribution::SinglePartition | Distribution::HashPartitioned(_) => { + // Since there is no ordering requirement, preserving ordering is pointless + replace_order_preserving_variants(&mut child, dist_onward)?; + } + Distribution::UnspecifiedDistribution => { + // Since ordering is lost, trying to preserve ordering is pointless + if !maintains { + replace_order_preserving_variants(&mut child, dist_onward)?; + } + } + } + } + Ok(child) + }, + ) + .collect::>>()?; + + let new_distribution_context = DistributionContext { + plan: if plan.as_any().is::() && can_interleave(&new_children) { + // Add a special case for [`UnionExec`] since we want to "bubble up" + // hash-partitioned data. So instead of + // + // Agg: + // Repartition (hash): + // Union: + // - Agg: + // Repartition (hash): + // Data + // - Agg: + // Repartition (hash): + // Data + // + // we can use: + // + // Agg: + // Interleave: + // - Agg: + // Repartition (hash): + // Data + // - Agg: + // Repartition (hash): + // Data + Arc::new(InterleaveExec::try_new(new_children)?) + } else { + plan.with_new_children(new_children)? + }, + distribution_onwards, + }; + Ok(Transformed::Yes(new_distribution_context)) +} + +/// A struct to keep track of distribution changing executors +/// (`RepartitionExec`, `SortPreservingMergeExec`, `CoalescePartitionsExec`), +/// and their associated parents inside `plan`. Using this information, +/// we can optimize distribution of the plan if/when necessary. +#[derive(Debug, Clone)] +struct DistributionContext { + plan: Arc, + /// Keep track of associations for each child of the plan. If `None`, + /// there is no distribution changing operator in its descendants. + distribution_onwards: Vec>, +} + +impl DistributionContext { + /// Creates an empty context. + fn new(plan: Arc) -> Self { + let length = plan.children().len(); + DistributionContext { + plan, + distribution_onwards: vec![None; length], + } + } + + /// Constructs a new context from children contexts. + fn new_from_children_nodes( + children_nodes: Vec, + parent_plan: Arc, + ) -> Result { + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect(); + let distribution_onwards = children_nodes + .into_iter() + .enumerate() + .map(|(idx, context)| { + let DistributionContext { + plan, + // The `distribution_onwards` tree keeps track of operators + // that change distribution, or preserves the existing + // distribution (starting from an operator that change distribution). + distribution_onwards, + } = context; + if plan.children().is_empty() { + // Plan has no children, there is nothing to propagate. + None + } else if distribution_onwards[0].is_none() { + if let Some(repartition) = + plan.as_any().downcast_ref::() + { + match repartition.partitioning() { + Partitioning::RoundRobinBatch(_) + | Partitioning::Hash(_, _) => { + // Start tracking operators starting from this repartition (either roundrobin or hash): + return Some(ExecTree::new(plan, idx, vec![])); + } + _ => {} + } + } else if plan.as_any().is::() + || plan.as_any().is::() + { + // Start tracking operators starting from this sort preserving merge: + return Some(ExecTree::new(plan, idx, vec![])); + } + None + } else { + // Propagate children distribution tracking to the above + let new_distribution_onwards = izip!( + plan.required_input_distribution().iter(), + distribution_onwards.into_iter() + ) + .flat_map(|(required_dist, distribution_onwards)| { + if let Some(distribution_onwards) = distribution_onwards { + // Operator can safely propagate the distribution above. + // This is similar to maintaining order in the EnforceSorting rule. + if let Distribution::UnspecifiedDistribution = required_dist { + return Some(distribution_onwards); + } + } + None + }) + .collect::>(); + // Either: + // - None of the children has a connection to an operator that modifies distribution, or + // - The current operator requires distribution at its input so doesn't propagate it above. + if new_distribution_onwards.is_empty() { + None + } else { + Some(ExecTree::new(plan, idx, new_distribution_onwards)) + } + } + }) + .collect(); + Ok(DistributionContext { + plan: with_new_children_if_necessary(parent_plan, children_plans)?.into(), + distribution_onwards, + }) + } + + /// Computes distribution tracking contexts for every child of the plan. + fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(DistributionContext::new) + .collect() + } +} + +impl TreeNode for DistributionContext { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_nodes = children + .into_iter() + .map(transform) + .collect::>>()?; + DistributionContext::new_from_children_nodes(children_nodes, self.plan) + } + } +} + +/// implement Display method for `DistributionContext` struct. +impl fmt::Display for DistributionContext { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let plan_string = get_plan_string(&self.plan); + write!(f, "plan: {:?}", plan_string)?; + for (idx, child) in self.distribution_onwards.iter().enumerate() { + if let Some(child) = child { + write!(f, "idx:{:?}, exec_tree:{}", idx, child)?; + } + } + write!(f, "") + } +} + +#[derive(Debug, Clone)] +struct JoinKeyPairs { + left_keys: Vec>, + right_keys: Vec>, +} + +#[derive(Debug, Clone)] +struct PlanWithKeyRequirements { + plan: Arc, + /// Parent required key ordering + required_key_ordering: Vec>, + /// The request key ordering to children + request_key_ordering: Vec>>>, +} + +impl PlanWithKeyRequirements { + fn new(plan: Arc) -> Self { + let children_len = plan.children().len(); + PlanWithKeyRequirements { + plan, + required_key_ordering: vec![], + request_key_ordering: vec![None; children_len], + } + } + + fn children(&self) -> Vec { + let plan_children = self.plan.children(); + assert_eq!(plan_children.len(), self.request_key_ordering.len()); + plan_children + .into_iter() + .zip(self.request_key_ordering.clone()) + .map(|(child, required)| { + let from_parent = required.unwrap_or_default(); + let length = child.children().len(); + PlanWithKeyRequirements { + plan: child, + required_key_ordering: from_parent, + request_key_ordering: vec![None; length], + } + }) + .collect() + } +} + +impl TreeNode for PlanWithKeyRequirements { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + let children = self.children(); + for child in children { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + + Ok(VisitRecursion::Continue) + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + + let children_plans = new_children? + .into_iter() + .map(|child| child.plan) + .collect::>(); + let new_plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithKeyRequirements { + plan: new_plan.into(), + required_key_ordering: self.required_key_ordering, + request_key_ordering: self.request_key_ordering, + }) + } else { + Ok(self) + } + } +} + +/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on +#[cfg(feature = "parquet")] +#[cfg(test)] +pub(crate) mod tests { + use std::ops::Deref; + + use super::*; + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::object_store::ObjectStoreUrl; + use crate::datasource::physical_plan::ParquetExec; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_optimizer::enforce_sorting::EnforceSorting; + use crate::physical_optimizer::output_requirements::OutputRequirements; + use crate::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; + use crate::physical_plan::expressions::col; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::{ + utils::JoinOn, HashJoinExec, PartitionMode, SortMergeJoinExec, + }; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; + + use crate::physical_optimizer::test_utils::{ + coalesce_partitions_exec, repartition_exec, + }; + use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use crate::physical_plan::sorts::sort::SortExec; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ScalarValue; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_physical_expr::{ + expressions, expressions::binary, expressions::lit, expressions::Column, + LexOrdering, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + }; + + /// Models operators like BoundedWindowExec that require an input + /// ordering but is easy to construct + #[derive(Debug)] + struct SortRequiredExec { + input: Arc, + expr: LexOrdering, + } + + impl SortRequiredExec { + fn new(input: Arc) -> Self { + let expr = input.output_ordering().unwrap_or(&[]).to_vec(); + Self { input, expr } + } + + fn new_with_requirement( + input: Arc, + requirement: Vec, + ) -> Self { + Self { + input, + expr: requirement, + } + } + } + + impl DisplayAs for SortRequiredExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!( + f, + "SortRequiredExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + ) + } + } + + impl ExecutionPlan for SortRequiredExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> crate::physical_plan::Partitioning { + self.input.output_partitioning() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + // model that it requires the output ordering of its input + fn required_input_ordering(&self) -> Vec>> { + vec![self + .output_ordering() + .map(PhysicalSortRequirement::from_sort_exprs)] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + let child = children.pop().unwrap(); + Ok(Arc::new(Self::new_with_requirement( + child, + self.expr.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!(); + } + + fn statistics(&self) -> Result { + self.input.statistics() + } + } + + pub(crate) fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])) + } + + fn parquet_exec() -> Arc { + parquet_exec_with_sort(vec![]) + } + + pub(crate) fn parquet_exec_with_sort( + output_ordering: Vec>, + ) -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering, + infinite_source: false, + }, + None, + None, + )) + } + + fn parquet_exec_multiple() -> Arc { + parquet_exec_multiple_sorted(vec![]) + } + + // Created a sorted parquet exec with multiple files + fn parquet_exec_multiple_sorted( + output_ordering: Vec>, + ) -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![ + vec![PartitionedFile::new("x".to_string(), 100)], + vec![PartitionedFile::new("y".to_string(), 100)], + ], + statistics: Statistics::new_unknown(&schema()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering, + infinite_source: false, + }, + None, + None, + )) + } + + fn csv_exec() -> Arc { + csv_exec_with_sort(vec![]) + } + + fn csv_exec_with_sort(output_ordering: Vec>) -> Arc { + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering, + infinite_source: false, + }, + false, + b',', + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn csv_exec_multiple() -> Arc { + csv_exec_multiple_sorted(vec![]) + } + + // Created a sorted parquet exec with multiple files + fn csv_exec_multiple_sorted( + output_ordering: Vec>, + ) -> Arc { + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![ + vec![PartitionedFile::new("x".to_string(), 100)], + vec![PartitionedFile::new("y".to_string(), 100)], + ], + statistics: Statistics::new_unknown(&schema()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering, + infinite_source: false, + }, + false, + b',', + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn projection_exec_with_alias( + input: Arc, + alias_pairs: Vec<(String, String)>, + ) -> Arc { + let mut exprs = vec![]; + for (column, alias) in alias_pairs.iter() { + exprs.push((col(column, &input.schema()).unwrap(), alias.to_string())); + } + Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) + } + + fn aggregate_exec_with_alias( + input: Arc, + alias_pairs: Vec<(String, String)>, + ) -> Arc { + let schema = schema(); + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for (column, alias) in alias_pairs.iter() { + group_by_expr + .push((col(column, &input.schema()).unwrap(), alias.to_string())); + } + let group_by = PhysicalGroupBy::new_single(group_by_expr.clone()); + + let final_group_by_expr = group_by_expr + .iter() + .enumerate() + .map(|(index, (_col, name))| { + ( + Arc::new(expressions::Column::new(name, index)) + as Arc, + name.clone(), + ) + }) + .collect::>(); + let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + + Arc::new( + AggregateExec::try_new( + AggregateMode::FinalPartitioned, + final_grouping, + vec![], + vec![], + vec![], + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![], + vec![], + vec![], + input, + schema.clone(), + ) + .unwrap(), + ), + schema, + ) + .unwrap(), + ) + } + + fn hash_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, + ) -> Arc { + Arc::new( + HashJoinExec::try_new( + left, + right, + join_on.clone(), + None, + join_type, + PartitionMode::Partitioned, + false, + ) + .unwrap(), + ) + } + + fn sort_merge_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, + ) -> Arc { + Arc::new( + SortMergeJoinExec::try_new( + left, + right, + join_on.clone(), + *join_type, + vec![SortOptions::default(); join_on.len()], + false, + ) + .unwrap(), + ) + } + + fn filter_exec(input: Arc) -> Arc { + let predicate = Arc::new(BinaryExpr::new( + col("c", &schema()).unwrap(), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + Arc::new(FilterExec::try_new(predicate, input).unwrap()) + } + + fn sort_exec( + sort_exprs: Vec, + input: Arc, + preserve_partitioning: bool, + ) -> Arc { + let new_sort = SortExec::new(sort_exprs, input) + .with_preserve_partitioning(preserve_partitioning); + Arc::new(new_sort) + } + + fn sort_preserving_merge_exec( + sort_exprs: Vec, + input: Arc, + ) -> Arc { + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + } + + fn limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new( + Arc::new(LocalLimitExec::new(input, 100)), + 0, + Some(100), + )) + } + + fn union_exec(input: Vec>) -> Arc { + Arc::new(UnionExec::new(input)) + } + + fn sort_required_exec(input: Arc) -> Arc { + Arc::new(SortRequiredExec::new(input)) + } + + fn sort_required_exec_with_req( + input: Arc, + sort_exprs: LexOrdering, + ) -> Arc { + Arc::new(SortRequiredExec::new_with_requirement(input, sort_exprs)) + } + + pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() + } + + fn ensure_distribution_helper( + plan: Arc, + target_partitions: usize, + bounded_order_preserving_variants: bool, + ) -> Result> { + let distribution_context = DistributionContext::new(plan); + let mut config = ConfigOptions::new(); + config.execution.target_partitions = target_partitions; + config.optimizer.enable_round_robin_repartition = false; + config.optimizer.repartition_file_scans = false; + config.optimizer.repartition_file_min_size = 1024; + config.optimizer.prefer_existing_sort = bounded_order_preserving_variants; + ensure_distribution(distribution_context, &config).map(|item| item.into().plan) + } + + /// Test whether plan matches with expected plan + macro_rules! plans_matches_expected { + ($EXPECTED_LINES: expr, $PLAN: expr) => { + let physical_plan = $PLAN; + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + } + } + + /// Runs the repartition optimizer and asserts the plan against the expected + macro_rules! assert_optimized { + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024); + }; + + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $BOUNDED_ORDER_PRESERVING_VARIANTS, 10, false, 1024); + }; + + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $BOUNDED_ORDER_PRESERVING_VARIANTS: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { + let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); + + let mut config = ConfigOptions::new(); + config.execution.target_partitions = $TARGET_PARTITIONS; + config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; + config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; + config.optimizer.prefer_existing_sort = $BOUNDED_ORDER_PRESERVING_VARIANTS; + + // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade + // because they were written prior to the separation of `BasicEnforcement` into + // `EnforceSorting` and `EnforceDistribution`. + // TODO: Orthogonalize the tests here just to verify `EnforceDistribution` and create + // new tests for the cascade. + + // Add the ancillary output requirements operator at the start: + let optimizer = OutputRequirements::new_add_mode(); + let optimized = optimizer.optimize($PLAN.clone(), &config)?; + + let optimized = if $FIRST_ENFORCE_DIST { + // Run enforce distribution rule first: + let optimizer = EnforceDistribution::new(); + let optimized = optimizer.optimize(optimized, &config)?; + // The rule should be idempotent. + // Re-running this rule shouldn't introduce unnecessary operators. + let optimizer = EnforceDistribution::new(); + let optimized = optimizer.optimize(optimized, &config)?; + // Run the enforce sorting rule: + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(optimized, &config)?; + optimized + } else { + // Run the enforce sorting rule first: + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(optimized, &config)?; + // Run enforce distribution rule: + let optimizer = EnforceDistribution::new(); + let optimized = optimizer.optimize(optimized, &config)?; + // The rule should be idempotent. + // Re-running this rule shouldn't introduce unnecessary operators. + let optimizer = EnforceDistribution::new(); + let optimized = optimizer.optimize(optimized, &config)?; + optimized + }; + + // Remove the ancillary output requirements operator when done: + let optimizer = OutputRequirements::new_remove_mode(); + let optimized = optimizer.optimize(optimized, &config)?; + + // Now format correctly + let plan = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&plan); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; + } + + macro_rules! assert_plan_txt { + ($EXPECTED_LINES: expr, $PLAN: expr) => { + let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); + // Now format correctly + let plan = displayable($PLAN.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&plan); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; + } + + #[test] + fn multi_hash_joins() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ("d".to_string(), "d1".to_string()), + ("e".to_string(), "e1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + // Join on (a == b1) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + for join_type in join_types { + let join = hash_join_exec(left.clone(), right.clone(), &join_on, &join_type); + let join_plan = format!( + "HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(a@0, b1@1)]" + ); + + match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftSemi + | JoinType::LeftAnti => { + // Join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + let top_join = hash_join_exec( + join.clone(), + parquet_exec(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(a@0, c@2)]"); + + let expected = match join_type { + // Should include 3 RepartitionExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs + _ => vec![ + top_join_plan.as_str(), + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join.clone(), true); + assert_optimized!(expected, top_join, false); + } + JoinType::RightSemi | JoinType::RightAnti => {} + } + + match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightAnti => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Column::new_with_schema("b1", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = + hash_join_exec(join, parquet_exec(), &top_join_on, &join_type); + let top_join_plan = match join_type { + JoinType::RightSemi | JoinType::RightAnti => + format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(b1@1, c@2)]"), + _ => + format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(b1@6, c@2)]"), + }; + + let expected = match join_type { + // Should include 3 RepartitionExecs + JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => + vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 4 RepartitionExecs + _ => + vec![ + top_join_plan.as_str(), + "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", + join_plan.as_str(), + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join.clone(), true); + assert_optimized!(expected, top_join, false); + } + JoinType::LeftSemi | JoinType::LeftAnti => {} + } + } + + Ok(()) + } + + #[test] + fn multi_joins_after_alias() -> Result<()> { + let left = parquet_exec(); + let right = parquet_exec(); + + // Join on (a == b) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b", &schema()).unwrap(), + )]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Projection(a as a1, a as a2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("a".to_string(), "a2".to_string()), + ]; + let projection = projection_exec_with_alias(join, alias_pairs); + + // Join on (a1 == c) + let top_join_on = vec![( + Column::new_with_schema("a1", &projection.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec( + projection.clone(), + right.clone(), + &top_join_on, + &JoinType::Inner, + ); + + // Output partition need to respect the Alias and should not introduce additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, c@2)]", + "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, top_join.clone(), true); + assert_optimized!(expected, top_join, false); + + // Join on (a2 == c) + let top_join_on = vec![( + Column::new_with_schema("a2", &projection.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner); + + // Output partition need to respect the Alias and should not introduce additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a2@1, c@2)]", + "ProjectionExec: expr=[a@0 as a1, a@0 as a2]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, top_join.clone(), true); + assert_optimized!(expected, top_join, false); + Ok(()) + } + + #[test] + fn multi_joins_after_multi_alias() -> Result<()> { + let left = parquet_exec(); + let right = parquet_exec(); + + // Join on (a == b) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b", &schema()).unwrap(), + )]; + + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Projection(c as c1) + let alias_pairs: Vec<(String, String)> = + vec![("c".to_string(), "c1".to_string())]; + let projection = projection_exec_with_alias(join, alias_pairs); + + // Projection(c1 as a) + let alias_pairs: Vec<(String, String)> = + vec![("c1".to_string(), "a".to_string())]; + let projection2 = projection_exec_with_alias(projection, alias_pairs); + + // Join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &projection2.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + + let top_join = hash_join_exec(projection2, right, &top_join_on, &JoinType::Inner); + + // The Column 'a' has different meaning now after the two Projections + // The original Output partition can not satisfy the Join requirements and need to add an additional RepartitionExec + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, c@2)]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "ProjectionExec: expr=[c1@0 as a]", + "ProjectionExec: expr=[c@2 as c1]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, top_join.clone(), true); + assert_optimized!(expected, top_join, false); + Ok(()) + } + + #[test] + fn join_after_agg_alias() -> Result<()> { + // group by (a as a1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + // group by (a as a2) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a2".to_string())], + ); + + // Join on (a1 == a2) + let join_on = vec![( + Column::new_with_schema("a1", &left.schema()).unwrap(), + Column::new_with_schema("a2", &right.schema()).unwrap(), + )]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, a2@0)]", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", + "RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join.clone(), true); + assert_optimized!(expected, join, false); + Ok(()) + } + + #[test] + fn hash_join_key_ordering() -> Result<()> { + // group by (a as a1, b as b1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ], + ); + // group by (b, a) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("b".to_string(), "b".to_string()), + ("a".to_string(), "a".to_string()), + ], + ); + + // Join on (b1 == b && a1 == a) + let join_on = vec![ + ( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a1", &left.schema()).unwrap(), + Column::new_with_schema("a", &right.schema()).unwrap(), + ), + ]; + let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b1@1, b@0), (a1@0, a@1)]", + "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join.clone(), true); + assert_optimized!(expected, join, false); + Ok(()) + } + + #[test] + fn multi_hash_join_key_ordering() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + + // Join on (a == a1 and b == b1 and c == c1) + let join_on = vec![ + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ]; + let bottom_left_join = + hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner); + + // Projection(a as A, a as AA, b as B, c as C) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "A".to_string()), + ("a".to_string(), "AA".to_string()), + ("b".to_string(), "B".to_string()), + ("c".to_string(), "C".to_string()), + ]; + let bottom_left_projection = + projection_exec_with_alias(bottom_left_join, alias_pairs); + + // Join on (c == c1 and b == b1 and a == a1) + let join_on = vec![ + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ]; + let bottom_right_join = + hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Join on (B == b1 and C == c and AA = a1) + let top_join_on = vec![ + ( + Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ), + ]; + + let top_join = hash_join_exec( + bottom_left_projection.clone(), + bottom_right_join, + &top_join_on, + &JoinType::Inner, + ); + + let predicate: Arc = binary( + col("c", top_join.schema().deref())?, + Operator::Gt, + lit(1i64), + top_join.schema().deref(), + )?; + + let filter_top_join: Arc = + Arc::new(FilterExec::try_new(predicate, top_join)?); + + // The bottom joins' join key ordering is adjusted based on the top join. And the top join should not introduce additional RepartitionExec + let expected = &[ + "FilterExec: c@6 > 1", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(B@2, b1@6), (C@3, c@2), (AA@1, a1@5)]", + "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)]", + "RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)]", + "RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, filter_top_join.clone(), true); + assert_optimized!(expected, filter_top_join, false); + Ok(()) + } + + #[test] + fn reorder_join_keys_to_left_input() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + + // Join on (a == a1 and b == b1 and c == c1) + let join_on = vec![ + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ]; + + let bottom_left_join = ensure_distribution_helper( + hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), + 10, + true, + )?; + + // Projection(a as A, a as AA, b as B, c as C) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "A".to_string()), + ("a".to_string(), "AA".to_string()), + ("b".to_string(), "B".to_string()), + ("c".to_string(), "C".to_string()), + ]; + let bottom_left_projection = + projection_exec_with_alias(bottom_left_join, alias_pairs); + + // Join on (c == c1 and b == b1 and a == a1) + let join_on = vec![ + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ]; + let bottom_right_join = ensure_distribution_helper( + hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), + 10, + true, + )?; + + // Join on (B == b1 and C == c and AA = a1) + let top_join_on = vec![ + ( + Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ), + ]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + for join_type in join_types { + let top_join = hash_join_exec( + bottom_left_projection.clone(), + bottom_right_join.clone(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)]", &join_type); + + let reordered = reorder_join_keys_to_inputs(top_join)?; + + // The top joins' join key ordering is adjusted based on the children inputs. + let expected = &[ + top_join_plan.as_str(), + "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)]", + "RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)]", + "RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_plan_txt!(expected, reordered); + } + + Ok(()) + } + + #[test] + fn reorder_join_keys_to_right_input() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + + // Join on (a == a1 and b == b1) + let join_on = vec![ + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ]; + let bottom_left_join = ensure_distribution_helper( + hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner), + 10, + true, + )?; + + // Projection(a as A, a as AA, b as B, c as C) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "A".to_string()), + ("a".to_string(), "AA".to_string()), + ("b".to_string(), "B".to_string()), + ("c".to_string(), "C".to_string()), + ]; + let bottom_left_projection = + projection_exec_with_alias(bottom_left_join, alias_pairs); + + // Join on (c == c1 and b == b1 and a == a1) + let join_on = vec![ + ( + Column::new_with_schema("c", &schema()).unwrap(), + Column::new_with_schema("c1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("b", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("a1", &right.schema()).unwrap(), + ), + ]; + let bottom_right_join = ensure_distribution_helper( + hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner), + 10, + true, + )?; + + // Join on (B == b1 and C == c and AA = a1) + let top_join_on = vec![ + ( + Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(), + ), + ( + Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(), + Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(), + ), + ]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + for join_type in join_types { + let top_join = hash_join_exec( + bottom_left_projection.clone(), + bottom_right_join.clone(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)]", &join_type); + + let reordered = reorder_join_keys_to_inputs(top_join)?; + + // The top joins' join key ordering is adjusted based on the children inputs. + let expected = &[ + top_join_plan.as_str(), + "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)]", + "RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)]", + "RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_plan_txt!(expected, reordered); + } + + Ok(()) + } + + #[test] + fn multi_smj_joins() -> Result<()> { + let left = parquet_exec(); + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ("c".to_string(), "c1".to_string()), + ("d".to_string(), "d1".to_string()), + ("e".to_string(), "e1".to_string()), + ]; + let right = projection_exec_with_alias(parquet_exec(), alias_pairs); + + // SortMergeJoin does not support RightSemi and RightAnti join now + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Join on (a == b1) + let join_on = vec![( + Column::new_with_schema("a", &schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + for join_type in join_types { + let join = + sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); + let join_plan = + format!("SortMergeJoin: join_type={join_type}, on=[(a@0, b1@1)]"); + + // Top join on (a == c) + let top_join_on = vec![( + Column::new_with_schema("a", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + let top_join = sort_merge_join_exec( + join.clone(), + parquet_exec(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("SortMergeJoin: join_type={join_type}, on=[(a@0, c@2)]"); + + let expected = match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => + vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortExec: expr=[a@0 ASC]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[b1@1 ASC]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[c@2 ASC]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoin + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test + // cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => vec![ + top_join_plan.as_str(), + // Below 2 operators are differences introduced, when join mode is changed + "SortExec: expr=[a@0 ASC]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + join_plan.as_str(), + "SortExec: expr=[a@0 ASC]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[b1@1 ASC]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[c@2 ASC]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected, top_join.clone(), true, true); + + let expected_first_sort_enforcement = match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => + vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[a@0 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b1@1 ASC]", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoin + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RgihtSemi, RightAnti, Full + // - We need to add one additional Hash Repartition and Roundrobin repartition after + // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => vec![ + top_join_plan.as_str(), + // Below 4 operators are differences introduced, when join mode is changed + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + join_plan.as_str(), + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[a@0 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b1@1 ASC]", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + }; + assert_optimized!(expected_first_sort_enforcement, top_join, false, true); + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Column::new_with_schema("b1", &join.schema()).unwrap(), + Column::new_with_schema("c", &schema()).unwrap(), + )]; + let top_join = sort_merge_join_exec( + join, + parquet_exec(), + &top_join_on, + &join_type, + ); + let top_join_plan = + format!("SortMergeJoin: join_type={join_type}, on=[(b1@6, c@2)]"); + + let expected = match join_type { + // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs + JoinType::Inner | JoinType::Right => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortExec: expr=[a@0 ASC]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[b1@1 ASC]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[c@2 ASC]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs + JoinType::Left | JoinType::Full => vec![ + top_join_plan.as_str(), + "SortExec: expr=[b1@6 ASC]", + "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", + join_plan.as_str(), + "SortExec: expr=[a@0 ASC]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[b1@1 ASC]", + "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[c@2 ASC]", + "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // this match arm cannot be reached + _ => unreachable!() + }; + assert_optimized!(expected, top_join.clone(), true, true); + + let expected_first_sort_enforcement = match join_type { + // Should include 6 RepartitionExecs (3 of them preserves order) and 3 SortExecs + JoinType::Inner | JoinType::Right => vec![ + top_join_plan.as_str(), + join_plan.as_str(), + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[a@0 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b1@1 ASC]", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs + JoinType::Left | JoinType::Full => vec![ + top_join_plan.as_str(), + "SortPreservingRepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, sort_exprs=b1@6 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b1@6 ASC]", + "CoalescePartitionsExec", + join_plan.as_str(), + "SortPreservingRepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, sort_exprs=a@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[a@0 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, sort_exprs=b1@1 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b1@1 ASC]", + "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=c@2 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ], + // this match arm cannot be reached + _ => unreachable!() + }; + assert_optimized!( + expected_first_sort_enforcement, + top_join, + false, + true + ); + } + _ => {} + } + } + + Ok(()) + } + + #[test] + fn smj_join_key_ordering() -> Result<()> { + // group by (a as a1, b as b1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("a".to_string(), "a1".to_string()), + ("b".to_string(), "b1".to_string()), + ], + ); + //Projection(a1 as a3, b1 as b3) + let alias_pairs: Vec<(String, String)> = vec![ + ("a1".to_string(), "a3".to_string()), + ("b1".to_string(), "b3".to_string()), + ]; + let left = projection_exec_with_alias(left, alias_pairs); + + // group by (b, a) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![ + ("b".to_string(), "b".to_string()), + ("a".to_string(), "a".to_string()), + ], + ); + + //Projection(a as a2, b as b2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a2".to_string()), + ("b".to_string(), "b2".to_string()), + ]; + let right = projection_exec_with_alias(right, alias_pairs); + + // Join on (b3 == b2 && a3 == a2) + let join_on = vec![ + ( + Column::new_with_schema("b3", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + ), + ( + Column::new_with_schema("a3", &left.schema()).unwrap(), + Column::new_with_schema("a2", &right.schema()).unwrap(), + ), + ]; + let join = sort_merge_join_exec(left, right.clone(), &join_on, &JoinType::Inner); + + // Only two RepartitionExecs added + let expected = &[ + "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", + "SortExec: expr=[b3@1 ASC,a3@0 ASC]", + "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", + "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortExec: expr=[b2@1 ASC,a2@0 ASC]", + "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", + "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, join.clone(), true, true); + + let expected_first_sort_enforcement = &[ + "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", + "SortPreservingRepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, sort_exprs=b3@1 ASC,a3@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b3@1 ASC,a3@0 ASC]", + "CoalescePartitionsExec", + "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", + "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "SortPreservingRepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, sort_exprs=b2@1 ASC,a2@0 ASC", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[b2@1 ASC,a2@0 ASC]", + "CoalescePartitionsExec", + "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", + "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected_first_sort_enforcement, join, false, true); + Ok(()) + } + + #[test] + fn merge_does_not_need_sort() -> Result<()> { + // see https://github.com/apache/arrow-datafusion/issues/4331 + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + + // Scan some sorted parquet files + let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + + // CoalesceBatchesExec to mimic behavior after a filter + let exec = Arc::new(CoalesceBatchesExec::new(exec, 4096)); + + // Merge from multiple parquet files and keep the data sorted + let exec = Arc::new(SortPreservingMergeExec::new(sort_key, exec)); + + // The optimizer should not add an additional SortExec as the + // data is already sorted + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + "CoalesceBatchesExec: target_batch_size=4096", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_optimized!(expected, exec, true); + // In this case preserving ordering through order preserving operators is not desirable + // (according to flag: bounded_order_preserving_variants) + // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with + // SortExec at the top. + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "CoalesceBatchesExec: target_batch_size=4096", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_optimized!(expected, exec, false); + Ok(()) + } + + #[test] + fn union_to_interleave() -> Result<()> { + // group by (a as a1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + // group by (a as a2) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + + // Union + let plan = Arc::new(UnionExec::new(vec![left, right])); + + // final agg + let plan = + aggregate_exec_with_alias(plan, vec![("a1".to_string(), "a2".to_string())]); + + // Only two RepartitionExecs added, no final RepartitionExec required + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", + "AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[]", + "InterleaveExec", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn added_repartition_to_single_partition() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan = aggregate_exec_with_alias(parquet_exec(), alias); + + let expected = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_deepest_node() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan = aggregate_exec_with_alias(filter_exec(parquet_exec()), alias); + + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + + fn repartition_unsorted_limit() -> Result<()> { + let plan = limit_exec(filter_exec(parquet_exec())); + + let expected = &[ + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // nothing sorts the data, so the local limit doesn't require sorted data either + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_sorted_limit() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); + + let expected = &[ + "GlobalLimitExec: skip=0, fetch=100", + "LocalLimitExec: fetch=100", + // data is sorted so can't repartition here + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_sorted_limit_with_filter() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = + sort_required_exec(filter_exec(sort_exec(sort_key, parquet_exec(), false))); + + let expected = &[ + "SortRequiredExec: [c@2 ASC]", + "FilterExec: c@2 = 0", + // We can use repartition here, ordering requirement by SortRequiredExec + // is still satisfied. + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_ignores_limit() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan = aggregate_exec_with_alias( + limit_exec(filter_exec(limit_exec(parquet_exec()))), + alias, + ); + + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // repartition should happen prior to the filter to maximize parallelism + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + "LocalLimitExec: fetch=100", + // Expect no repartition to happen for local limit + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_ignores_union() -> Result<()> { + let plan = union_exec(vec![parquet_exec(); 5]); + + let expected = &[ + "UnionExec", + // Expect no repartition of ParquetExec + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_through_sort_preserving_merge() -> Result<()> { + // sort preserving merge with non-sorted input + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); + + // need resort as the data was not sorted correctly + let expected = &[ + "SortExec: expr=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + + Ok(()) + } + + #[test] + fn repartition_ignores_sort_preserving_merge() -> Result<()> { + // sort preserving merge already sorted input, + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = sort_preserving_merge_exec( + sort_key.clone(), + parquet_exec_multiple_sorted(vec![sort_key]), + ); + + // should not sort (as the data was already sorted) + // should not repartition, since increased parallelism is not beneficial for SortPReservingMerge + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, plan.clone(), true); + + let expected = &[ + "SortExec: expr=[c@2 ASC]", + "CoalescePartitionsExec", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { + // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + let plan = sort_preserving_merge_exec(sort_key, input); + + // should not repartition / sort (as the data was already sorted) + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + "UnionExec", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, plan.clone(), true); + + let expected = &[ + "SortExec: expr=[c@2 ASC]", + "CoalescePartitionsExec", + "UnionExec", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_does_not_destroy_sort() -> Result<()> { + // SortRequired + // Parquet(sorted) + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = + sort_required_exec(filter_exec(parquet_exec_with_sort(vec![sort_key]))); + + // during repartitioning ordering is preserved + let expected = &[ + "SortRequiredExec: [c@2 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, plan.clone(), true, true); + assert_optimized!(expected, plan, false, true); + Ok(()) + } + + #[test] + fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { + // model a more complicated scenario where one child of a union can be repartitioned for performance + // but the other can not be + // + // Union + // SortRequired + // Parquet(sorted) + // Filter + // Parquet(unsorted) + + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input1 = sort_required_exec(parquet_exec_with_sort(vec![sort_key])); + let input2 = filter_exec(parquet_exec()); + let plan = union_exec(vec![input1, input2]); + + // should not repartition below the SortRequired as that + // branch doesn't benefit from increased parallelism + let expected = &[ + "UnionExec", + // union input 1: no repartitioning + "SortRequiredExec: [c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + // union input 2: should repartition + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_transitively_with_projection() -> Result<()> { + let schema = schema(); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + col("b", &schema).unwrap(), + )) as Arc, + "sum".to_string(), + )]; + // non sorted input + let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); + let sort_key = vec![PhysicalSortExpr { + expr: col("sum", &proj.schema()).unwrap(), + options: SortOptions::default(), + }]; + let plan = sort_preserving_merge_exec(sort_key, proj); + + let expected = &[ + "SortPreservingMergeExec: [sum@0 ASC]", + "SortExec: expr=[sum@0 ASC]", + // Since this projection is not trivial, increasing parallelism is beneficial + "ProjectionExec: expr=[a@0 + b@1 as sum]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + + let expected_first_sort_enforcement = &[ + "SortExec: expr=[sum@0 ASC]", + "CoalescePartitionsExec", + // Since this projection is not trivial, increasing parallelism is beneficial + "ProjectionExec: expr=[a@0 + b@1 as sum]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected_first_sort_enforcement, plan, false); + Ok(()) + } + + #[test] + fn repartition_ignores_transitively_with_projection() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let alias = vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ]; + // sorted input + let plan = sort_required_exec(projection_exec_with_alias( + parquet_exec_multiple_sorted(vec![sort_key]), + alias, + )); + + let expected = &[ + "SortRequiredExec: [c@2 ASC]", + // Since this projection is trivial, increasing parallelism is not beneficial + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_transitively_past_sort_with_projection() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let alias = vec![("a".to_string(), "a".to_string())]; + let plan = sort_preserving_merge_exec( + sort_key.clone(), + sort_exec( + sort_key, + projection_exec_with_alias(parquet_exec(), alias), + true, + ), + ); + + let expected = &[ + "SortExec: expr=[c@2 ASC]", + // Since this projection is trivial, increasing parallelism is not beneficial + "ProjectionExec: expr=[a@0 as a]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + assert_optimized!(expected, plan, false); + Ok(()) + } + + #[test] + fn repartition_transitively_past_sort_with_filter() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); + + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + "SortExec: expr=[a@0 ASC]", + // Expect repartition on the input to the sort (as it can benefit from additional parallelism) + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + + let expected_first_sort_enforcement = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "FilterExec: c@2 = 0", + // Expect repartition on the input of the filter (as it can benefit from additional parallelism) + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected_first_sort_enforcement, plan, false); + Ok(()) + } + + #[test] + #[cfg(feature = "parquet")] + fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan = sort_exec( + sort_key, + projection_exec_with_alias( + filter_exec(parquet_exec()), + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ], + ), + false, + ); + + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + // Expect repartition on the input to the sort (as it can benefit from additional parallelism) + "SortExec: expr=[a@0 ASC]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", + "FilterExec: c@2 = 0", + // repartition is lowest down + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, plan.clone(), true); + + let expected_first_sort_enforcement = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_optimized!(expected_first_sort_enforcement, plan, false); + Ok(()) + } + + #[test] + fn parallelization_single_partition() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan_parquet = aggregate_exec_with_alias(parquet_exec(), alias.clone()); + let plan_csv = aggregate_exec_with_alias(csv_exec(), alias); + + let expected_parquet = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "ParquetExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "CsvExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true, false, 2, true, 10); + assert_optimized!(expected_csv, plan_csv, true, false, 2, true, 10); + Ok(()) + } + + #[test] + /// CsvExec on compressed csv file will not be partitioned + /// (Not able to decompress chunked csv file) + fn parallelization_compressed_csv() -> Result<()> { + let compression_types = [ + FileCompressionType::GZIP, + FileCompressionType::BZIP2, + FileCompressionType::XZ, + FileCompressionType::ZSTD, + FileCompressionType::UNCOMPRESSED, + ]; + + let expected_not_partitioned = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + let expected_partitioned = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "CsvExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + for compression_type in compression_types { + let expected = if compression_type.is_compressed() { + &expected_not_partitioned[..] + } else { + &expected_partitioned[..] + }; + + let plan = aggregate_exec_with_alias( + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new( + "x".to_string(), + 100, + )]], + statistics: Statistics::new_unknown(&schema()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + false, + b',', + b'"', + None, + compression_type, + )), + vec![("a".to_string(), "a".to_string())], + ); + + assert_optimized!(expected, plan, true, false, 2, true, 10); + } + Ok(()) + } + + #[test] + fn parallelization_two_partitions() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan_parquet = + aggregate_exec_with_alias(parquet_exec_multiple(), alias.clone()); + let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias.clone()); + + let expected_parquet = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + // Plan already has two partitions + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + // Plan already has two partitions + "CsvExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true, false, 2, true, 10); + assert_optimized!(expected_csv, plan_csv, true, false, 2, true, 10); + Ok(()) + } + + #[test] + fn parallelization_two_partitions_into_four() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan_parquet = + aggregate_exec_with_alias(parquet_exec_multiple(), alias.clone()); + let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias.clone()); + + let expected_parquet = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + // Multiple source files splitted across partitions + "ParquetExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = [ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + // Multiple source files splitted across partitions + "CsvExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true, false, 4, true, 10); + assert_optimized!(expected_csv, plan_csv, true, false, 4, true, 10); + Ok(()) + } + + #[test] + fn parallelization_sorted_limit() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); + let plan_csv = limit_exec(sort_exec(sort_key.clone(), csv_exec(), false)); + + let expected_parquet = &[ + "GlobalLimitExec: skip=0, fetch=100", + "LocalLimitExec: fetch=100", + // data is sorted so can't repartition here + "SortExec: expr=[c@2 ASC]", + // Doesn't parallelize for SortExec without preserve_partitioning + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = &[ + "GlobalLimitExec: skip=0, fetch=100", + "LocalLimitExec: fetch=100", + // data is sorted so can't repartition here + "SortExec: expr=[c@2 ASC]", + // Doesn't parallelize for SortExec without preserve_partitioning + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_limit_with_filter() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let plan_parquet = limit_exec(filter_exec(sort_exec( + sort_key.clone(), + parquet_exec(), + false, + ))); + let plan_csv = + limit_exec(filter_exec(sort_exec(sort_key.clone(), csv_exec(), false))); + + let expected_parquet = &[ + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // even though data is sorted, we can use repartition here. Since + // ordering is not used in subsequent stages anyway. + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + // SortExec doesn't benefit from input partitioning + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = &[ + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // even though data is sorted, we can use repartition here. Since + // ordering is not used in subsequent stages anyway. + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "SortExec: expr=[c@2 ASC]", + // SortExec doesn't benefit from input partitioning + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_ignores_limit() -> Result<()> { + let alias = vec![("a".to_string(), "a".to_string())]; + let plan_parquet = aggregate_exec_with_alias( + limit_exec(filter_exec(limit_exec(parquet_exec()))), + alias.clone(), + ); + let plan_csv = aggregate_exec_with_alias( + limit_exec(filter_exec(limit_exec(csv_exec()))), + alias.clone(), + ); + + let expected_parquet = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // repartition should happen prior to the filter to maximize parallelism + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + // Limit doesn't benefit from input partitioning - no parallelism + "LocalLimitExec: fetch=100", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + "CoalescePartitionsExec", + "LocalLimitExec: fetch=100", + "FilterExec: c@2 = 0", + // repartition should happen prior to the filter to maximize parallelism + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "GlobalLimitExec: skip=0, fetch=100", + // Limit doesn't benefit from input partitioning - no parallelism + "LocalLimitExec: fetch=100", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_union_inputs() -> Result<()> { + let plan_parquet = union_exec(vec![parquet_exec(); 5]); + let plan_csv = union_exec(vec![csv_exec(); 5]); + + let expected_parquet = &[ + "UnionExec", + // Union doesn't benefit from input partitioning - no parallelism + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + let expected_csv = &[ + "UnionExec", + // Union doesn't benefit from input partitioning - no parallelism + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + // sort preserving merge already sorted input, + let plan_parquet = sort_preserving_merge_exec( + sort_key.clone(), + parquet_exec_with_sort(vec![sort_key.clone()]), + ); + let plan_csv = sort_preserving_merge_exec( + sort_key.clone(), + csv_exec_with_sort(vec![sort_key.clone()]), + ); + + // parallelization is not beneficial for SortPreservingMerge + let expected_parquet = &[ + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + let expected_csv = &[ + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_sort_preserving_merge_with_union() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) + let input_parquet = + union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); + let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); + let plan_csv = sort_preserving_merge_exec(sort_key.clone(), input_csv); + + // should not repartition (union doesn't benefit from increased parallelism) + // should not sort (as the data was already sorted) + let expected_parquet = &[ + "SortPreservingMergeExec: [c@2 ASC]", + "UnionExec", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + let expected_csv = &[ + "SortPreservingMergeExec: [c@2 ASC]", + "UnionExec", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_does_not_benefit() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + // SortRequired + // Parquet(sorted) + let plan_parquet = + sort_required_exec(parquet_exec_with_sort(vec![sort_key.clone()])); + let plan_csv = sort_required_exec(csv_exec_with_sort(vec![sort_key])); + + // no parallelization, because SortRequiredExec doesn't benefit from increased parallelism + let expected_parquet = &[ + "SortRequiredExec: [c@2 ASC]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + let expected_csv = &[ + "SortRequiredExec: [c@2 ASC]", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { + // sorted input + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + + //Projection(a as a2, b as b2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a2".to_string()), + ("c".to_string(), "c2".to_string()), + ]; + let proj_parquet = projection_exec_with_alias( + parquet_exec_with_sort(vec![sort_key.clone()]), + alias_pairs.clone(), + ); + let sort_key_after_projection = vec![PhysicalSortExpr { + expr: col("c2", &proj_parquet.schema()).unwrap(), + options: SortOptions::default(), + }]; + let plan_parquet = + sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); + let expected = &[ + "SortPreservingMergeExec: [c2@1 ASC]", + " ProjectionExec: expr=[a@0 as a2, c@2 as c2]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + plans_matches_expected!(expected, &plan_parquet); + + // data should not be repartitioned / resorted + let expected_parquet = &[ + "ProjectionExec: expr=[a@0 as a2, c@2 as c2]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected_parquet, plan_parquet, true); + Ok(()) + } + + #[test] + fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { + // sorted input + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + + //Projection(a as a2, b as b2) + let alias_pairs: Vec<(String, String)> = vec![ + ("a".to_string(), "a2".to_string()), + ("c".to_string(), "c2".to_string()), + ]; + + let proj_csv = + projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); + let sort_key_after_projection = vec![PhysicalSortExpr { + expr: col("c2", &proj_csv.schema()).unwrap(), + options: SortOptions::default(), + }]; + let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); + let expected = &[ + "SortPreservingMergeExec: [c2@1 ASC]", + " ProjectionExec: expr=[a@0 as a2, c@2 as c2]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + ]; + plans_matches_expected!(expected, &plan_csv); + + // data should not be repartitioned / resorted + let expected_csv = &[ + "ProjectionExec: expr=[a@0 as a2, c@2 as c2]", + "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], has_header=false", + ]; + + assert_optimized!(expected_csv, plan_csv, true); + Ok(()) + } + + #[test] + fn remove_redundant_roundrobins() -> Result<()> { + let input = parquet_exec(); + let repartition = repartition_exec(repartition_exec(input)); + let physical_plan = repartition_exec(filter_exec(repartition)); + let expected = &[ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", + " FilterExec: c@2 = 0", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + plans_matches_expected!(expected, &physical_plan); + + let expected = &[ + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn preserve_ordering_through_repartition() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); + + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + "FilterExec: c@2 = 0", + "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c@2 ASC", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + // last flag sets config.optimizer.bounded_order_preserving_variants + assert_optimized!(expected, physical_plan.clone(), true, true); + assert_optimized!(expected, physical_plan, false, true); + + Ok(()) + } + + #[test] + fn do_not_preserve_ordering_through_repartition() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); + + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn no_need_for_sort_after_filter() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); + + let expected = &[ + // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. + "CoalescePartitionsExec", + // Since after this stage c is constant. c@2 ASC ordering is already satisfied. + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn do_not_preserve_ordering_through_repartition2() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key]); + + let sort_req = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); + + let expected = &[ + "SortPreservingMergeExec: [a@0 ASC]", + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + + let expected = &[ + "SortExec: expr=[a@0 ASC]", + "CoalescePartitionsExec", + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn do_not_preserve_ordering_through_repartition3() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key]); + let physical_plan = filter_exec(input); + + let expected = &[ + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn do_not_put_sort_when_input_is_invalid() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec(); + let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); + let expected = &[ + // Ordering requirement of sort required exec is NOT satisfied + // by existing ordering at the source. + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + assert_plan_txt!(expected, physical_plan); + + let expected = &[ + "SortRequiredExec: [a@0 ASC]", + // Since at the start of the rule ordering requirement is not satisfied + // EnforceDistribution rule doesn't satisfy this requirement either. + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + let mut config = ConfigOptions::new(); + config.execution.target_partitions = 10; + config.optimizer.enable_round_robin_repartition = true; + config.optimizer.prefer_existing_sort = false; + let distribution_plan = + EnforceDistribution::new().optimize(physical_plan, &config)?; + assert_plan_txt!(expected, distribution_plan); + + Ok(()) + } + + #[test] + fn put_sort_when_input_is_valid() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]; + let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); + let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); + + let expected = &[ + // Ordering requirement of sort required exec is satisfied + // by existing ordering at the source. + "SortRequiredExec: [a@0 ASC]", + "FilterExec: c@2 = 0", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + assert_plan_txt!(expected, physical_plan); + + let expected = &[ + "SortRequiredExec: [a@0 ASC]", + // Since at the start of the rule ordering requirement is satisfied + // EnforceDistribution rule satisfy this requirement also. + // ordering is re-satisfied by introduction of SortExec. + "SortExec: expr=[a@0 ASC]", + "FilterExec: c@2 = 0", + // ordering is lost here + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + + let mut config = ConfigOptions::new(); + config.execution.target_partitions = 10; + config.optimizer.enable_round_robin_repartition = true; + config.optimizer.prefer_existing_sort = false; + let distribution_plan = + EnforceDistribution::new().optimize(physical_plan, &config)?; + assert_plan_txt!(expected, distribution_plan); + + Ok(()) + } + + #[test] + fn do_not_add_unnecessary_hash() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let alias = vec![("a".to_string(), "a".to_string())]; + let input = parquet_exec_with_sort(vec![sort_key]); + let physical_plan = aggregate_exec_with_alias(input, alias.clone()); + + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + // Make sure target partition number is 1. In this case hash repartition is unnecessary + assert_optimized!(expected, physical_plan.clone(), true, false, 1, false, 1024); + assert_optimized!(expected, physical_plan, false, false, 1, false, 1024); + + Ok(()) + } + + #[test] + fn do_not_add_unnecessary_hash2() -> Result<()> { + let schema = schema(); + let sort_key = vec![PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + let alias = vec![("a".to_string(), "a".to_string())]; + let input = parquet_exec_multiple_sorted(vec![sort_key]); + let aggregate = aggregate_exec_with_alias(input, alias.clone()); + let physical_plan = aggregate_exec_with_alias(aggregate, alias.clone()); + + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + // Since hash requirements of this operator is satisfied. There shouldn't be + // a hash repartition here + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2", + "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", + ]; + + // Make sure target partition number is larger than 2 (e.g partition number at the source). + assert_optimized!(expected, physical_plan.clone(), true, false, 4, false, 1024); + assert_optimized!(expected, physical_plan, false, false, 4, false, 1024); + + Ok(()) + } + + #[test] + fn optimize_away_unnecessary_repartition() -> Result<()> { + let physical_plan = coalesce_partitions_exec(repartition_exec(parquet_exec())); + let expected = &[ + "CoalescePartitionsExec", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + plans_matches_expected!(expected, physical_plan.clone()); + + let expected = + &["ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]"]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } + + #[test] + fn optimize_away_unnecessary_repartition2() -> Result<()> { + let physical_plan = filter_exec(repartition_exec(coalesce_partitions_exec( + filter_exec(repartition_exec(parquet_exec())), + ))); + let expected = &[ + "FilterExec: c@2 = 0", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CoalescePartitionsExec", + " FilterExec: c@2 = 0", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + plans_matches_expected!(expected, physical_plan.clone()); + + let expected = &[ + "FilterExec: c@2 = 0", + "FilterExec: c@2 = 0", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + + assert_optimized!(expected, physical_plan.clone(), true); + assert_optimized!(expected, physical_plan, false); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs similarity index 62% rename from datafusion/core/src/physical_optimizer/sort_enforcement.rs rename to datafusion/core/src/physical_optimizer/enforce_sorting.rs index a79552de49de2..14715ede500ad 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -17,8 +17,8 @@ //! EnforceSorting optimizer rule inspects the physical plan with respect //! to local sorting requirements and does the following: -//! - Adds a [SortExec] when a requirement is not met, -//! - Removes an already-existing [SortExec] if it is possible to prove +//! - Adds a [`SortExec`] when a requirement is not met, +//! - Removes an already-existing [`SortExec`] if it is possible to prove //! that this sort is unnecessary //! The rule can work on valid *and* invalid physical plans with respect to //! sorting requirements, but always produces a valid physical plan in this sense. @@ -34,33 +34,35 @@ //! in the physical plan. The first sort is unnecessary since its result is overwritten //! by another [`SortExec`]. Therefore, this rule removes it from the physical plan. +use std::sync::Arc; + use crate::config::ConfigOptions; use crate::error::Result; +use crate::physical_optimizer::replace_with_order_preserving_variants::{ + replace_with_order_preserving_variants, OrderPreservationContext, +}; use crate::physical_optimizer::sort_pushdown::{pushdown_sorts, SortPushDown}; use crate::physical_optimizer::utils::{ - add_sort_above, find_indices, is_coalesce_partitions, is_limit, is_repartition, - is_sort, is_sort_preserving_merge, is_sorted, is_union, is_window, - merge_and_order_indices, set_difference, + add_sort_above, is_coalesce_partitions, is_limit, is_repartition, is_sort, + is_sort_preserving_merge, is_union, is_window, ExecTree, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, }; -use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; -use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::utils::{get_at_indices, longest_consecutive_prefix}; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::utils::{ - convert_to_expr, get_indices_of_matching_exprs, ordering_satisfy, - ordering_satisfy_requirement_concrete, +use crate::physical_plan::{ + with_new_children_if_necessary, Distribution, ExecutionPlan, InputOrderMode, }; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement}; -use itertools::{concat, izip, Itertools}; -use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::{plan_err, DataFusionError}; +use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; + +use datafusion_physical_plan::repartition::RepartitionExec; +use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the /// ones it can prove unnecessary. @@ -74,42 +76,6 @@ impl EnforceSorting { } } -/// This object implements a tree that we use while keeping track of paths -/// leading to [`SortExec`]s. -#[derive(Debug, Clone)] -struct ExecTree { - /// The `ExecutionPlan` associated with this node - pub plan: Arc, - /// Child index of the plan in its parent - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors - pub children: Vec, -} - -impl ExecTree { - /// Create new Exec tree - pub fn new( - plan: Arc, - idx: usize, - children: Vec, - ) -> Self { - ExecTree { - plan, - idx, - children, - } - } - - /// This function returns the executors at the leaves of the tree. - fn get_leaves(&self) -> Vec> { - if self.children.is_empty() { - vec![self.plan.clone()] - } else { - concat(self.children.iter().map(|e| e.get_leaves())) - } - } -} - /// This object is used within the [`EnforceSorting`] rule to track the closest /// [`SortExec`] descendant(s) for every child of a plan. #[derive(Debug, Clone)] @@ -123,7 +89,7 @@ struct PlanWithCorrespondingSort { } impl PlanWithCorrespondingSort { - pub fn new(plan: Arc) -> Self { + fn new(plan: Arc) -> Self { let length = plan.children().len(); PlanWithCorrespondingSort { plan, @@ -131,7 +97,7 @@ impl PlanWithCorrespondingSort { } } - pub fn new_from_children_nodes( + fn new_from_children_nodes( children_nodes: Vec, parent_plan: Arc, ) -> Result { @@ -181,11 +147,11 @@ impl PlanWithCorrespondingSort { Ok(PlanWithCorrespondingSort { plan, sort_onwards }) } - pub fn children(&self) -> Vec { + fn children(&self) -> Vec { self.plan .children() .into_iter() - .map(|child| PlanWithCorrespondingSort::new(child)) + .map(PlanWithCorrespondingSort::new) .collect() } } @@ -238,7 +204,7 @@ struct PlanWithCorrespondingCoalescePartitions { } impl PlanWithCorrespondingCoalescePartitions { - pub fn new(plan: Arc) -> Self { + fn new(plan: Arc) -> Self { let length = plan.children().len(); PlanWithCorrespondingCoalescePartitions { plan, @@ -246,7 +212,7 @@ impl PlanWithCorrespondingCoalescePartitions { } } - pub fn new_from_children_nodes( + fn new_from_children_nodes( children_nodes: Vec, parent_plan: Arc, ) -> Result { @@ -297,11 +263,11 @@ impl PlanWithCorrespondingCoalescePartitions { }) } - pub fn children(&self) -> Vec { + fn children(&self) -> Vec { self.plan .children() .into_iter() - .map(|child| PlanWithCorrespondingCoalescePartitions::new(child)) + .map(PlanWithCorrespondingCoalescePartitions::new) .collect() } } @@ -366,9 +332,20 @@ impl PhysicalOptimizerRule for EnforceSorting { } else { adjusted.plan }; + let plan_with_pipeline_fixer = OrderPreservationContext::new(new_plan); + let updated_plan = + plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + replace_with_order_preserving_variants( + plan_with_pipeline_fixer, + false, + true, + config, + ) + })?; + // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: - let sort_pushdown = SortPushDown::init(new_plan); + let sort_pushdown = SortPushDown::init(updated_plan.plan); let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; Ok(adjusted.plan) } @@ -420,15 +397,20 @@ fn parallelize_sorts( // SortPreservingMergeExec cascade to parallelize sorting. let mut prev_layer = plan.clone(); update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; - let sort_exprs = get_sort_exprs(&plan)?; - add_sort_above(&mut prev_layer, sort_exprs.to_vec())?; - let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer); + let (sort_exprs, fetch) = get_sort_exprs(&plan)?; + add_sort_above( + &mut prev_layer, + &PhysicalSortRequirement::from_sort_exprs(sort_exprs), + fetch, + ); + let spm = SortPreservingMergeExec::new(sort_exprs.to_vec(), prev_layer) + .with_fetch(fetch); return Ok(Transformed::Yes(PlanWithCorrespondingCoalescePartitions { plan: Arc::new(spm), coalesce_onwards: vec![None], })); } else if is_coalesce_partitions(&plan) { - // There is an unnecessary `CoalescePartitionExec` in the plan. + // There is an unnecessary `CoalescePartitionsExec` in the plan. let mut prev_layer = plan.clone(); update_child_to_remove_coalesce(&mut prev_layer, &mut coalesce_onwards[0])?; let new_plan = plan.with_new_children(vec![prev_layer])?; @@ -468,18 +450,14 @@ fn ensure_sorting( { let physical_ordering = child.output_ordering(); match (required_ordering, physical_ordering) { - (Some(required_ordering), Some(physical_ordering)) => { - if !ordering_satisfy_requirement_concrete( - physical_ordering, - &required_ordering, - || child.equivalence_properties(), - || child.ordering_equivalence_properties(), - ) { + (Some(required_ordering), Some(_)) => { + if !child + .equivalence_properties() + .ordering_satisfy_requirement(&required_ordering) + { // Make sure we preserve the ordering requirements: update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; - let sort_expr = - PhysicalSortRequirement::to_sort_exprs(required_ordering); - add_sort_above(child, sort_expr)?; + add_sort_above(child, &required_ordering, None); if is_sort(child) { *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); } else { @@ -489,8 +467,7 @@ fn ensure_sorting( } (Some(required), None) => { // Ordering requirement is not met, we should add a `SortExec` to the plan. - let sort_expr = PhysicalSortRequirement::to_sort_exprs(required); - add_sort_above(child, sort_expr)?; + add_sort_above(child, &required, None); *sort_onwards = Some(ExecTree::new(child.clone(), idx, vec![])); } (None, Some(_)) => { @@ -500,7 +477,9 @@ fn ensure_sorting( update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; } } - (None, None) => {} + (None, None) => { + update_child_to_remove_unnecessary_sort(child, sort_onwards, &plan)?; + } } } // For window expressions, we can remove some sorts when we can @@ -516,9 +495,10 @@ fn ensure_sorting( { // This SortPreservingMergeExec is unnecessary, input already has a // single partition. + sort_onwards.truncate(1); return Ok(Transformed::Yes(PlanWithCorrespondingSort { - plan: children[0].clone(), - sort_onwards: vec![sort_onwards[0].clone()], + plan: children.swap_remove(0), + sort_onwards, })); } Ok(Transformed::Yes(PlanWithCorrespondingSort { @@ -535,13 +515,12 @@ fn analyze_immediate_sort_removal( ) -> Option { if let Some(sort_exec) = plan.as_any().downcast_ref::() { let sort_input = sort_exec.input().clone(); + // If this sort is unnecessary, we should remove it: - if ordering_satisfy( - sort_input.output_ordering(), - sort_exec.output_ordering(), - || sort_input.equivalence_properties(), - || sort_input.ordering_equivalence_properties(), - ) { + if sort_input + .equivalence_properties() + .ordering_satisfy(sort_exec.output_ordering().unwrap_or(&[])) + { // Since we know that a `SortExec` has exactly one child, // we can use the zero index safely: return Some( @@ -582,107 +561,68 @@ fn analyze_window_sort_removal( sort_tree: &mut ExecTree, window_exec: &Arc, ) -> Result> { - let (window_expr, partition_keys) = if let Some(exec) = - window_exec.as_any().downcast_ref::() - { - (exec.window_expr(), &exec.partition_keys) - } else if let Some(exec) = window_exec.as_any().downcast_ref::() { - (exec.window_expr(), &exec.partition_keys) - } else { - return Err(DataFusionError::Plan( - "Expects to receive either WindowAggExec of BoundedWindowAggExec".to_string(), - )); - }; + let requires_single_partition = matches!( + window_exec.required_input_distribution()[sort_tree.idx], + Distribution::SinglePartition + ); + let mut window_child = + remove_corresponding_sort_from_sub_plan(sort_tree, requires_single_partition)?; + let (window_expr, new_window) = + if let Some(exec) = window_exec.as_any().downcast_ref::() { + ( + exec.window_expr(), + get_best_fitting_window( + exec.window_expr(), + &window_child, + &exec.partition_keys, + )?, + ) + } else if let Some(exec) = window_exec.as_any().downcast_ref::() { + ( + exec.window_expr(), + get_best_fitting_window( + exec.window_expr(), + &window_child, + &exec.partition_keys, + )?, + ) + } else { + return plan_err!( + "Expects to receive either WindowAggExec of BoundedWindowAggExec" + ); + }; let partitionby_exprs = window_expr[0].partition_by(); - let orderby_sort_keys = window_expr[0].order_by(); - - // search_flags stores return value of the can_skip_sort. - // `None` case represents `SortExec` cannot be removed. - // `PartitionSearch` mode stores at which mode executor should work to remove - // `SortExec` before it, - // `bool` stores whether or not we need to reverse window expressions to remove `SortExec`. - let mut search_flags = None; - for sort_any in sort_tree.get_leaves() { - // Variable `sort_any` will either be a `SortExec` or a - // `SortPreservingMergeExec`, and both have a single child. - // Therefore, we can use the 0th index without loss of generality. - let sort_input = &sort_any.children()[0]; - let flags = can_skip_sort(partitionby_exprs, orderby_sort_keys, sort_input)?; - if flags.is_some() && (search_flags.is_none() || search_flags == flags) { - search_flags = flags; - continue; - } - // We can not skip the sort, or window reversal requirements are not - // uniform; then sort removal is not possible -- we immediately return. - return Ok(None); - } - let (should_reverse, partition_search_mode) = if let Some(search_flags) = search_flags - { - search_flags - } else { - // We can not skip the sort return: - return Ok(None); - }; - let is_unbounded = unbounded_output(window_exec); - if !is_unbounded && partition_search_mode != PartitionSearchMode::Sorted { - // Executor has bounded input and `partition_search_mode` is not `PartitionSearchMode::Sorted` - // in this case removing the sort is not helpful, return: - return Ok(None); - }; - let new_window_expr = if should_reverse { - window_expr - .iter() - .map(|e| e.get_reverse_expr()) - .collect::>>() + if let Some(new_window) = new_window { + // We were able to change the window to accommodate the input, use it: + Ok(Some(PlanWithCorrespondingSort::new(new_window))) } else { - Some(window_expr.to_vec()) - }; - if let Some(window_expr) = new_window_expr { - let requires_single_partition = matches!( - window_exec.required_input_distribution()[sort_tree.idx], - Distribution::SinglePartition - ); - let mut new_child = remove_corresponding_sort_from_sub_plan( - sort_tree, - requires_single_partition, - )?; - let new_schema = new_child.schema(); + // We were unable to change the window to accommodate the input, so we + // will insert a sort. + let reqs = window_exec + .required_input_ordering() + .swap_remove(0) + .unwrap_or_default(); + // Satisfy the ordering requirement so that the window can run: + add_sort_above(&mut window_child, &reqs, None); let uses_bounded_memory = window_expr.iter().all(|e| e.uses_bounded_memory()); - // If all window expressions can run with bounded memory, choose the - // bounded window variant: - let new_plan = if uses_bounded_memory { + let new_window = if uses_bounded_memory { Arc::new(BoundedWindowAggExec::try_new( - window_expr, - new_child, - new_schema, - partition_keys.to_vec(), - partition_search_mode, + window_expr.to_vec(), + window_child, + partitionby_exprs.to_vec(), + InputOrderMode::Sorted, )?) as _ } else { - if partition_search_mode != PartitionSearchMode::Sorted { - // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. - // Hence, if `partition_search_mode` is not `PartitionSearchMode::Sorted` we should convert - // input ordering such that it can work with PartitionSearchMode::Sorted (add `SortExec`). - // Effectively `WindowAggExec` works only in PartitionSearchMode::Sorted mode. - let reqs = window_exec - .required_input_ordering() - .swap_remove(0) - .unwrap_or(vec![]); - let sort_expr = PhysicalSortRequirement::to_sort_exprs(reqs); - add_sort_above(&mut new_child, sort_expr)?; - }; Arc::new(WindowAggExec::try_new( - window_expr, - new_child, - new_schema, - partition_keys.to_vec(), + window_expr.to_vec(), + window_child, + partitionby_exprs.to_vec(), )?) as _ }; - return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); + Ok(Some(PlanWithCorrespondingSort::new(new_window))) } - Ok(None) } /// Updates child to remove the unnecessary [`CoalescePartitionsExec`] below it. @@ -708,7 +648,7 @@ fn remove_corresponding_coalesce_in_sub_plan( && is_repartition(&new_plan) && is_repartition(parent) { - new_plan = new_plan.children()[0].clone() + new_plan = new_plan.children().swap_remove(0) } new_plan } else { @@ -748,7 +688,7 @@ fn remove_corresponding_sort_from_sub_plan( ) -> Result> { // A `SortExec` is always at the bottom of the tree. let mut updated_plan = if is_sort(&sort_onwards.plan) { - sort_onwards.plan.children()[0].clone() + sort_onwards.plan.children().swap_remove(0) } else { let plan = &sort_onwards.plan; let mut children = plan.children(); @@ -760,8 +700,18 @@ fn remove_corresponding_sort_from_sub_plan( children[item.idx] = remove_corresponding_sort_from_sub_plan(item, requires_single_partition)?; } + // Replace with variants that do not preserve order. if is_sort_preserving_merge(plan) { - children[0].clone() + children.swap_remove(0) + } else if let Some(repartition) = plan.as_any().downcast_ref::() + { + Arc::new( + // By default, RepartitionExec does not preserve order + RepartitionExec::try_new( + children.swap_remove(0), + repartition.partitioning().clone(), + )?, + ) } else { plan.clone().with_new_children(children)? } @@ -779,222 +729,53 @@ fn remove_corresponding_sort_from_sub_plan( updated_plan, )); } else { - updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan.clone())); + updated_plan = Arc::new(CoalescePartitionsExec::new(updated_plan)); } } Ok(updated_plan) } /// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. -fn get_sort_exprs(sort_any: &Arc) -> Result<&[PhysicalSortExpr]> { +fn get_sort_exprs( + sort_any: &Arc, +) -> Result<(&[PhysicalSortExpr], Option)> { if let Some(sort_exec) = sort_any.as_any().downcast_ref::() { - Ok(sort_exec.expr()) + Ok((sort_exec.expr(), sort_exec.fetch())) } else if let Some(sort_preserving_merge_exec) = sort_any.as_any().downcast_ref::() { - Ok(sort_preserving_merge_exec.expr()) - } else { - Err(DataFusionError::Plan( - "Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec" - .to_string(), + Ok(( + sort_preserving_merge_exec.expr(), + sort_preserving_merge_exec.fetch(), )) - } -} - -/// Compares physical ordering (output ordering of input executor) with -/// `partitionby_exprs` and `orderby_keys` -/// to decide whether existing ordering is sufficient to run current window executor. -/// A `None` return value indicates that we can not remove the sort in question (input ordering is not -/// sufficient to run current window executor). -/// A `Some((bool, PartitionSearchMode))` value indicates window executor can be run with existing input ordering -/// (Hence we can remove [`SortExec`] before it). -/// `bool` represents whether we should reverse window executor to remove [`SortExec`] before it. -/// `PartitionSearchMode` represents, in which mode Window Executor should work with existing ordering. -fn can_skip_sort( - partitionby_exprs: &[Arc], - orderby_keys: &[PhysicalSortExpr], - input: &Arc, -) -> Result> { - let physical_ordering = if let Some(physical_ordering) = input.output_ordering() { - physical_ordering - } else { - // If there is no physical ordering, there is no way to remove a - // sort, so immediately return. - return Ok(None); - }; - let orderby_exprs = convert_to_expr(orderby_keys); - let physical_ordering_exprs = convert_to_expr(physical_ordering); - let equal_properties = || input.equivalence_properties(); - // indices of the order by expressions among input ordering expressions - let ob_indices = get_indices_of_matching_exprs( - &orderby_exprs, - &physical_ordering_exprs, - equal_properties, - ); - if ob_indices.len() != orderby_exprs.len() { - // If all order by expressions are not in the input ordering, - // there is no way to remove a sort -- immediately return: - return Ok(None); - } - // indices of the partition by expressions among input ordering expressions - let pb_indices = get_indices_of_matching_exprs( - partitionby_exprs, - &physical_ordering_exprs, - equal_properties, - ); - let ordered_merged_indices = merge_and_order_indices(&pb_indices, &ob_indices); - // Indices of order by columns that doesn't seen in partition by - // Equivalently (Order by columns) ∖ (Partition by columns) where `∖` represents set difference. - let unique_ob_indices = set_difference(&ob_indices, &pb_indices); - if !is_sorted(&unique_ob_indices) { - // ORDER BY indices should be ascending ordered - return Ok(None); - } - let first_n = longest_consecutive_prefix(ordered_merged_indices); - let furthest_ob_index = *unique_ob_indices.last().unwrap_or(&0); - // Cannot skip sort if last order by index is not within consecutive prefix. - // For instance, if input is ordered by a,b,c,d - // for expression `PARTITION BY a, ORDER BY b, d`, `first_n` would be 2 (meaning a, b defines a prefix for input ordering) - // Whereas `furthest_ob_index` would be 3 (column d occurs at the 3rd index of the existing ordering.) - // Hence existing ordering is not sufficient to run current Executor. - // However, for expression `PARTITION BY a, ORDER BY b, c, d`, `first_n` would be 4 (meaning a, b, c, d defines a prefix for input ordering) - // Similarly, `furthest_ob_index` would be 3 (column d occurs at the 3rd index of the existing ordering.) - // Hence existing ordering would be sufficient to run current Executor. - if first_n <= furthest_ob_index { - return Ok(None); - } - let input_orderby_columns = get_at_indices(physical_ordering, &unique_ob_indices)?; - let expected_orderby_columns = - get_at_indices(orderby_keys, find_indices(&ob_indices, &unique_ob_indices)?)?; - let should_reverse = if let Some(should_reverse) = check_alignments( - &input.schema(), - &input_orderby_columns, - &expected_orderby_columns, - )? { - should_reverse - } else { - // If ordering directions are not aligned. We cannot calculate result without changing existing ordering. - return Ok(None); - }; - - let ordered_pb_indices = pb_indices.iter().copied().sorted().collect::>(); - // Determine how many elements in the partition by columns defines a consecutive range from zero. - let first_n = longest_consecutive_prefix(&ordered_pb_indices); - let mode = if first_n == partitionby_exprs.len() { - // All of the partition by columns defines a consecutive range from zero. - PartitionSearchMode::Sorted - } else if first_n > 0 { - // All of the partition by columns defines a consecutive range from zero. - let ordered_range = &ordered_pb_indices[0..first_n]; - let input_pb_exprs = get_at_indices(&physical_ordering_exprs, ordered_range)?; - let partially_ordered_indices = get_indices_of_matching_exprs( - &input_pb_exprs, - partitionby_exprs, - equal_properties, - ); - PartitionSearchMode::PartiallySorted(partially_ordered_indices) - } else { - // None of the partition by columns defines a consecutive range from zero. - PartitionSearchMode::Linear - }; - - Ok(Some((should_reverse, mode))) -} - -fn check_alignments( - schema: &SchemaRef, - physical_ordering: &[PhysicalSortExpr], - required: &[PhysicalSortExpr], -) -> Result> { - let res = izip!(physical_ordering, required) - .map(|(lhs, rhs)| check_alignment(schema, lhs, rhs)) - .collect::>>>()?; - Ok(if let Some(res) = res { - if !res.is_empty() { - let first = res[0]; - let all_same = res.into_iter().all(|elem| elem == first); - all_same.then_some(first) - } else { - Some(false) - } - } else { - // Cannot skip some of the requirements in the input. - None - }) -} - -/// Compares `physical_ordering` and `required` ordering, decides whether -/// alignments match. A `None` return value indicates that current column is -/// not aligned. A `Some(bool)` value indicates otherwise, and signals whether -/// we should reverse the window expression in order to avoid sorting. -fn check_alignment( - input_schema: &SchemaRef, - physical_ordering: &PhysicalSortExpr, - required: &PhysicalSortExpr, -) -> Result> { - Ok(if required.expr.eq(&physical_ordering.expr) { - let physical_opts = physical_ordering.options; - let required_opts = required.options; - if required.expr.nullable(input_schema)? { - let reverse = physical_opts == !required_opts; - (reverse || physical_opts == required_opts).then_some(reverse) - } else { - // If the column is not nullable, NULLS FIRST/LAST is not important. - Some(physical_opts.descending != required_opts.descending) - } } else { - None - }) -} - -// Get unbounded_output information for the executor -fn unbounded_output(plan: &Arc) -> bool { - let res = if plan.children().is_empty() { - plan.unbounded_output(&[]) - } else { - let children_unbounded_output = plan - .children() - .iter() - .map(unbounded_output) - .collect::>(); - plan.unbounded_output(&children_unbounded_output) - }; - res.unwrap_or(true) + plan_err!("Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec") + } } #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; - use crate::physical_optimizer::dist_enforcement::EnforceDistribution; - use crate::physical_plan::aggregates::PhysicalGroupBy; - use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::joins::utils::JoinOn; - use crate::physical_plan::joins::SortMergeJoinExec; - use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; - use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::union::UnionExec; - use crate::physical_plan::windows::create_window_expr; - use crate::physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, + use crate::physical_optimizer::enforce_distribution::EnforceDistribution; + use crate::physical_optimizer::test_utils::{ + aggregate_exec, bounded_window_exec, coalesce_batches_exec, + coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, + limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, + repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, spr_repartition_exec, union_exec, }; - use crate::physical_plan::{displayable, Partitioning}; - use crate::prelude::SessionContext; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::csv_exec_sorted; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; + use datafusion_common::Result; use datafusion_expr::JoinType; - use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::expressions::{col, NotExpr}; - use datafusion_physical_expr::PhysicalSortExpr; - use std::sync::Arc; + use datafusion_physical_expr::expressions::{col, Column, NotExpr}; fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); @@ -1022,385 +803,22 @@ mod tests { Ok(schema) } - // Util function to get string representation of a physical plan - fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() - } - - #[tokio::test] - async fn test_is_column_aligned_nullable() -> Result<()> { - let schema = create_test_schema()?; - let params = vec![ - ((true, true), (false, false), Some(true)), - ((true, true), (false, true), None), - ((true, true), (true, false), None), - ((true, false), (false, true), Some(true)), - ((true, false), (false, false), None), - ((true, false), (true, true), None), - ]; - for ( - (physical_desc, physical_nulls_first), - (req_desc, req_nulls_first), - expected, - ) in params - { - let physical_ordering = PhysicalSortExpr { - expr: col("nullable_col", &schema)?, - options: SortOptions { - descending: physical_desc, - nulls_first: physical_nulls_first, - }, - }; - let required_ordering = PhysicalSortExpr { - expr: col("nullable_col", &schema)?, - options: SortOptions { - descending: req_desc, - nulls_first: req_nulls_first, - }, - }; - let res = check_alignment(&schema, &physical_ordering, &required_ordering)?; - assert_eq!(res, expected); - } - - Ok(()) - } - - #[tokio::test] - async fn test_is_column_aligned_non_nullable() -> Result<()> { - let schema = create_test_schema()?; - - let params = vec![ - ((true, true), (false, false), Some(true)), - ((true, true), (false, true), Some(true)), - ((true, true), (true, false), Some(false)), - ((true, false), (false, true), Some(true)), - ((true, false), (false, false), Some(true)), - ((true, false), (true, true), Some(false)), - ]; - for ( - (physical_desc, physical_nulls_first), - (req_desc, req_nulls_first), - expected, - ) in params - { - let physical_ordering = PhysicalSortExpr { - expr: col("non_nullable_col", &schema)?, - options: SortOptions { - descending: physical_desc, - nulls_first: physical_nulls_first, - }, - }; - let required_ordering = PhysicalSortExpr { - expr: col("non_nullable_col", &schema)?, - options: SortOptions { - descending: req_desc, - nulls_first: req_nulls_first, - }, - }; - let res = check_alignment(&schema, &physical_ordering, &required_ordering)?; - assert_eq!(res, expected); - } - - Ok(()) - } - - #[tokio::test] - async fn test_can_skip_ordering_exhaustive() -> Result<()> { - let test_schema = create_test_schema3()?; - // Columns a,c are nullable whereas b,d are not nullable. - // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST - // Column e is not ordered. - let sort_exprs = vec![ - sort_expr("a", &test_schema), - sort_expr("b", &test_schema), - sort_expr("c", &test_schema), - sort_expr("d", &test_schema), - ]; - let exec_unbounded = csv_exec_sorted(&test_schema, sort_exprs, true); - - // test cases consists of vector of tuples. Where each tuple represents a single test case. - // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns - // For instance `vec!["a", "b"]` corresponds to PARTITION BY a, b - // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns - // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check - // for reversibility in this test). - // Third field in the tuple is Option, which corresponds to expected algorithm mode. - // None represents that existing ordering is not sufficient to run executor with any one of the algorithms - // (We need to add SortExec to be able to run it). - // Some(PartitionSearchMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // PartitionSearchMode. - let test_cases = vec![ - (vec!["a"], vec!["a"], Some(Sorted)), - (vec!["a"], vec!["b"], Some(Sorted)), - (vec!["a"], vec!["c"], None), - (vec!["a"], vec!["a", "b"], Some(Sorted)), - (vec!["a"], vec!["b", "c"], Some(Sorted)), - (vec!["a"], vec!["a", "c"], None), - (vec!["a"], vec!["a", "b", "c"], Some(Sorted)), - (vec!["b"], vec!["a"], Some(Linear)), - (vec!["b"], vec!["b"], None), - (vec!["b"], vec!["c"], None), - (vec!["b"], vec!["a", "b"], Some(Linear)), - (vec!["b"], vec!["b", "c"], None), - (vec!["b"], vec!["a", "c"], Some(Linear)), - (vec!["b"], vec!["a", "b", "c"], Some(Linear)), - (vec!["c"], vec!["a"], Some(Linear)), - (vec!["c"], vec!["b"], None), - (vec!["c"], vec!["c"], None), - (vec!["c"], vec!["a", "b"], Some(Linear)), - (vec!["c"], vec!["b", "c"], None), - (vec!["c"], vec!["a", "c"], Some(Linear)), - (vec!["c"], vec!["a", "b", "c"], Some(Linear)), - (vec!["b", "a"], vec!["a"], Some(Sorted)), - (vec!["b", "a"], vec!["b"], Some(Sorted)), - (vec!["b", "a"], vec!["c"], Some(Sorted)), - (vec!["b", "a"], vec!["a", "b"], Some(Sorted)), - (vec!["b", "a"], vec!["b", "c"], Some(Sorted)), - (vec!["b", "a"], vec!["a", "c"], Some(Sorted)), - (vec!["b", "a"], vec!["a", "b", "c"], Some(Sorted)), - (vec!["c", "b"], vec!["a"], Some(Linear)), - (vec!["c", "b"], vec!["b"], None), - (vec!["c", "b"], vec!["c"], None), - (vec!["c", "b"], vec!["a", "b"], Some(Linear)), - (vec!["c", "b"], vec!["b", "c"], None), - (vec!["c", "b"], vec!["a", "c"], Some(Linear)), - (vec!["c", "b"], vec!["a", "b", "c"], Some(Linear)), - (vec!["c", "a"], vec!["a"], Some(PartiallySorted(vec![1]))), - (vec!["c", "a"], vec!["b"], Some(PartiallySorted(vec![1]))), - (vec!["c", "a"], vec!["c"], Some(PartiallySorted(vec![1]))), - ( - vec!["c", "a"], - vec!["a", "b"], - Some(PartiallySorted(vec![1])), - ), - ( - vec!["c", "a"], - vec!["b", "c"], - Some(PartiallySorted(vec![1])), - ), - ( - vec!["c", "a"], - vec!["a", "c"], - Some(PartiallySorted(vec![1])), - ), - ( - vec!["c", "a"], - vec!["a", "b", "c"], - Some(PartiallySorted(vec![1])), - ), - (vec!["c", "b", "a"], vec!["a"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["b"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["c"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["a", "b"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["b", "c"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["a", "c"], Some(Sorted)), - (vec!["c", "b", "a"], vec!["a", "b", "c"], Some(Sorted)), - ]; - for (case_idx, test_case) in test_cases.iter().enumerate() { - let (partition_by_columns, order_by_params, expected) = &test_case; - let mut partition_by_exprs = vec![]; - for col_name in partition_by_columns { - partition_by_exprs.push(col(col_name, &test_schema)?); - } - - let mut order_by_exprs = vec![]; - for col_name in order_by_params { - let expr = col(col_name, &test_schema)?; - // Give default ordering, this is same with input ordering direction - // In this test we do check for reversibility. - let options = SortOptions::default(); - order_by_exprs.push(PhysicalSortExpr { expr, options }); - } - let res = - can_skip_sort(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; - // Since reversibility is not important in this test. Convert Option<(bool, PartitionSearchMode)> to Option - let res = res.map(|(_, mode)| mode); - assert_eq!( - res, *expected, - "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" - ); - } - - Ok(()) - } - - #[tokio::test] - async fn test_can_skip_ordering() -> Result<()> { - let test_schema = create_test_schema3()?; - // Columns a,c are nullable whereas b,d are not nullable. - // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST - // Column e is not ordered. - let sort_exprs = vec![ - sort_expr("a", &test_schema), - sort_expr("b", &test_schema), - sort_expr("c", &test_schema), - sort_expr("d", &test_schema), - ]; - let exec_unbounded = csv_exec_sorted(&test_schema, sort_exprs, true); - - // test cases consists of vector of tuples. Where each tuple represents a single test case. - // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns - // For instance `vec!["a", "b"]` corresponds to PARTITION BY a, b - // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns - // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, - // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, - // Third field in the tuple is Option<(bool, PartitionSearchMode)>, which corresponds to expected result. - // None represents that existing ordering is not sufficient to run executor with any one of the algorithms - // (We need to add SortExec to be able to run it). - // Some((bool, PartitionSearchMode)) represents, we can run algorithm with existing ordering. Algorithm should work in - // PartitionSearchMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. - // For instance, `Some((false, PartitionSearchMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm - // should work in Sorted mode to work with existing ordering. - let test_cases = vec![ - // PARTITION BY a, b ORDER BY c ASC NULLS LAST - (vec!["a", "b"], vec![("c", false, false)], None), - // ORDER BY c ASC NULLS FIRST - (vec![], vec![("c", false, true)], None), - // PARTITION BY b, ORDER BY c ASC NULLS FIRST - (vec!["b"], vec![("c", false, true)], None), - // PARTITION BY a, ORDER BY c ASC NULLS FIRST - (vec!["a"], vec![("c", false, true)], None), - // PARTITION BY b, ORDER BY c ASC NULLS FIRST - ( - vec!["a", "b"], - vec![("c", false, true), ("e", false, true)], - None, - ), - // PARTITION BY a, ORDER BY b ASC NULLS FIRST - (vec!["a"], vec![("b", false, true)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY a ASC NULLS FIRST - (vec!["a"], vec![("a", false, true)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY a ASC NULLS LAST - (vec!["a"], vec![("a", false, false)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY a DESC NULLS FIRST - (vec!["a"], vec![("a", true, true)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY a DESC NULLS LAST - (vec!["a"], vec![("a", true, false)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY b ASC NULLS LAST - (vec!["a"], vec![("b", false, false)], Some((false, Sorted))), - // PARTITION BY a, ORDER BY b DESC NULLS LAST - (vec!["a"], vec![("b", true, false)], Some((true, Sorted))), - // PARTITION BY a, b ORDER BY c ASC NULLS FIRST - ( - vec!["a", "b"], - vec![("c", false, true)], - Some((false, Sorted)), - ), - // PARTITION BY b, a ORDER BY c ASC NULLS FIRST - ( - vec!["b", "a"], - vec![("c", false, true)], - Some((false, Sorted)), - ), - // PARTITION BY a, b ORDER BY c DESC NULLS LAST - ( - vec!["a", "b"], - vec![("c", true, false)], - Some((true, Sorted)), - ), - // PARTITION BY e ORDER BY a ASC NULLS FIRST - ( - vec!["e"], - vec![("a", false, true)], - // For unbounded, expects to work in Linear mode. Shouldn't reverse window function. - Some((false, Linear)), - ), - // PARTITION BY b, c ORDER BY a ASC NULLS FIRST, c ASC NULLS FIRST - ( - vec!["b", "c"], - vec![("a", false, true), ("c", false, true)], - Some((false, Linear)), - ), - // PARTITION BY b ORDER BY a ASC NULLS FIRST - (vec!["b"], vec![("a", false, true)], Some((false, Linear))), - // PARTITION BY a, e ORDER BY b ASC NULLS FIRST - ( - vec!["a", "e"], - vec![("b", false, true)], - Some((false, PartiallySorted(vec![0]))), - ), - // PARTITION BY a, c ORDER BY b ASC NULLS FIRST - ( - vec!["a", "c"], - vec![("b", false, true)], - Some((false, PartiallySorted(vec![0]))), - ), - // PARTITION BY c, a ORDER BY b ASC NULLS FIRST - ( - vec!["c", "a"], - vec![("b", false, true)], - Some((false, PartiallySorted(vec![1]))), - ), - // PARTITION BY d, b, a ORDER BY c ASC NULLS FIRST - ( - vec!["d", "b", "a"], - vec![("c", false, true)], - Some((false, PartiallySorted(vec![2, 1]))), - ), - // PARTITION BY e, b, a ORDER BY c ASC NULLS FIRST - ( - vec!["e", "b", "a"], - vec![("c", false, true)], - Some((false, PartiallySorted(vec![2, 1]))), - ), - // PARTITION BY d, a ORDER BY b ASC NULLS FIRST - ( - vec!["d", "a"], - vec![("b", false, true)], - Some((false, PartiallySorted(vec![1]))), - ), - // PARTITION BY b, ORDER BY b, a ASC NULLS FIRST - ( - vec!["a"], - vec![("b", false, true), ("a", false, true)], - Some((false, Sorted)), - ), - // ORDER BY b, a ASC NULLS FIRST - (vec![], vec![("b", false, true), ("a", false, true)], None), - ]; - for (case_idx, test_case) in test_cases.iter().enumerate() { - let (partition_by_columns, order_by_params, expected) = &test_case; - let mut partition_by_exprs = vec![]; - for col_name in partition_by_columns { - partition_by_exprs.push(col(col_name, &test_schema)?); - } - - let mut order_by_exprs = vec![]; - for (col_name, descending, nulls_first) in order_by_params { - let expr = col(col_name, &test_schema)?; - let options = SortOptions { - descending: *descending, - nulls_first: *nulls_first, - }; - order_by_exprs.push(PhysicalSortExpr { expr, options }); - } - - assert_eq!( - can_skip_sort(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, - *expected, - "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" - ); - } - - Ok(()) - } - /// Runs the sort enforcement optimizer and asserts the plan /// against the original and expected plans /// /// `$EXPECTED_PLAN_LINES`: input plan /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan /// `$PLAN`: the plan to optimized + /// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. /// macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { - let session_ctx = SessionContext::new(); + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr) => { + let config = SessionConfig::new().with_repartition_sorts($REPARTITION_SORTS); + let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES @@ -1435,16 +853,16 @@ mod tests { let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); - let expected_input = vec![ + let expected_input = [ "SortExec: expr=[nullable_col@0 ASC]", " SortExec: expr=[non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1462,8 +880,11 @@ mod tests { }, )]; let sort = sort_exec(sort_exprs.clone(), source); + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + let coalesce_batches = coalesce_batches_exec(sort); - let window_agg = bounded_window_exec("non_nullable_col", sort_exprs, sort); + let window_agg = + bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); let sort_exprs = vec![sort_expr_options( "non_nullable_col", @@ -1486,23 +907,21 @@ mod tests { let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - let expected_input = vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " FilterExec: NOT non_nullable_col@1", " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[non_nullable_col@1 DESC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; + " CoalesceBatchesExec: target_batch_size=128", + " SortExec: expr=[non_nullable_col@1 DESC]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", " FilterExec: NOT non_nullable_col@1", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[non_nullable_col@1 DESC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " CoalesceBatchesExec: target_batch_size=128", + " SortExec: expr=[non_nullable_col@1 DESC]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1515,15 +934,15 @@ mod tests { let physical_plan = sort_preserving_merge_exec(sort_exprs, source); - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1538,18 +957,18 @@ mod tests { let sort_exprs = vec![sort_expr("nullable_col", &schema)]; let sort = sort_exec(sort_exprs.clone(), spm); let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1572,7 +991,7 @@ mod tests { let sort3 = sort_exec(sort_exprs, spm2); let physical_plan = repartition_exec(repartition_exec(sort3)); - let expected_input = vec![ + let expected_input = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortExec: expr=[nullable_col@0 ASC]", @@ -1580,15 +999,15 @@ mod tests { " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1616,23 +1035,23 @@ mod tests { // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = vec![ + let expected_input = [ "AggregateExec: mode=Final, gby=[], aggr=[]", " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "AggregateExec: mode=Final, gby=[], aggr=[]", " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1667,29 +1086,52 @@ mod tests { // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " FilterExec: NOT non_nullable_col@1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC]", " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " FilterExec: NOT non_nullable_col@1", " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort5() -> Result<()> { + let left_schema = create_test_schema2()?; + let right_schema = create_test_schema3()?; + let left_input = memory_exec(&left_schema); + let parquet_sort_exprs = vec![sort_expr("a", &right_schema)]; + let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); + + let on = vec![( + Column::new_with_schema("col_a", &left_schema)?, + Column::new_with_schema("c", &right_schema)?, + )]; + let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; + let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); + + let expected_input = ["SortExec: expr=[a@2 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]"]; + + let expected_optimized = ["HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1708,17 +1150,17 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_expr("nullable_col", &schema)], input2); - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1741,20 +1183,17 @@ mod tests { let repartition = repartition_exec(union); let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // We should keep the bottom `SortExec`. - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", @@ -1762,9 +1201,8 @@ mod tests { " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1778,16 +1216,16 @@ mod tests { ]; let sort = sort_exec(vec![sort_exprs[0].clone()], source); let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1804,17 +1242,17 @@ mod tests { let physical_plan = sort_preserving_merge_exec(vec![sort_exprs[1].clone()], sort2); - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC]", " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = vec![ + let expected_optimized = [ "SortExec: expr=[non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1841,7 +1279,7 @@ mod tests { ]; // should not add a sort at the output of the union, input plan should not be changed let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1872,7 +1310,7 @@ mod tests { ]; // should not add a sort at the output of the union, input plan should not be changed let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1896,23 +1334,19 @@ mod tests { // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First ParquetExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1938,26 +1372,22 @@ mod tests { // First input to the union is not Sorted (SortExec is finer than required ordering by the SortPreservingMergeExec above). // Second input to the union is already Sorted (matches with the required ordering by the SortPreservingMergeExec above). // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // should adjust sorting in the first input of the union such that it is not unnecessarily fine - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -1983,26 +1413,22 @@ mod tests { // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2036,23 +1462,19 @@ mod tests { // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. However, we should be able to change the unnecessarily // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2082,29 +1504,25 @@ mod tests { // At the same time, this ordering requirement is unnecessarily fine. // The final plan should be valid AND the ordering of the third child // shouldn't be finer than necessary. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2125,24 +1543,20 @@ mod tests { let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec - let expected_output = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_output, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_output, physical_plan, true); Ok(()) } @@ -2180,21 +1594,17 @@ mod tests { // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. - let expected_input = vec![ - "UnionExec", + let expected_input = ["UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[nullable_col@0 DESC NULLS LAST,non_nullable_col@1 DESC NULLS LAST]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. - let expected_optimized = vec![ - "UnionExec", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", + let expected_optimized = ["UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2229,23 +1639,21 @@ mod tests { // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. - let expected_input = vec![ + let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", " UnionExec", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - ]; - let expected_optimized = vec![ + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; + let expected_optimized = [ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2269,23 +1677,19 @@ mod tests { // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - ]; - let expected_optimized = vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; + let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2320,27 +1724,23 @@ mod tests { let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2375,16 +1775,16 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs.clone(), join); - let join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"nullable_col\", index: 0 }}, Column {{ name: \"col_a\", index: 0 }})]"); - let join_plan2 = - format!(" SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"nullable_col\", index: 0 }}, Column {{ name: \"col_a\", index: 0 }})]"); - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let join_plan = format!( + "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + ); + let join_plan2 = format!( + " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + ); + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", join_plan2.as_str(), " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; let expected_optimized = match join_type { JoinType::Inner | JoinType::Left @@ -2411,7 +1811,7 @@ mod tests { ] } }; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); } Ok(()) } @@ -2446,22 +1846,22 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs, join); - let join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"nullable_col\", index: 0 }}, Column {{ name: \"col_a\", index: 0 }})]"); + let join_plan = format!( + "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + ); let spm_plan = match join_type { JoinType::RightAnti => { "SortPreservingMergeExec: [col_a@0 ASC,col_b@1 ASC]" } _ => "SortPreservingMergeExec: [col_a@2 ASC,col_b@3 ASC]", }; - let join_plan2 = - format!(" SortMergeJoin: join_type={join_type}, on=[(Column {{ name: \"nullable_col\", index: 0 }}, Column {{ name: \"col_a\", index: 0 }})]"); - let expected_input = vec![ - spm_plan, + let join_plan2 = format!( + " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + ); + let expected_input = [spm_plan, join_plan2.as_str(), " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; let expected_optimized = match join_type { JoinType::Inner | JoinType::Right | JoinType::RightAnti => { // can push down the sort requirements and save 1 SortExec @@ -2485,7 +1885,7 @@ mod tests { ] } }; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); } Ok(()) } @@ -2513,23 +1913,19 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); - let expected_input = vec![ - "SortPreservingMergeExec: [col_b@3 ASC,col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(Column { name: \"nullable_col\", index: 0 }, Column { name: \"col_a\", index: 0 })]", + let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC,col_a@2 ASC]", + " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = vec![ - "SortExec: expr=[col_b@3 ASC,col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(Column { name: \"nullable_col\", index: 0 }, Column { name: \"col_a\", index: 0 })]", + let expected_optimized = ["SortExec: expr=[col_b@3 ASC,col_a@2 ASC]", + " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[col_a@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) let sort_exprs2 = vec![ @@ -2539,23 +1935,19 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(Column { name: \"nullable_col\", index: 0 }, Column { name: \"col_a\", index: 0 })]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", + " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = vec![ - "SortExec: expr=[nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(Column { name: \"nullable_col\", index: 0 }, Column { name: \"col_a\", index: 0 })]", + let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", + " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " SortExec: expr=[col_a@0 ASC]", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2580,22 +1972,18 @@ mod tests { let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - let expected_input = vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; + " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2619,21 +2007,17 @@ mod tests { // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC]", " FilterExec: NOT non_nullable_col@1", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - let expected_optimized = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", " FilterExec: NOT non_nullable_col@1", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -2656,6 +2040,17 @@ mod tests { let orig_plan = Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; + let actual = get_plan_string(&orig_plan); + let expected_input = vec![ + "SortExec: expr=[nullable_col@0 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_eq!( + expected_input, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_input:#?}\nactual:\n\n{actual:#?}\n\n" + ); let mut plan = orig_plan.clone(); let rules = vec![ @@ -2701,195 +2096,187 @@ mod tests { let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort // Also input plan is not valid as it is. We need to add SortExec before SortPreservingMergeExec. - let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", - ]; - let expected_optimized = vec![ + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + let expected_optimized = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=0", - " MemoryExec: partitions=0, partition_sizes=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected_optimized, physical_plan); + assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } - /// make PhysicalSortExpr with default options - fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) - } - - /// PhysicalSortExpr with specified options - fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, - ) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } - } - - fn memory_exec(schema: &SchemaRef) -> Arc { - Arc::new(MemoryExec::try_new(&[], schema.clone(), None).unwrap()) - } - - fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) - } - - fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) - } - - fn filter_exec( - predicate: Arc, - input: Arc, - ) -> Arc { - Arc::new(FilterExec::try_new(predicate, input).unwrap()) + #[tokio::test] + async fn test_with_lost_ordering_bounded() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, false); + let repartition_rr = repartition_exec(source); + let repartition_hash = Arc::new(RepartitionExec::try_new( + repartition_rr, + Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + )?) as _; + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + + let expected_input = ["SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=false"]; + let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=false"]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) } - fn bounded_window_exec( - col_name: &str, - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); - let schema = input.schema(); - - Arc::new( - BoundedWindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col(col_name, &schema).unwrap()], - &[], - &sort_exprs, - Arc::new(WindowFrame::new(true)), - schema.as_ref(), - ) - .unwrap()], - input.clone(), - input.schema(), - vec![], - Sorted, - ) - .unwrap(), - ) - } - - /// Create a non sorted parquet exec - fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - )) + #[tokio::test] + async fn test_with_lost_ordering_unbounded() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec(source); + let repartition_hash = Arc::new(RepartitionExec::try_new( + repartition_rr, + Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + )?) as _; + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + + let expected_input = [ + "SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) } - // Created a sorted parquet exec - fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - infinite_source: false, - }, - None, - None, - )) + #[tokio::test] + async fn test_with_lost_ordering_unbounded_parallelize_off() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec(source); + let repartition_hash = Arc::new(RepartitionExec::try_new( + repartition_rr, + Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + )?) as _; + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + + let expected_input = ["SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false" + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) } - fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) + #[tokio::test] + async fn test_do_not_pushdown_through_spm() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let repartition_rr = repartition_exec(source); + let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); + + let expected_input = ["SortExec: expr=[b@1 ASC]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; + let expected_optimized = ["SortExec: expr=[b@1 ASC]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) } - fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) - } + #[tokio::test] + async fn test_pushdown_through_spm() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs.clone(), false); + let repartition_rr = repartition_exec(source); + let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let physical_plan = sort_exec( + vec![ + sort_expr("a", &schema), + sort_expr("b", &schema), + sort_expr("c", &schema), + ], + spm, + ); - fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) + let expected_input = ["SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; + let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) } - fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) - } + #[tokio::test] + async fn test_window_multi_layer_requirement() -> Result<()> { + let schema = create_test_schema3()?; + let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let source = csv_exec_sorted(&schema, vec![], false); + let sort = sort_exec(sort_exprs.clone(), source); + let repartition = repartition_exec(sort); + let repartition = spr_repartition_exec(repartition); + let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); - fn repartition_exec(input: Arc) -> Arc { - Arc::new( - RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap(), - ) - } + let physical_plan = bounded_window_exec("a", sort_exprs, spm); - fn aggregate_exec(input: Arc) -> Arc { - let schema = input.schema(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![], - vec![], - vec![], - input, - schema, - ) - .unwrap(), - ) - } - - fn sort_merge_join_exec( - left: Arc, - right: Arc, - join_on: &JoinOn, - join_type: &JoinType, - ) -> Arc { - Arc::new( - SortMergeJoinExec::try_new( - left, - right, - join_on.clone(), - *join_type, - vec![SortOptions::default(); join_on.len()], - false, - ) - .unwrap(), - ) + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + " SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, sort_exprs=a@0 ASC,b@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[a@0 ASC,b@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, false); + Ok(()) } } diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs deleted file mode 100644 index 9466297d24d00..0000000000000 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ /dev/null @@ -1,94 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Select the efficient global sort implementation based on sort details. - -use std::sync::Arc; - -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::{Transformed, TreeNode}; - -/// Currently for a sort operator, if -/// - there are more than one input partitions -/// - and there's some limit which can be pushed down to each of its input partitions -/// then [SortPreservingMergeExec] with local sort with a limit pushed down will be preferred; -/// Otherwise, the normal global sort [SortExec] will be used. -/// Later more intelligent statistics-based decision can also be introduced. -/// For example, for a small data set, the global sort may be efficient enough -#[derive(Default)] -pub struct GlobalSortSelection {} - -impl GlobalSortSelection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for GlobalSortSelection { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - plan.transform_up(&|plan| { - let transformed = - plan.as_any() - .downcast_ref::() - .and_then(|sort_exec| { - if sort_exec.input().output_partitioning().partition_count() > 1 - // It's already preserving the partitioning so that it can be regarded as a local sort - && !sort_exec.preserve_partitioning() - && (sort_exec.fetch().is_some() || config.optimizer.repartition_sorts) - { - let sort = SortExec::new( - sort_exec.expr().to_vec(), - sort_exec.input().clone() - ) - .with_fetch(sort_exec.fetch()) - .with_preserve_partitioning(true); - let global_sort: Arc = - Arc::new(SortPreservingMergeExec::new( - sort_exec.expr().to_vec(), - Arc::new(sort), - )); - Some(global_sort) - } else { - None - } - }); - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) - }) - } - - fn name(&self) -> &str { - "global_sort_selection" - } - - fn schema_check(&self) -> bool { - false - } -} diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index a97ef6a3f9d30..6b2fe24acf005 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -15,36 +15,38 @@ // specific language governing permissions and limitations // under the License. -//! Select the proper PartitionMode and build side based on the avaliable statistics for hash join. -use std::sync::Arc; +//! The [`JoinSelection`] rule tries to modify a given plan so that it can +//! accommodate infinite sources and utilize statistical information (if there +//! is any) to obtain more performant plans. To achieve the first goal, it +//! tries to transform a non-runnable query (with the given infinite sources) +//! into a runnable query by replacing pipeline-breaking join operations with +//! pipeline-friendly ones. To achieve the second goal, it selects the proper +//! `PartitionMode` and the build side using the available statistics for hash joins. -use arrow::datatypes::Schema; +use std::sync::Arc; use crate::config::ConfigOptions; -use crate::logical_expr::JoinType; -use crate::physical_plan::expressions::Column; +use crate::error::Result; +use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use crate::physical_plan::joins::{ - utils::{ColumnIndex, JoinFilter, JoinSide}, - CrossJoinExec, HashJoinExec, PartitionMode, + CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode, + SymmetricHashJoinExec, }; use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; +use crate::physical_plan::ExecutionPlan; -use super::optimizer::PhysicalOptimizerRule; -use crate::error::Result; +use arrow_schema::Schema; +use datafusion_common::internal_err; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DataFusionError, JoinType}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalExpr; -/// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make -/// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal -/// based on the available statistics that the inputs have. -/// If the statistics information is not available, the partition mode will fall back to [PartitionMode::Partitioned]. -/// -/// JoinSelection rule will also reorder the build and probe phase of the hash joins -/// based on the avaliable statistics that the inputs have. -/// The rule optimizes the order such that the left (build) side of the join is the smallest. -/// If the statistics information is not available, the order stays the same as the original query. -/// JoinSelection rule will also swap the left and right sides for cross join to keep the left side -/// is the smallest. +/// The [`JoinSelection`] rule tries to modify a given plan so that it can +/// accommodate infinite sources and optimize joins in the plan according to +/// available statistical information, if there is any. #[derive(Default)] pub struct JoinSelection {} @@ -55,23 +57,32 @@ impl JoinSelection { } } -// TODO we need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. -// TODO In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. -fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> bool { +// TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. +// TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is is 8 times. +/// Checks statistics for join swap. +fn should_swap_join_order( + left: &dyn ExecutionPlan, + right: &dyn ExecutionPlan, +) -> Result { // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let (left_size, right_size) = match ( - left.statistics().total_byte_size, - right.statistics().total_byte_size, + let left_stats = left.statistics()?; + let right_stats = right.statistics()?; + // First compare `total_byte_size` of left and right side, + // if information in this field is insufficient fallback to the `num_rows` + match ( + left_stats.total_byte_size.get_value(), + right_stats.total_byte_size.get_value(), ) { - (Some(l), Some(r)) => (Some(l), Some(r)), - _ => (left.statistics().num_rows, right.statistics().num_rows), - }; - - match (left_size, right_size) { - (Some(l), Some(r)) => l > r, - _ => false, + (Some(l), Some(r)) => Ok(l > r), + _ => match ( + left_stats.num_rows.get_value(), + right_stats.num_rows.get_value(), + ) { + (Some(l), Some(r)) => Ok(l > r), + _ => Ok(false), + }, } } @@ -81,16 +92,21 @@ fn supports_collect_by_size( ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() - if let Some(size) = plan.statistics().total_byte_size { - size != 0 && size < collection_size_threshold - } else if let Some(row_count) = plan.statistics().num_rows { - row_count != 0 && row_count < collection_size_threshold + let Ok(stats) = plan.statistics() else { + return false; + }; + + if let Some(size) = stats.total_byte_size.get_value() { + *size != 0 && *size < collection_size_threshold + } else if let Some(row_count) = stats.num_rows.get_value() { + *row_count != 0 && *row_count < collection_size_threshold } else { false } } + /// Predicate that checks whether the given join type supports input swapping. -pub fn supports_swap(join_type: JoinType) -> bool { +fn supports_swap(join_type: JoinType) -> bool { matches!( join_type, JoinType::Inner @@ -103,9 +119,10 @@ pub fn supports_swap(join_type: JoinType) -> bool { | JoinType::RightAnti ) } + /// This function returns the new join type we get after swapping the given /// join's inputs. -pub fn swap_join_type(join_type: JoinType) -> JoinType { +fn swap_join_type(join_type: JoinType) -> JoinType { match join_type { JoinType::Inner => JoinType::Inner, JoinType::Full => JoinType::Full, @@ -119,7 +136,7 @@ pub fn swap_join_type(join_type: JoinType) -> JoinType { } /// This function swaps the inputs of the given join operator. -pub fn swap_hash_join( +fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { @@ -160,7 +177,7 @@ pub fn swap_hash_join( /// the output should not be impacted. This function creates the expressions /// that will allow to swap back the values from the original left as the first /// columns and those on the right next. -pub fn swap_reverting_projection( +fn swap_reverting_projection( left_schema: &Schema, right_schema: &Schema, ) -> Vec<(Arc, String)> { @@ -182,30 +199,26 @@ pub fn swap_reverting_projection( } /// Swaps join sides for filter column indices and produces new JoinFilter -fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(|filter| { - let column_indices = filter - .column_indices() - .iter() - .map(|idx| { - let side = if matches!(idx.side, JoinSide::Left) { - JoinSide::Right - } else { - JoinSide::Left - }; - ColumnIndex { - index: idx.index, - side, - } - }) - .collect(); +fn swap_filter(filter: &JoinFilter) -> JoinFilter { + let column_indices = filter + .column_indices() + .iter() + .map(|idx| ColumnIndex { + index: idx.index, + side: idx.side.negate(), + }) + .collect(); - JoinFilter::new( - filter.expression().clone(), - column_indices, - filter.schema().clone(), - ) - }) + JoinFilter::new( + filter.expression().clone(), + column_indices, + filter.schema().clone(), + ) +} + +/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). +fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { + filter.map(swap_filter) } impl PhysicalOptimizerRule for JoinSelection { @@ -214,63 +227,32 @@ impl PhysicalOptimizerRule for JoinSelection { plan: Arc, config: &ConfigOptions, ) -> Result> { + let pipeline = PipelineStatePropagator::new(plan); + // First, we make pipeline-fixing modifications to joins so as to accommodate + // unbounded inputs. Each pipeline-fixing subrule, which is a function + // of type `PipelineFixerSubrule`, takes a single [`PipelineStatePropagator`] + // argument storing state variables that indicate the unboundedness status + // of the current [`ExecutionPlan`] as we traverse the plan tree. + let subrules: Vec> = vec![ + Box::new(hash_join_convert_symmetric_subrule), + Box::new(hash_join_swap_subrule), + ]; + let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + // Next, we apply another subrule that tries to optimize joins using any + // statistics their inputs might have. + // - For a hash join with partition mode [`PartitionMode::Auto`], we will + // make a cost-based decision to select which `PartitionMode` mode + // (`Partitioned`/`CollectLeft`) is optimal. If the statistics information + // is not available, we will fall back to [`PartitionMode::Partitioned`]. + // - We optimize/swap join sides so that the left (build) side of the join + // is the small side. If the statistics information is not available, we + // do not modify join sides. + // - We will also swap left and right sides for cross joins so that the left + // side is the small side. let config = &config.optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; - plan.transform_up(&|plan| { - let transformed = if let Some(hash_join) = - plan.as_any().downcast_ref::() - { - match hash_join.partition_mode() { - PartitionMode::Auto => { - try_collect_left(hash_join, Some(collect_left_threshold))? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )? - } - PartitionMode::CollectLeft => try_collect_left(hash_join, None)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if should_swap_join_order(&**left, &**right) - && supports_swap(*hash_join.join_type()) - { - swap_hash_join(hash_join, PartitionMode::Partitioned) - .map(Some)? - } else { - None - } - } - } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() - { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right) { - let new_join = - CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) - } else { - None - } - } else { - None - }; - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) + state.plan.transform_up(&|plan| { + statistical_join_selection_subrule(plan, collect_left_threshold) }) } @@ -283,13 +265,17 @@ impl PhysicalOptimizerRule for JoinSelection { } } -/// Try to create the PartitionMode::CollectLeft HashJoinExec when possible. -/// The method will first consider the current join type and check whether it is applicable to run CollectLeft mode -/// and will try to swap the join if the orignal type is unapplicable to run CollectLeft. -/// When the collect_threshold is provided, the method will also check both the left side and right side sizes +/// Tries to create a [`HashJoinExec`] in [`PartitionMode::CollectLeft`] when possible. /// -/// For [JoinType::Full], it is alway unable to run CollectLeft mode and will return None. -/// For [JoinType::Left] and [JoinType::LeftAnti], can not run CollectLeft mode, should swap join type to [JoinType::Right] and [JoinType::RightAnti] +/// This function will first consider the given join type and check whether the +/// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. +/// When the `collect_threshold` is provided, this function will also check left +/// and right sizes. +/// +/// For [`JoinType::Full`], it can not use `CollectLeft` mode and will return `None`. +/// For [`JoinType::Left`] and [`JoinType::LeftAnti`], it can not run `CollectLeft` +/// mode as is, but it can do so by changing the join type to [`JoinType::Right`] +/// and [`JoinType::RightAnti`], respectively. fn try_collect_left( hash_join: &HashJoinExec, collect_threshold: Option, @@ -320,7 +306,7 @@ fn try_collect_left( }; match (left_can_collect, right_can_collect) { (true, true) => { - if should_swap_join_order(&**left, &**right) + if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) { Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?)) @@ -359,7 +345,7 @@ fn try_collect_left( fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right) && supports_swap(*hash_join.join_type()) + if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) { swap_hash_join(hash_join, PartitionMode::Partitioned) } else { @@ -375,36 +361,266 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result, + collect_left_threshold: usize, +) -> Result>> { + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { + match hash_join.partition_mode() { + PartitionMode::Auto => { + try_collect_left(hash_join, Some(collect_left_threshold))?.map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )? + } + PartitionMode::CollectLeft => try_collect_left(hash_join, None)? + .map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if should_swap_join_order(&**left, &**right)? + && supports_swap(*hash_join.join_type()) + { + swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + } else { + None + } + } + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right)? { + let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); + // TODO avoid adding ProjectionExec again and again, only adding Final Projection + let proj: Arc = Arc::new(ProjectionExec::try_new( + swap_reverting_projection(&left.schema(), &right.schema()), + Arc::new(new_join), + )?); + Some(proj) + } else { + None + } + } else { + None + }; + + Ok(if let Some(transformed) = transformed { + Transformed::Yes(transformed) + } else { + Transformed::No(plan) + }) +} + +/// Pipeline-fixing join selection subrule. +pub type PipelineFixerSubrule = dyn Fn( + PipelineStatePropagator, + &ConfigOptions, +) -> Option>; + +/// This subrule checks if we can replace a hash join with a symmetric hash +/// join when we are dealing with infinite inputs on both sides. This change +/// avoids pipeline breaking and preserves query runnability. If possible, +/// this subrule makes this replacement; otherwise, it has no effect. +fn hash_join_convert_symmetric_subrule( + mut input: PipelineStatePropagator, + config_options: &ConfigOptions, +) -> Option> { + if let Some(hash_join) = input.plan.as_any().downcast_ref::() { + let ub_flags = input.children_unbounded(); + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + input.unbounded = left_unbounded || right_unbounded; + let result = if left_unbounded && right_unbounded { + let mode = if config_options.optimizer.repartition_joins { + StreamJoinPartitionMode::Partitioned + } else { + StreamJoinPartitionMode::SinglePartition + }; + SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.null_equals_null(), + mode, + ) + .map(|exec| { + input.plan = Arc::new(exec) as _; + input + }) + } else { + Ok(input) + }; + Some(result) + } else { + None + } +} + +/// This subrule will swap build/probe sides of a hash join depending on whether +/// one of its inputs may produce an infinite stream of records. The rule ensures +/// that the left (build) side of the hash join always operates on an input stream +/// that will produce a finite set of records. If the left side can not be chosen +/// to be "finite", the join sides stay the same as the original query. +/// ```text +/// For example, this rule makes the following transformation: +/// +/// +/// +/// +--------------+ +--------------+ +/// | | unbounded | | +/// Left | Infinite | true | Hash |\true +/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ +/// | | | | \ | | | | +/// +--------------+ +--------------+ - | Hash Join |-------| Projection | +/// - | | | | +/// +--------------+ +--------------+ / +--------------+ +--------------+ +/// | | unbounded | | / +/// Right | Finite | false | Hash |/false +/// | Data Source |--------------| Repartition | +/// | | | | +/// +--------------+ +--------------+ +/// +/// +/// +/// +--------------+ +--------------+ +/// | | unbounded | | +/// Left | Finite | false | Hash |\false +/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ +/// | | | | \ | | true | | true +/// +--------------+ +--------------+ - | Hash Join |-------| Projection |----- +/// - | | | | +/// +--------------+ +--------------+ / +--------------+ +--------------+ +/// | | unbounded | | / +/// Right | Infinite | true | Hash |/true +/// | Data Source |--------------| Repartition | +/// | | | | +/// +--------------+ +--------------+ +/// +/// ``` +fn hash_join_swap_subrule( + mut input: PipelineStatePropagator, + _config_options: &ConfigOptions, +) -> Option> { + if let Some(hash_join) = input.plan.as_any().downcast_ref::() { + let ub_flags = input.children_unbounded(); + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + input.unbounded = left_unbounded || right_unbounded; + let result = if left_unbounded + && !right_unbounded + && matches!( + *hash_join.join_type(), + JoinType::Inner + | JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + ) { + swap_join_according_to_unboundedness(hash_join).map(|plan| { + input.plan = plan; + input + }) + } else { + Ok(input) + }; + Some(result) + } else { + None + } +} + +/// This function swaps sides of a hash join to make it runnable even if one of +/// its inputs are infinite. Note that this is not always possible; i.e. +/// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`] and +/// [`JoinType::RightSemi`] can not run with an unbounded left side, even if +/// we swap join sides. Therefore, we do not consider them here. +fn swap_join_according_to_unboundedness( + hash_join: &HashJoinExec, +) -> Result> { + let partition_mode = hash_join.partition_mode(); + let join_type = hash_join.join_type(); + match (*partition_mode, *join_type) { + ( + _, + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, + ) => internal_err!("{join_type} join cannot be swapped for unbounded input."), + (PartitionMode::Partitioned, _) => { + swap_hash_join(hash_join, PartitionMode::Partitioned) + } + (PartitionMode::CollectLeft, _) => { + swap_hash_join(hash_join, PartitionMode::CollectLeft) + } + (PartitionMode::Auto, _) => { + internal_err!("Auto is not acceptable for unbounded input here.") + } + } +} + +/// Apply given `PipelineFixerSubrule`s to a given plan. This plan, along with +/// auxiliary boundedness information, is in the `PipelineStatePropagator` object. +fn apply_subrules( + mut input: PipelineStatePropagator, + subrules: &Vec>, + config_options: &ConfigOptions, +) -> Result> { + for subrule in subrules { + if let Some(value) = subrule(input.clone(), config_options).transpose()? { + input = value; + } + } + let is_unbounded = input + .plan + .unbounded_output(&input.children_unbounded()) + // Treat the case where an operator can not run on unbounded data as + // if it can and it outputs unbounded data. Do not raise an error yet. + // Such operators may be fixed, adjusted or replaced later on during + // optimization passes -- sorts may be removed, windows may be adjusted + // etc. If this doesn't happen, the final `PipelineChecker` rule will + // catch this and raise an error anyway. + .unwrap_or(true); + input.unbounded = is_unbounded; + Ok(Transformed::Yes(input)) +} + #[cfg(test)] -mod tests { +mod tests_statistical { + use std::sync::Arc; + + use super::*; use crate::{ physical_plan::{ displayable, joins::PartitionMode, ColumnStatistics, Statistics, }, - test::exec::StatisticsExec, + test::StatisticsExec, }; - use super::*; - use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{stats::Precision, JoinType, ScalarValue}; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::PhysicalExpr; fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10), - total_byte_size: Some(100000), - ..Default::default() + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(100000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100000), - total_byte_size: Some(10), - ..Default::default() + num_rows: Precision::Inexact(100000), + total_byte_size: Precision::Inexact(10), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); @@ -420,13 +636,19 @@ mod tests { min: Option, max: Option, distinct_count: Option, - ) -> Option> { - Some(vec![ColumnStatistics { - distinct_count, - min_value: min.map(|size| ScalarValue::UInt64(Some(size))), - max_value: max.map(|size| ScalarValue::UInt64(Some(size))), + ) -> Vec { + vec![ColumnStatistics { + distinct_count: distinct_count + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + min_value: min + .map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size)))) + .unwrap_or(Precision::Absent), + max_value: max + .map(|size| Precision::Inexact(ScalarValue::UInt64(Some(size)))) + .unwrap_or(Precision::Absent), ..Default::default() - }]) + }] } /// Returns three plans with statistics of (min, max, distinct_count) @@ -440,39 +662,39 @@ mod tests { ) { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(100_000), + num_rows: Precision::Inexact(100_000), column_statistics: create_column_stats( Some(0), Some(50_000), Some(50_000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let medium = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10_000), + num_rows: Precision::Inexact(10_000), column_statistics: create_column_stats( Some(1000), Some(5000), Some(1000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("medium_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(1000), + num_rows: Precision::Inexact(1000), column_statistics: create_column_stats( Some(0), Some(100_000), Some(1000), ), - ..Default::default() + total_byte_size: Precision::Absent, }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); @@ -521,10 +743,13 @@ mod tests { .downcast_ref::() .expect("The type of the plan should not be changed"); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); } @@ -556,7 +781,6 @@ mod tests { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - println!("swapping_projection {swapping_projection:?}"); let (col, name) = &swapping_projection.expr()[0]; assert_eq!(name, "small_col"); assert_col_expr(col, "small_col", 1); @@ -571,10 +795,13 @@ mod tests { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(10) ); - assert_eq!(swapped_join.right().statistics().total_byte_size, Some(10)); } #[tokio::test] @@ -612,10 +839,13 @@ mod tests { assert_eq!(swapped_join.schema().fields().len(), 1); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); assert_eq!(original_schema, swapped_join.schema()); @@ -632,7 +862,7 @@ mod tests { .optimize(Arc::new($PLAN), &ConfigOptions::new()) .unwrap(); - let plan = displayable(optimized.as_ref()).indent().to_string(); + let plan = displayable(optimized.as_ref()).indent(true).to_string(); let actual_lines = plan.split("\n").collect::>(); assert_eq!( @@ -687,13 +917,13 @@ mod tests { // has an exact cardinality of 10_000 rows). let expected = [ "ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(Column { name: \"small_col\", index: 1 }, Column { name: \"medium_col\", index: 0 })]", + " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]", " ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"small_col\", index: 0 }, Column { name: \"big_col\", index: 0 })]", - " StatisticsExec: col_count=1, row_count=Some(1000)", - " StatisticsExec: col_count=1, row_count=Some(100000)", - " StatisticsExec: col_count=1, row_count=Some(10000)", - "" + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]", + " StatisticsExec: col_count=1, row_count=Inexact(1000)", + " StatisticsExec: col_count=1, row_count=Inexact(100000)", + " StatisticsExec: col_count=1, row_count=Inexact(10000)", + "", ]; assert_optimized!(expected, join); } @@ -724,10 +954,13 @@ mod tests { .downcast_ref::() .expect("The type of the plan should not be changed"); - assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10)); assert_eq!( - swapped_join.right().statistics().total_byte_size, - Some(100000) + swapped_join.left().statistics().unwrap().total_byte_size, + Precision::Inexact(10) + ); + assert_eq!( + swapped_join.right().statistics().unwrap().total_byte_size, + Precision::Inexact(100000) ); } @@ -770,27 +1003,27 @@ mod tests { async fn test_join_selection_collect_left() { let big = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10000000), - total_byte_size: Some(10000000), - ..Default::default() + num_rows: Precision::Inexact(10000000), + total_byte_size: Precision::Inexact(10000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), )); let small = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10), - total_byte_size: Some(10), - ..Default::default() + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(10), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( Statistics { - num_rows: None, - total_byte_size: None, - ..Default::default() + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); @@ -848,27 +1081,27 @@ mod tests { async fn test_join_selection_partitioned() { let big1 = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(10000000), - total_byte_size: Some(10000000), - ..Default::default() + num_rows: Precision::Inexact(10000000), + total_byte_size: Precision::Inexact(10000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col1", DataType::Int32, false)]), )); let big2 = Arc::new(StatisticsExec::new( Statistics { - num_rows: Some(20000000), - total_byte_size: Some(20000000), - ..Default::default() + num_rows: Precision::Inexact(20000000), + total_byte_size: Precision::Inexact(20000000), + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("big_col2", DataType::Int32, false)]), )); let empty = Arc::new(StatisticsExec::new( Statistics { - num_rows: None, - total_byte_size: None, - ..Default::default() + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], }, Schema::new(vec![Field::new("empty_col", DataType::Int32, false)]), )); @@ -967,3 +1200,503 @@ mod tests { } } } + +#[cfg(test)] +mod util_tests { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; + use datafusion_physical_expr::intervals::utils::check_support; + use datafusion_physical_expr::PhysicalExpr; + + #[test] + fn check_expr_supported() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + let supported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(check_support(&supported_expr, &schema)); + let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; + assert!(check_support(&supported_expr_2, &schema)); + let unsupported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(!check_support(&unsupported_expr, &schema)); + let unsupported_expr_2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), + )) as Arc; + assert!(!check_support(&unsupported_expr_2, &schema)); + } +} + +#[cfg(test)] +mod hash_join_tests { + use super::*; + use crate::physical_optimizer::join_selection::swap_join_type; + use crate::physical_optimizer::test_utils::SourceType; + use crate::physical_plan::expressions::Column; + use crate::physical_plan::joins::PartitionMode; + use crate::physical_plan::projection::ProjectionExec; + use crate::test_util::UnboundedExec; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::utils::DataPtr; + use datafusion_common::JoinType; + use datafusion_physical_plan::empty::EmptyExec; + use std::sync::Arc; + + struct TestCase { + case: String, + initial_sources_unbounded: (SourceType, SourceType), + initial_join_type: JoinType, + initial_mode: PartitionMode, + expected_sources_unbounded: (SourceType, SourceType), + expected_join_type: JoinType, + expected_mode: PartitionMode, + expecting_swap: bool, + } + + #[tokio::test] + async fn test_join_with_swap_full() -> Result<()> { + // NOTE: Currently, some initial conditions are not viable after join order selection. + // For example, full join always comes in partitioned mode. See the warning in + // function "swap". If this changes in the future, we should update these tests. + let cases = vec![ + TestCase { + case: "Bounded - Unbounded 1".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Unbounded - Bounded 2".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Bounded - Bounded 3".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + TestCase { + case: "Unbounded - Unbounded 4".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: JoinType::Full, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: JoinType::Full, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }, + ]; + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_cases_without_collect_left_check() -> Result<()> { + let mut cases = vec![]; + let join_types = vec![JoinType::LeftSemi, JoinType::Inner]; + for join_type in join_types { + cases.push(TestCase { + case: "Unbounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::CollectLeft, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::Partitioned, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_not_support_collect_left() -> Result<()> { + let mut cases = vec![]; + // After [JoinSelection] optimization, these join types cannot run in CollectLeft mode except + // [JoinType::LeftSemi] + let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti]; + for join_type in the_ones_not_support_collect_left { + cases.push(TestCase { + case: "Unbounded - Bounded".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: swap_join_type(join_type), + expected_mode: PartitionMode::Partitioned, + expecting_swap: true, + }); + cases.push(TestCase { + case: "Bounded - Unbounded".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + #[tokio::test] + async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> { + let mut cases = vec![]; + let the_ones_not_support_collect_left = + vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi]; + for join_type in the_ones_not_support_collect_left { + // We expect that (SourceType::Unbounded, SourceType::Bounded) will change, regardless of the + // statistics. + cases.push(TestCase { + case: "Unbounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // We expect that (SourceType::Bounded, SourceType::Unbounded) will stay same, regardless of the + // statistics. + cases.push(TestCase { + case: "Bounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // + cases.push(TestCase { + case: "Bounded - Bounded / CollectLeft".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::CollectLeft, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::CollectLeft, + expecting_swap: false, + }); + // If cases are partitioned, only unbounded & bounded check will affect the order. + cases.push(TestCase { + case: "Unbounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Bounded - Bounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + cases.push(TestCase { + case: "Unbounded - Unbounded / Partitioned".to_string(), + initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), + initial_join_type: join_type, + initial_mode: PartitionMode::Partitioned, + expected_sources_unbounded: ( + SourceType::Unbounded, + SourceType::Unbounded, + ), + expected_join_type: join_type, + expected_mode: PartitionMode::Partitioned, + expecting_swap: false, + }); + } + + for case in cases.into_iter() { + test_join_with_maybe_swap_unbounded_case(case).await? + } + Ok(()) + } + + async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { + let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded; + let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded; + let left_exec = Arc::new(UnboundedExec::new( + (!left_unbounded).then_some(1), + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Int32, + false, + )]))), + 2, + )) as Arc; + let right_exec = Arc::new(UnboundedExec::new( + (!right_unbounded).then_some(1), + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "b", + DataType::Int32, + false, + )]))), + 2, + )) as Arc; + + let join = HashJoinExec::try_new( + Arc::clone(&left_exec), + Arc::clone(&right_exec), + vec![( + Column::new_with_schema("a", &left_exec.schema())?, + Column::new_with_schema("b", &right_exec.schema())?, + )], + None, + &t.initial_join_type, + t.initial_mode, + false, + )?; + + let children = vec![ + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: left_unbounded, + children: vec![], + }, + PipelineStatePropagator { + plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + unbounded: right_unbounded, + children: vec![], + }, + ]; + let initial_hash_join_state = PipelineStatePropagator { + plan: Arc::new(join), + unbounded: false, + children, + }; + + let optimized_hash_join = + hash_join_swap_subrule(initial_hash_join_state, &ConfigOptions::new()) + .unwrap()?; + let optimized_join_plan = optimized_hash_join.plan; + + // If swap did happen + let projection_added = optimized_join_plan.as_any().is::(); + let plan = if projection_added { + let proj = optimized_join_plan + .as_any() + .downcast_ref::() + .expect( + "A proj is required to swap columns back to their original order", + ); + proj.input().clone() + } else { + optimized_join_plan + }; + + if let Some(HashJoinExec { + left, + right, + join_type, + mode, + .. + }) = plan.as_any().downcast_ref::() + { + let left_changed = Arc::data_ptr_eq(left, &right_exec); + let right_changed = Arc::data_ptr_eq(right, &left_exec); + // If this is not equal, we have a bigger problem. + assert_eq!(left_changed, right_changed); + assert_eq!( + ( + t.case.as_str(), + if left.unbounded_output(&[])? { + SourceType::Unbounded + } else { + SourceType::Bounded + }, + if right.unbounded_output(&[])? { + SourceType::Unbounded + } else { + SourceType::Bounded + }, + join_type, + mode, + left_changed && right_changed + ), + ( + t.case.as_str(), + t.expected_sources_unbounded.0, + t.expected_sources_unbounded.1, + &t.expected_join_type, + &t.expected_mode, + t.expecting_swap + ) + ); + }; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs new file mode 100644 index 0000000000000..8f5dbc2e9214b --- /dev/null +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::ExecutionPlan; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use itertools::Itertools; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all +/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` +/// `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +pub struct LimitedDistinctAggregation {} + +impl LimitedDistinctAggregation { + /// Create a new `LimitedDistinctAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + limit: usize, + ) -> Option> { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_by().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.order_by_expr().to_vec(), + aggr.input().clone(), + aggr.input_schema(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` + /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation + fn transform_limit(plan: Arc) -> Option> { + let limit: usize; + let mut global_fetch: Option = None; + let mut global_skip: usize = 0; + let children: Vec>; + let mut is_global_limit = false; + if let Some(local_limit) = plan.as_any().downcast_ref::() { + limit = local_limit.fetch(); + children = local_limit.children(); + } else if let Some(global_limit) = plan.as_any().downcast_ref::() + { + global_fetch = global_limit.fetch(); + global_fetch?; + global_skip = global_limit.skip(); + // the aggregate must read at least fetch+skip number of rows + limit = global_fetch.unwrap() + global_skip; + children = global_limit.children(); + is_global_limit = true + } else { + return None; + } + let child = children.iter().exactly_one().ok()?; + // ensure there is no output ordering; can this rule be relaxed? + if plan.output_ordering().is_some() { + return None; + } + // ensure no ordering is required on the input + if plan.required_input_ordering()[0].is_some() { + return None; + } + + // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by + // must match that of a child aggregation in order to rewrite the child aggregation + let mut match_aggr: Arc = plan; + let mut found_match_aggr = false; + + let mut rewrite_applicable = true; + let mut closure = |plan: Arc| { + if !rewrite_applicable { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + if found_match_aggr { + if let Some(parent_aggr) = + match_aggr.as_any().downcast_ref::() + { + if !parent_aggr.group_by().eq(aggr.group_by()) { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::No(plan)); + } + } + } + // either we run into an Aggregate and transform it, or disable the rewrite + // for subsequent children + match Self::transform_agg(aggr, limit) { + None => {} + Some(new_aggr) => { + match_aggr = plan; + found_match_aggr = true; + return Ok(Transformed::Yes(new_aggr)); + } + } + } + rewrite_applicable = false; + Ok(Transformed::No(plan)) + }; + let child = child.clone().transform_down_mut(&mut closure).ok()?; + if is_global_limit { + return Some(Arc::new(GlobalLimitExec::new( + child, + global_skip, + global_fetch, + ))); + } + Some(Arc::new(LocalLimitExec::new(child, limit))) + } +} + +impl Default for LimitedDistinctAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for LimitedDistinctAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = + LimitedDistinctAggregation::transform_limit(plan.clone()) + { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitedDistinctAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; + use crate::physical_optimizer::enforce_distribution::tests::{ + parquet_exec_with_sort, schema, trim_plan_display, + }; + use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; + use crate::physical_plan::collect; + use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_batches; + use arrow_schema::SchemaRef; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::cast; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_plan::aggregates::AggregateMode; + use datafusion_physical_plan::displayable; + use std::sync::Arc; + + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(1), + Some(4), + Some(5), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + fn assert_plan_matches_expected( + plan: &Arc, + expected: &[&str], + ) -> Result<()> { + let expected_lines: Vec<&str> = expected.to_vec(); + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let optimized = LimitedDistinctAggregation::new() + .optimize(Arc::clone(plan), state.config_options())?; + + let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&optimized_result); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + + Ok(()) + } + + async fn assert_results_match_expected( + plan: Arc, + expected: &str, + ) -> Result<()> { + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + let batches = collect(plan, ctx.task_ctx()).await?; + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + Ok(()) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + #[tokio::test] + async fn test_partial_final() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + // expected to push the limit to the Partial and Final AggregateExecs + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_local() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 4, // fetch + ); + // expected to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_single_global() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = GlobalLimitExec::new( + Arc::new(single_agg), + 1, // skip + Some(3), // fetch + ); + // expected to push the skip+fetch limit to the AggregateExec + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[tokio::test] + async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec + let group_by_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let distinct_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + Arc::new(group_by_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(distinct_agg), + 4, // fetch + ); + // expected to push the limit to the outer AggregateExec only + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) + } + + #[test] + fn test_no_group_by() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec![]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![agg.count_expr()], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_filter() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let filter_expr = Some(expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?); + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + vec![None], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } + + #[test] + fn test_has_order_by() -> Result<()> { + let sort_key = vec![PhysicalSortExpr { + expr: expressions::col("a", &schema()).unwrap(), + options: SortOptions::default(), + }]; + let source = parquet_exec_with_sort(vec![sort_key]); + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a ORDER BY a LIMIT 10;`, Single AggregateExec + let order_by_expr = Some(vec![PhysicalSortExpr { + expr: expressions::col("a", &schema.clone()).unwrap(), + options: SortOptions::default(), + }]); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![None], /* filter_expr */ + vec![order_by_expr], /* order_by_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index b4c019d62ba98..e990fead610d1 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -15,24 +15,29 @@ // specific language governing permissions and limitations // under the License. -//! This module contains a query optimizer that operates against a physical plan and applies -//! rules to a physical plan, such as "Repartition". - +//! Optimizer that rewrites [`ExecutionPlan`]s. +//! +//! These rules take advantage of physical plan properties , such as +//! "Repartition" or "Sortedness" +//! +//! [`ExecutionPlan`]: crate::physical_plan::ExecutionPlan pub mod aggregate_statistics; pub mod coalesce_batches; pub mod combine_partial_final_agg; -pub mod dist_enforcement; -pub mod global_sort_selection; +pub mod enforce_distribution; +pub mod enforce_sorting; pub mod join_selection; +pub mod limited_distinct_aggregation; pub mod optimizer; +pub mod output_requirements; pub mod pipeline_checker; +mod projection_pushdown; pub mod pruning; -pub mod repartition; -pub mod sort_enforcement; +pub mod replace_with_order_preserving_variants; mod sort_pushdown; +pub mod topk_aggregation; mod utils; -pub mod pipeline_fixer; #[cfg(test)] pub mod test_utils; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 26ec137e2b7b5..f8c82576e2546 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -19,7 +19,18 @@ use std::sync::Arc; +use super::projection_pushdown::ProjectionPushdown; use crate::config::ConfigOptions; +use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; +use crate::physical_optimizer::coalesce_batches::CoalesceBatches; +use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; +use crate::physical_optimizer::enforce_distribution::EnforceDistribution; +use crate::physical_optimizer::enforce_sorting::EnforceSorting; +use crate::physical_optimizer::join_selection::JoinSelection; +use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; +use crate::physical_optimizer::output_requirements::OutputRequirements; +use crate::physical_optimizer::pipeline_checker::PipelineChecker; +use crate::physical_optimizer::topk_aggregation::TopKAggregation; use crate::{error::Result, physical_plan::ExecutionPlan}; /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which @@ -42,3 +53,80 @@ pub trait PhysicalOptimizerRule { /// and should disable the schema check. fn schema_check(&self) -> bool; } + +/// A rule-based physical optimizer. +#[derive(Clone)] +pub struct PhysicalOptimizer { + /// All rules to apply + pub rules: Vec>, +} + +impl Default for PhysicalOptimizer { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizer { + /// Create a new optimizer using the recommended list of rules + pub fn new() -> Self { + let rules: Vec> = vec![ + // If there is a output requirement of the query, make sure that + // this information is not lost across different rules during optimization. + Arc::new(OutputRequirements::new_add_mode()), + Arc::new(AggregateStatistics::new()), + // Statistics-based join selection will change the Auto mode to a real join implementation, + // like collect left, or hash join, or future sort merge join, which will influence the + // EnforceDistribution and EnforceSorting rules as they decide whether to add additional + // repartitioning and local sorting steps to meet distribution and ordering requirements. + // Therefore, it should run before EnforceDistribution and EnforceSorting. + Arc::new(JoinSelection::new()), + // The LimitedDistinctAggregation rule should be applied before the EnforceDistribution rule, + // as that rule may inject other operations in between the different AggregateExecs. + // Applying the rule early means only directly-connected AggregateExecs must be examined. + Arc::new(LimitedDistinctAggregation::new()), + // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution + // requirements. Please make sure that the whole plan tree is determined before this rule. + // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at + // least one of the operators in the plan benefits from increased parallelism. + Arc::new(EnforceDistribution::new()), + // The CombinePartialFinalAggregate rule should be applied after the EnforceDistribution rule + Arc::new(CombinePartialFinalAggregate::new()), + // The EnforceSorting rule is for adding essential local sorting to satisfy the required + // ordering. Please make sure that the whole plan tree is determined before this rule. + // Note that one should always run this rule after running the EnforceDistribution rule + // as the latter may break local sorting requirements. + Arc::new(EnforceSorting::new()), + // The CoalesceBatches rule will not influence the distribution and ordering of the + // whole plan tree. Therefore, to avoid influencing other rules, it should run last. + Arc::new(CoalesceBatches::new()), + // Remove the ancillary output requirement operator since we are done with the planning + // phase. + Arc::new(OutputRequirements::new_remove_mode()), + // The PipelineChecker rule will reject non-runnable query plans that use + // pipeline-breaking operators on infinite input(s). The rule generates a + // diagnostic error message when this happens. It makes no changes to the + // given query plan; i.e. it only acts as a final gatekeeping rule. + Arc::new(PipelineChecker::new()), + // The aggregation limiter will try to find situations where the accumulator count + // is not tied to the cardinality, i.e. when the output of the aggregation is passed + // into an `order by max(x) limit y`. In this case it will copy the limit value down + // to the aggregation, allowing it to use only y number of accumulators. + Arc::new(TopKAggregation::new()), + // The ProjectionPushdown rule tries to push projections towards + // the sources in the execution plan. As a result of this process, + // a projection can disappear if it reaches the source providers, and + // sequential projections can merge into one. Even if these two cases + // are not present, the load of executors such as join or union will be + // reduced by narrowing their input tables. + Arc::new(ProjectionPushdown::new()), + ]; + + Self::with_rules(rules) + } + + /// Create a new optimizer with the given rules + pub fn with_rules(rules: Vec>) -> Self { + Self { rules } + } +} diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs new file mode 100644 index 0000000000000..f8bf3bb965e8c --- /dev/null +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The GlobalOrderRequire optimizer rule either: +//! - Adds an auxiliary `OutputRequirementExec` operator to keep track of global +//! ordering and distribution requirement across rules, or +//! - Removes the auxiliary `OutputRequirementExec` operator from the physical plan. +//! Since the `OutputRequirementExec` operator is only a helper operator, it +//! shouldn't occur in the final plan (i.e. the executed plan). + +use std::sync::Arc; + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr::{ + Distribution, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + +/// This rule either adds or removes [`OutputRequirements`]s to/from the physical +/// plan according to its `mode` attribute, which is set by the constructors +/// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of +/// the global requirements (ordering and distribution) across rules. +/// +/// The primary usecase of this node and rule is to specify and preserve the desired output +/// ordering and distribution the entire plan. When sending to a single client, a single partition may +/// be desirable, but when sending to a multi-partitioned writer, keeping multiple partitions may be +/// better. +#[derive(Debug)] +pub struct OutputRequirements { + mode: RuleMode, +} + +impl OutputRequirements { + /// Create a new rule which works in `Add` mode; i.e. it simply adds a + /// top-level [`OutputRequirementExec`] into the physical plan to keep track + /// of global ordering and distribution requirements if there are any. + /// Note that this rule should run at the beginning. + pub fn new_add_mode() -> Self { + Self { + mode: RuleMode::Add, + } + } + + /// Create a new rule which works in `Remove` mode; i.e. it simply removes + /// the top-level [`OutputRequirementExec`] from the physical plan if there is + /// any. We do this because a `OutputRequirementExec` is an ancillary, + /// non-executable operator whose sole purpose is to track global + /// requirements during optimization. Therefore, a + /// `OutputRequirementExec` should not appear in the final plan. + pub fn new_remove_mode() -> Self { + Self { + mode: RuleMode::Remove, + } + } +} + +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Hash)] +enum RuleMode { + Add, + Remove, +} + +/// An ancillary, non-executable operator whose sole purpose is to track global +/// requirements during optimization. It imposes +/// - the ordering requirement in its `order_requirement` attribute. +/// - the distribution requirement in its `dist_requirement` attribute. +/// +/// See [`OutputRequirements`] for more details +#[derive(Debug)] +pub(crate) struct OutputRequirementExec { + input: Arc, + order_requirement: Option, + dist_requirement: Distribution, +} + +impl OutputRequirementExec { + pub(crate) fn new( + input: Arc, + requirements: Option, + dist_requirement: Distribution, + ) -> Self { + Self { + input, + order_requirement: requirements, + dist_requirement, + } + } + + pub(crate) fn input(&self) -> Arc { + self.input.clone() + } +} + +impl DisplayAs for OutputRequirementExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "OutputRequirementExec") + } +} + +impl ExecutionPlan for OutputRequirementExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> crate::physical_plan::Partitioning { + self.input.output_partitioning() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn required_input_distribution(&self) -> Vec { + vec![self.dist_requirement.clone()] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn required_input_ordering(&self) -> Vec>> { + vec![self.order_requirement.clone()] + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + // Has a single child + Ok(children[0]) + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), // has a single child + self.order_requirement.clone(), + self.dist_requirement.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!(); + } + + fn statistics(&self) -> Result { + self.input.statistics() + } +} + +impl PhysicalOptimizerRule for OutputRequirements { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + match self.mode { + RuleMode::Add => require_top_ordering(plan), + RuleMode::Remove => plan.transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::Yes(sort_req.input())) + } else { + Ok(Transformed::No(plan)) + } + }), + } + } + + fn name(&self) -> &str { + "OutputRequirements" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This functions adds ancillary `OutputRequirementExec` to the the physical plan, so that +/// global requirements are not lost during optimization. +fn require_top_ordering(plan: Arc) -> Result> { + let (new_plan, is_changed) = require_top_ordering_helper(plan)?; + if is_changed { + Ok(new_plan) + } else { + // Add `OutputRequirementExec` to the top, with no specified ordering and distribution requirement. + Ok(Arc::new(OutputRequirementExec::new( + new_plan, + // there is no ordering requirement + None, + Distribution::UnspecifiedDistribution, + )) as _) + } +} + +/// Helper function that adds an ancillary `OutputRequirementExec` to the given plan. +/// First entry in the tuple is resulting plan, second entry indicates whether any +/// `OutputRequirementExec` is added to the plan. +fn require_top_ordering_helper( + plan: Arc, +) -> Result<(Arc, bool)> { + let mut children = plan.children(); + // Global ordering defines desired ordering in the final result. + if children.len() != 1 { + Ok((plan, false)) + } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let req_ordering = sort_exec.output_ordering().unwrap_or(&[]); + let req_dist = sort_exec.required_input_distribution()[0].clone(); + let reqs = PhysicalSortRequirement::from_sort_exprs(req_ordering); + Ok(( + Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, + true, + )) + } else if let Some(spm) = plan.as_any().downcast_ref::() { + let reqs = PhysicalSortRequirement::from_sort_exprs(spm.expr()); + Ok(( + Arc::new(OutputRequirementExec::new( + plan, + Some(reqs), + Distribution::SinglePartition, + )) as _, + true, + )) + } else if plan.maintains_input_order()[0] + && plan.required_input_ordering()[0].is_none() + { + // Keep searching for a `SortExec` as long as ordering is maintained, + // and on-the-way operators do not themselves require an ordering. + // When an operator requires an ordering, any `SortExec` below can not + // be responsible for (i.e. the originator of) the global ordering. + let (new_child, is_changed) = + require_top_ordering_helper(children.swap_remove(0))?; + Ok((plan.with_new_children(vec![new_child])?, is_changed)) + } else { + // Stop searching, there is no global ordering desired for the query. + Ok((plan, false)) + } +} diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index b12c4ef93fc86..d59248aadf056 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -18,17 +18,19 @@ //! The [PipelineChecker] rule ensures that a given plan can accommodate its //! infinite sources, if there are any. It will reject non-runnable query plans //! that use pipeline-breaking operators on infinite input(s). -//! + +use std::sync::Arc; + use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::joins::SymmetricHashJoinExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + use datafusion_common::config::OptimizerOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::intervals::{check_support, is_datatype_supported}; -use std::sync::Arc; +use datafusion_common::{plan_err, DataFusionError}; +use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; /// The PipelineChecker rule rejects non-runnable query plans that use /// pipeline-breaking operators on infinite input(s). @@ -68,19 +70,27 @@ impl PhysicalOptimizerRule for PipelineChecker { pub struct PipelineStatePropagator { pub(crate) plan: Arc, pub(crate) unbounded: bool, - pub(crate) children_unbounded: Vec, + pub(crate) children: Vec, } impl PipelineStatePropagator { /// Constructs a new, default pipelining state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); + let children = plan.children(); PipelineStatePropagator { plan, unbounded: false, - children_unbounded: vec![false; length], + children: children.into_iter().map(Self::new).collect(), } } + + /// Returns the children unboundedness information. + pub fn children_unbounded(&self) -> Vec { + self.children + .iter() + .map(|c| c.unbounded) + .collect::>() + } } impl TreeNode for PipelineStatePropagator { @@ -88,9 +98,8 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(&Self) -> Result, { - let children = self.plan.children(); - for child in children { - match op(&PipelineStatePropagator::new(child))? { + for child in &self.children { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -104,25 +113,18 @@ impl TreeNode for PipelineStatePropagator { where F: FnMut(Self) -> Result, { - let children = self.plan.children(); - if !children.is_empty() { - let new_children = children + if !self.children.is_empty() { + let new_children = self + .children .into_iter() - .map(|child| PipelineStatePropagator::new(child)) .map(transform) .collect::>>()?; - let children_unbounded = new_children - .iter() - .map(|c| c.unbounded) - .collect::>(); - let children_plans = new_children - .into_iter() - .map(|child| child.plan) - .collect::>(); + let children_plans = new_children.iter().map(|c| c.plan.clone()).collect(); + Ok(PipelineStatePropagator { plan: with_new_children_if_necessary(self.plan, children_plans)?.into(), unbounded: self.unbounded, - children_unbounded, + children: new_children, }) } else { Ok(self) @@ -142,12 +144,12 @@ pub fn check_finiteness_requirements( { const MSG: &str = "Join operation cannot operate on a non-prunable stream without enabling \ the 'allow_symmetric_joins_without_pruning' configuration flag"; - return Err(DataFusionError::Plan(MSG.to_owned())); + return plan_err!("{}", MSG); } } input .plan - .unbounded_output(&input.children_unbounded) + .unbounded_output(&input.children_unbounded()) .map(|value| { input.unbounded = value; Transformed::Yes(input) @@ -163,7 +165,7 @@ pub fn check_finiteness_requirements( /// [`Operator`]: datafusion_expr::Operator fn is_prunable(join: &SymmetricHashJoinExec) -> bool { join.filter().map_or(false, |filter| { - check_support(filter.expression()) + check_support(filter.expression(), &join.schema()) && filter .schema() .fields() diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs deleted file mode 100644 index caae7743450d9..0000000000000 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ /dev/null @@ -1,713 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [PipelineFixer] rule tries to modify a given plan so that it can -//! accommodate its infinite sources, if there are any. In other words, -//! it tries to obtain a runnable query (with the given infinite sources) -//! from an non-runnable query by transforming pipeline-breaking operations -//! to pipeline-friendly ones. If this can not be done, the rule emits a -//! diagnostic error message. -//! -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::join_selection::swap_hash_join; -use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SymmetricHashJoinExec}; -use crate::physical_plan::ExecutionPlan; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::DataFusionError; -use datafusion_expr::logical_plan::JoinType; - -use std::sync::Arc; - -/// The [`PipelineFixer`] rule tries to modify a given plan so that it can -/// accommodate its infinite sources, if there are any. If this is not -/// possible, the rule emits a diagnostic error message. -#[derive(Default)] -pub struct PipelineFixer {} - -impl PipelineFixer { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} -/// [`PipelineFixer`] subrules are functions of this type. Such functions take a -/// single [`PipelineStatePropagator`] argument, which stores state variables -/// indicating the unboundedness status of the current [`ExecutionPlan`] as -/// the `PipelineFixer` rule traverses the entire plan tree. -type PipelineFixerSubrule = - dyn Fn(PipelineStatePropagator) -> Option>; - -impl PhysicalOptimizerRule for PipelineFixer { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - let pipeline = PipelineStatePropagator::new(plan); - let subrules: Vec> = vec![ - Box::new(hash_join_convert_symmetric_subrule), - Box::new(hash_join_swap_subrule), - ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules))?; - Ok(state.plan) - } - - fn name(&self) -> &str { - "PipelineFixer" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// This subrule checks if one can replace a hash join with a symmetric hash -/// join so that the pipeline does not break due to the join operation in -/// question. If possible, it makes this replacement; otherwise, it has no -/// effect. -fn hash_join_convert_symmetric_subrule( - mut input: PipelineStatePropagator, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded && right_unbounded { - SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), - hash_join - .on() - .iter() - .map(|(l, r)| (l.clone(), r.clone())) - .collect(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.null_equals_null(), - ) - .map(|exec| { - input.plan = Arc::new(exec) as _; - input - }) - } else { - Ok(input) - }; - Some(result) - } else { - None - } -} - -/// This subrule will swap build/probe sides of a hash join depending on whether its inputs -/// may produce an infinite stream of records. The rule ensures that the left (build) side -/// of the hash join always operates on an input stream that will produce a finite set of. -/// records If the left side can not be chosen to be "finite", the order stays the -/// same as the original query. -/// ```text -/// For example, this rule makes the following transformation: -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Infinite | true | Hash |\true -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | | | -/// +--------------+ +--------------+ - | Hash Join |-------| Projection | -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Finite | false | Hash |/false -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Finite | false | Hash |\false -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | true | | true -/// +--------------+ +--------------+ - | Hash Join |-------| Projection |----- -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Infinite | true | Hash |/true -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// ``` -fn hash_join_swap_subrule( - mut input: PipelineStatePropagator, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = &input.children_unbounded; - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded - && !right_unbounded - && matches!( - *hash_join.join_type(), - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - ) { - swap(hash_join).map(|plan| { - input.plan = plan; - input - }) - } else { - Ok(input) - }; - Some(result) - } else { - None - } -} - -/// This function swaps sides of a hash join to make it runnable even if one of its -/// inputs are infinite. Note that this is not always possible; i.e. [JoinType::Full], -/// [JoinType::Right], [JoinType::RightAnti] and [JoinType::RightSemi] can not run with -/// an unbounded left side, even if we swap. Therefore, we do not consider them here. -fn swap(hash_join: &HashJoinExec) -> Result> { - let partition_mode = hash_join.partition_mode(); - let join_type = hash_join.join_type(); - match (*partition_mode, *join_type) { - ( - _, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, - ) => Err(DataFusionError::Internal(format!( - "{join_type} join cannot be swapped for unbounded input." - ))), - (PartitionMode::Partitioned, _) => { - swap_hash_join(hash_join, PartitionMode::Partitioned) - } - (PartitionMode::CollectLeft, _) => { - swap_hash_join(hash_join, PartitionMode::CollectLeft) - } - (PartitionMode::Auto, _) => Err(DataFusionError::Internal( - "Auto is not acceptable for unbounded input here.".to_string(), - )), - } -} - -fn apply_subrules( - mut input: PipelineStatePropagator, - subrules: &Vec>, -) -> Result> { - for subrule in subrules { - if let Some(value) = subrule(input.clone()).transpose()? { - input = value; - } - } - let is_unbounded = input - .plan - .unbounded_output(&input.children_unbounded) - // Treat the case where an operator can not run on unbounded data as - // if it can and it outputs unbounded data. Do not raise an error yet. - // Such operators may be fixed, adjusted or replaced later on during - // optimization passes -- sorts may be removed, windows may be adjusted - // etc. If this doesn't happen, the final `PipelineChecker` rule will - // catch this and raise an error anyway. - .unwrap_or(true); - input.unbounded = is_unbounded; - Ok(Transformed::Yes(input)) -} - -#[cfg(test)] -mod util_tests { - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; - use datafusion_physical_expr::intervals::check_support; - use datafusion_physical_expr::PhysicalExpr; - use std::sync::Arc; - - #[test] - fn check_expr_supported() { - let supported_expr = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )) as Arc; - assert!(check_support(&supported_expr)); - let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; - assert!(check_support(&supported_expr_2)); - let unsupported_expr = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Or, - Arc::new(Column::new("a", 0)), - )) as Arc; - assert!(!check_support(&unsupported_expr)); - let unsupported_expr_2 = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Or, - Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), - )) as Arc; - assert!(!check_support(&unsupported_expr_2)); - } -} - -#[cfg(test)] -mod hash_join_tests { - use super::*; - use crate::physical_optimizer::join_selection::swap_join_type; - use crate::physical_optimizer::test_utils::SourceType; - use crate::physical_plan::expressions::Column; - use crate::physical_plan::joins::PartitionMode; - use crate::physical_plan::projection::ProjectionExec; - use crate::test_util::UnboundedExec; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::utils::DataPtr; - use std::sync::Arc; - - struct TestCase { - case: String, - initial_sources_unbounded: (SourceType, SourceType), - initial_join_type: JoinType, - initial_mode: PartitionMode, - expected_sources_unbounded: (SourceType, SourceType), - expected_join_type: JoinType, - expected_mode: PartitionMode, - expecting_swap: bool, - } - - #[tokio::test] - async fn test_join_with_swap_full() -> Result<()> { - // NOTE: Currently, some initial conditions are not viable after join order selection. - // For example, full join always comes in partitioned mode. See the warning in - // function "swap". If this changes in the future, we should update these tests. - let cases = vec![ - TestCase { - case: "Bounded - Unbounded 1".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Unbounded - Bounded 2".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Bounded - Bounded 3".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - TestCase { - case: "Unbounded - Unbounded 4".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: JoinType::Full, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: JoinType::Full, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }, - ]; - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_cases_without_collect_left_check() -> Result<()> { - let mut cases = vec![]; - let join_types = vec![JoinType::LeftSemi, JoinType::Inner]; - for join_type in join_types { - cases.push(TestCase { - case: "Unbounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::CollectLeft, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::Partitioned, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_not_support_collect_left() -> Result<()> { - let mut cases = vec![]; - // After [JoinSelection] optimization, these join types cannot run in CollectLeft mode except - // [JoinType::LeftSemi] - let the_ones_not_support_collect_left = vec![JoinType::Left, JoinType::LeftAnti]; - for join_type in the_ones_not_support_collect_left { - cases.push(TestCase { - case: "Unbounded - Bounded".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: swap_join_type(join_type), - expected_mode: PartitionMode::Partitioned, - expecting_swap: true, - }); - cases.push(TestCase { - case: "Bounded - Unbounded".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - #[tokio::test] - async fn test_not_supporting_swaps_possible_collect_left() -> Result<()> { - let mut cases = vec![]; - let the_ones_not_support_collect_left = - vec![JoinType::Right, JoinType::RightAnti, JoinType::RightSemi]; - for join_type in the_ones_not_support_collect_left { - // We expect that (SourceType::Unbounded, SourceType::Bounded) will change, regardless of the - // statistics. - cases.push(TestCase { - case: "Unbounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // We expect that (SourceType::Bounded, SourceType::Unbounded) will stay same, regardless of the - // statistics. - cases.push(TestCase { - case: "Bounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // - cases.push(TestCase { - case: "Bounded - Bounded / CollectLeft".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::CollectLeft, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::CollectLeft, - expecting_swap: false, - }); - // If cases are partitioned, only unbounded & bounded check will affect the order. - cases.push(TestCase { - case: "Unbounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Unbounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Unbounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Bounded - Bounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: (SourceType::Bounded, SourceType::Bounded), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - cases.push(TestCase { - case: "Unbounded - Unbounded / Partitioned".to_string(), - initial_sources_unbounded: (SourceType::Unbounded, SourceType::Unbounded), - initial_join_type: join_type, - initial_mode: PartitionMode::Partitioned, - expected_sources_unbounded: ( - SourceType::Unbounded, - SourceType::Unbounded, - ), - expected_join_type: join_type, - expected_mode: PartitionMode::Partitioned, - expecting_swap: false, - }); - } - - for case in cases.into_iter() { - test_join_with_maybe_swap_unbounded_case(case).await? - } - Ok(()) - } - - async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { - let left_unbounded = t.initial_sources_unbounded.0 == SourceType::Unbounded; - let right_unbounded = t.initial_sources_unbounded.1 == SourceType::Unbounded; - let left_exec = Arc::new(UnboundedExec::new( - (!left_unbounded).then_some(1), - RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Int32, - false, - )]))), - 2, - )) as Arc; - let right_exec = Arc::new(UnboundedExec::new( - (!right_unbounded).then_some(1), - RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( - "b", - DataType::Int32, - false, - )]))), - 2, - )) as Arc; - - let join = HashJoinExec::try_new( - Arc::clone(&left_exec), - Arc::clone(&right_exec), - vec![( - Column::new_with_schema("a", &left_exec.schema())?, - Column::new_with_schema("b", &right_exec.schema())?, - )], - None, - &t.initial_join_type, - t.initial_mode, - false, - )?; - - let initial_hash_join_state = PipelineStatePropagator { - plan: Arc::new(join), - unbounded: false, - children_unbounded: vec![left_unbounded, right_unbounded], - }; - let optimized_hash_join = - hash_join_swap_subrule(initial_hash_join_state).unwrap()?; - let optimized_join_plan = optimized_hash_join.plan; - - // If swap did happen - let projection_added = optimized_join_plan.as_any().is::(); - let plan = if projection_added { - let proj = optimized_join_plan - .as_any() - .downcast_ref::() - .expect( - "A proj is required to swap columns back to their original order", - ); - proj.input().clone() - } else { - optimized_join_plan - }; - - if let Some(HashJoinExec { - left, - right, - join_type, - mode, - .. - }) = plan.as_any().downcast_ref::() - { - let left_changed = Arc::data_ptr_eq(left, &right_exec); - let right_changed = Arc::data_ptr_eq(right, &left_exec); - // If this is not equal, we have a bigger problem. - assert_eq!(left_changed, right_changed); - assert_eq!( - ( - t.case.as_str(), - if left.unbounded_output(&[])? { - SourceType::Unbounded - } else { - SourceType::Bounded - }, - if right.unbounded_output(&[])? { - SourceType::Unbounded - } else { - SourceType::Bounded - }, - join_type, - mode, - left_changed && right_changed - ), - ( - t.case.as_str(), - t.expected_sources_unbounded.0, - t.expected_sources_unbounded.1, - &t.expected_join_type, - &t.expected_mode, - t.expecting_swap - ) - ); - }; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs new file mode 100644 index 0000000000000..67a2eaf0d9b3e --- /dev/null +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -0,0 +1,2310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the `ProjectionPushdown` physical optimization rule. +//! The function [`remove_unnecessary_projections`] tries to push down all +//! projections one by one if the operator below is amenable to this. If a +//! projection reaches a source, it can even dissappear from the plan entirely. + +use std::collections::HashMap; +use std::sync::Arc; + +use super::output_requirements::OutputRequirementExec; +use super::PhysicalOptimizerRule; +use crate::datasource::physical_plan::CsvExec; +use crate::error::Result; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + SymmetricHashJoinExec, +}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{Distribution, ExecutionPlan}; + +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::JoinSide; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr::{ + Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; +use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::union::UnionExec; + +use itertools::Itertools; + +/// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to +/// remove or swap with its child. +#[derive(Default)] +pub struct ProjectionPushdown {} + +impl ProjectionPushdown { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for ProjectionPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(&remove_unnecessary_projections) + } + + fn name(&self) -> &str { + "ProjectionPushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function checks if `plan` is a [`ProjectionExec`], and inspects its +/// input(s) to test whether it can push `plan` under its input(s). This function +/// will operate on the entire tree and may ultimately remove `plan` entirely +/// by leveraging source providers with built-in projection capabilities. +pub fn remove_unnecessary_projections( + plan: Arc, +) -> Result>> { + let maybe_modified = if let Some(projection) = + plan.as_any().downcast_ref::() + { + // If the projection does not cause any change on the input, we can + // safely remove it: + if is_projection_removable(projection) { + return Ok(Transformed::Yes(projection.input().clone())); + } + // If it does, check if we can push it under its child(ren): + let input = projection.input().as_any(); + if let Some(csv) = input.downcast_ref::() { + try_swapping_with_csv(projection, csv) + } else if let Some(memory) = input.downcast_ref::() { + try_swapping_with_memory(projection, memory)? + } else if let Some(child_projection) = input.downcast_ref::() { + let maybe_unified = try_unifying_projections(projection, child_projection)?; + return if let Some(new_plan) = maybe_unified { + // To unify 3 or more sequential projections: + remove_unnecessary_projections(new_plan) + } else { + Ok(Transformed::No(plan)) + }; + } else if let Some(output_req) = input.downcast_ref::() { + try_swapping_with_output_req(projection, output_req)? + } else if input.is::() { + try_swapping_with_coalesce_partitions(projection)? + } else if let Some(filter) = input.downcast_ref::() { + try_swapping_with_filter(projection, filter)? + } else if let Some(repartition) = input.downcast_ref::() { + try_swapping_with_repartition(projection, repartition)? + } else if let Some(sort) = input.downcast_ref::() { + try_swapping_with_sort(projection, sort)? + } else if let Some(spm) = input.downcast_ref::() { + try_swapping_with_sort_preserving_merge(projection, spm)? + } else if let Some(union) = input.downcast_ref::() { + try_pushdown_through_union(projection, union)? + } else if let Some(hash_join) = input.downcast_ref::() { + try_pushdown_through_hash_join(projection, hash_join)? + } else if let Some(cross_join) = input.downcast_ref::() { + try_swapping_with_cross_join(projection, cross_join)? + } else if let Some(nl_join) = input.downcast_ref::() { + try_swapping_with_nested_loop_join(projection, nl_join)? + } else if let Some(sm_join) = input.downcast_ref::() { + try_swapping_with_sort_merge_join(projection, sm_join)? + } else if let Some(sym_join) = input.downcast_ref::() { + try_swapping_with_sym_hash_join(projection, sym_join)? + } else if let Some(ste) = input.downcast_ref::() { + try_swapping_with_streaming_table(projection, ste)? + } else { + // If the input plan of the projection is not one of the above, we + // conservatively assume that pushing the projection down may hurt. + // When adding new operators, consider adding them here if you + // think pushing projections under them is beneficial. + None + } + } else { + return Ok(Transformed::No(plan)); + }; + + Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) +} + +/// Tries to embed `projection` to its input (`csv`). If possible, returns +/// [`CsvExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_csv( + projection: &ProjectionExec, + csv: &CsvExec, +) -> Option> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into CsvExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()).then(|| { + let mut file_scan = csv.base_config().clone(); + let new_projections = + new_projections_for_columns(projection, &file_scan.projection); + file_scan.projection = Some(new_projections); + + Arc::new(CsvExec::new( + file_scan, + csv.has_header(), + csv.delimiter(), + csv.quote(), + csv.escape(), + csv.file_compression_type, + )) as _ + }) +} + +/// Tries to embed `projection` to its input (`memory`). If possible, returns +/// [`MemoryExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_memory( + projection: &ProjectionExec, + memory: &MemoryExec, +) -> Result>> { + // If there is any non-column or alias-carrier expression, Projection should not be removed. + // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. + all_alias_free_columns(projection.expr()) + .then(|| { + let new_projections = + new_projections_for_columns(projection, memory.projection()); + + MemoryExec::try_new( + memory.partitions(), + memory.original_schema(), + Some(new_projections), + ) + .map(|e| Arc::new(e) as _) + }) + .transpose() +} + +/// Tries to embed `projection` to its input (`streaming table`). +/// If possible, returns [`StreamingTableExec`] as the top plan. Otherwise, +/// returns `None`. +fn try_swapping_with_streaming_table( + projection: &ProjectionExec, + streaming_table: &StreamingTableExec, +) -> Result>> { + if !all_alias_free_columns(projection.expr()) { + return Ok(None); + } + + let streaming_table_projections = streaming_table + .projection() + .as_ref() + .map(|i| i.as_ref().to_vec()); + let new_projections = + new_projections_for_columns(projection, &streaming_table_projections); + + let mut lex_orderings = vec![]; + for lex_ordering in streaming_table.projected_output_ordering().into_iter() { + let mut orderings = vec![]; + for order in lex_ordering { + let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? + else { + return Ok(None); + }; + orderings.push(PhysicalSortExpr { + expr: new_ordering, + options: order.options, + }); + } + lex_orderings.push(orderings); + } + + StreamingTableExec::try_new( + streaming_table.partition_schema().clone(), + streaming_table.partitions().clone(), + Some(&new_projections), + lex_orderings, + streaming_table.is_infinite(), + ) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Unifies `projection` with its input (which is also a [`ProjectionExec`]). +fn try_unifying_projections( + projection: &ProjectionExec, + child: &ProjectionExec, +) -> Result>> { + let mut projected_exprs = vec![]; + let mut column_ref_map: HashMap = HashMap::new(); + + // Collect the column references usage in the outer projection. + projection.expr().iter().for_each(|(expr, _)| { + expr.apply(&mut |expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + VisitRecursion::Continue + }) + }) + .unwrap(); + }); + + // Merging these projections is not beneficial, e.g + // If an expression is not trivial and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/arrow-datafusion/issues/8296 + if column_ref_map.iter().any(|(column, count)| { + *count > 1 && !is_expr_trivial(&child.expr()[column.index()].0.clone()) + }) { + return Ok(None); + } + + for (expr, alias) in projection.expr() { + // If there is no match in the input projection, we cannot unify these + // projections. This case will arise if the projection expression contains + // a `PhysicalExpr` variant `update_expr` doesn't support. + let Some(expr) = update_expr(expr, child.expr(), true)? else { + return Ok(None); + }; + projected_exprs.push((expr, alias.clone())); + } + + ProjectionExec::try_new(projected_exprs, child.input().clone()) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Checks if the given expression is trivial. +/// An expression is considered trivial if it is either a `Column` or a `Literal`. +fn is_expr_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + +/// Tries to swap `projection` with its input (`output_req`). If possible, +/// performs the swap and returns [`OutputRequirementExec`] as the top plan. +/// Otherwise, returns `None`. +fn try_swapping_with_output_req( + projection: &ProjectionExec, + output_req: &OutputRequirementExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_sort_reqs = vec![]; + // None or empty_vec can be treated in the same way. + if let Some(reqs) = &output_req.required_input_ordering()[0] { + for req in reqs { + let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_sort_reqs.push(PhysicalSortRequirement { + expr: new_expr, + options: req.options, + }); + } + } + + let dist_req = match &output_req.required_input_distribution()[0] { + Distribution::HashPartitioned(exprs) => { + let mut updated_exprs = vec![]; + for expr in exprs { + let Some(new_expr) = update_expr(expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(new_expr); + } + Distribution::HashPartitioned(updated_exprs) + } + dist => dist.clone(), + }; + + make_with_child(projection, &output_req.input()) + .map(|input| { + OutputRequirementExec::new( + input, + (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), + dist_req, + ) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap `projection` with its input, which is known to be a +/// [`CoalescePartitionsExec`]. If possible, performs the swap and returns +/// [`CoalescePartitionsExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_coalesce_partitions( + projection: &ProjectionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // CoalescePartitionsExec always has a single child, so zero indexing is safe. + make_with_child(projection, &projection.input().children()[0]) + .map(|e| Some(Arc::new(CoalescePartitionsExec::new(e)) as _)) +} + +/// Tries to swap `projection` with its input (`filter`). If possible, performs +/// the swap and returns [`FilterExec`] as the top plan. Otherwise, returns `None`. +fn try_swapping_with_filter( + projection: &ProjectionExec, + filter: &FilterExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down: + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + // Each column in the predicate expression must exist after the projection. + let Some(new_predicate) = update_expr(filter.predicate(), projection.expr(), false)? + else { + return Ok(None); + }; + + FilterExec::try_new(new_predicate, make_with_child(projection, filter.input())?) + .and_then(|e| { + let selectivity = filter.default_selectivity(); + e.with_default_selectivity(selectivity) + }) + .map(|e| Some(Arc::new(e) as _)) +} + +/// Tries to swap the projection with its input [`RepartitionExec`]. If it can be done, +/// it returns the new swapped version having the [`RepartitionExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_repartition( + projection: &ProjectionExec, + repartition: &RepartitionExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + // If pushdown is not beneficial or applicable, break it. + if projection.benefits_from_input_partitioning()[0] || !all_columns(projection.expr()) + { + return Ok(None); + } + + let new_projection = make_with_child(projection, repartition.input())?; + + let new_partitioning = match repartition.partitioning() { + Partitioning::Hash(partitions, size) => { + let mut new_partitions = vec![]; + for partition in partitions { + let Some(new_partition) = + update_expr(partition, projection.expr(), false)? + else { + return Ok(None); + }; + new_partitions.push(new_partition); + } + Partitioning::Hash(new_partitions, *size) + } + others => others.clone(), + }; + + Ok(Some(Arc::new(RepartitionExec::try_new( + new_projection, + new_partitioning, + )?))) +} + +/// Tries to swap the projection with its input [`SortExec`]. If it can be done, +/// it returns the new swapped version having the [`SortExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort( + projection: &ProjectionExec, + sort: &SortExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in sort.expr() { + let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: new_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortExec::new(updated_exprs, make_with_child(projection, sort.input())?) + .with_fetch(sort.fetch()), + ))) +} + +/// Tries to swap the projection with its input [`SortPreservingMergeExec`]. +/// If this is possible, it returns the new [`SortPreservingMergeExec`] whose +/// child is a projection. Otherwise, it returns None. +fn try_swapping_with_sort_preserving_merge( + projection: &ProjectionExec, + spm: &SortPreservingMergeExec, +) -> Result>> { + // If the projection does not narrow the the schema, we should not try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let mut updated_exprs = vec![]; + for sort in spm.expr() { + let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? + else { + return Ok(None); + }; + updated_exprs.push(PhysicalSortExpr { + expr: updated_expr, + options: sort.options, + }); + } + + Ok(Some(Arc::new( + SortPreservingMergeExec::new( + updated_exprs, + make_with_child(projection, spm.input())?, + ) + .with_fetch(spm.fetch()), + ))) +} + +/// Tries to push `projection` down through `union`. If possible, performs the +/// pushdown and returns a new [`UnionExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_union( + projection: &ProjectionExec, + union: &UnionExec, +) -> Result>> { + // If the projection doesn't narrow the schema, we shouldn't try to push it down. + if projection.expr().len() >= projection.input().schema().fields().len() { + return Ok(None); + } + + let new_children = union + .children() + .into_iter() + .map(|child| make_with_child(projection, &child)) + .collect::>>()?; + + Ok(Some(Arc::new(UnionExec::new(new_children)))) +} + +/// Tries to push `projection` down through `hash_join`. If possible, performs the +/// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections +/// as its children. Otherwise, returns `None`. +fn try_pushdown_through_hash_join( + projection: &ProjectionExec, + hash_join: &HashJoinExec, +) -> Result>> { + // Convert projected expressions to columns. We can not proceed if this is + // not possible. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + hash_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + hash_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + hash_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = hash_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + hash_join.left(), + hash_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + hash_join.left(), + hash_join.right(), + )?; + + Ok(Some(Arc::new(HashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + hash_join.join_type(), + *hash_join.partition_mode(), + hash_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`CrossJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_cross_join( + projection: &ProjectionExec, + cross_join: &CrossJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + cross_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + cross_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + cross_join.left(), + cross_join.right(), + )?; + + Ok(Some(Arc::new(CrossJoinExec::new( + Arc::new(new_left), + Arc::new(new_right), + )))) +} + +/// Tries to swap the projection with its input [`NestedLoopJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`NestedLoopJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_nested_loop_join( + projection: &ProjectionExec, + nl_join: &NestedLoopJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + nl_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + nl_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let new_filter = if let Some(filter) = nl_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + nl_join.left(), + nl_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + nl_join.left(), + nl_join.right(), + )?; + + Ok(Some(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_filter, + nl_join.join_type(), + )?))) +} + +/// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sort_merge_join( + projection: &ProjectionExec, + sm_join: &SortMergeJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sm_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sm_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sm_join.on(), + ) else { + return Ok(None); + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + &sm_join.children()[0], + &sm_join.children()[1], + )?; + + Ok(Some(Arc::new(SortMergeJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + sm_join.join_type, + sm_join.sort_options.clone(), + sm_join.null_equals_null, + )?))) +} + +/// Tries to swap the projection with its input [`SymmetricHashJoinExec`]. If it can be done, +/// it returns the new swapped version having the [`SymmetricHashJoinExec`] as the top plan. +/// Otherwise, it returns None. +fn try_swapping_with_sym_hash_join( + projection: &ProjectionExec, + sym_join: &SymmetricHashJoinExec, +) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + sym_join.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + sym_join.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + sym_join.on(), + ) else { + return Ok(None); + }; + + let new_filter = if let Some(filter) = sym_join.filter() { + match update_join_filter( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + filter, + sym_join.left(), + sym_join.right(), + ) { + Some(updated_filter) => Some(updated_filter), + None => return Ok(None), + } + } else { + None + }; + + let (new_left, new_right) = new_join_children( + projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + sym_join.left(), + sym_join.right(), + )?; + + Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + new_filter, + sym_join.join_type(), + sym_join.null_equals_null(), + sym_join.partition_mode(), + )?))) +} + +/// Compare the inputs and outputs of the projection. If the projection causes +/// any change in the fields, it returns `false`. +fn is_projection_removable(projection: &ProjectionExec) -> bool { + all_alias_free_columns(projection.expr()) && { + let schema = projection.schema(); + let input_schema = projection.input().schema(); + let fields = schema.fields(); + let input_fields = input_schema.fields(); + fields.len() == input_fields.len() + && fields + .iter() + .zip(input_fields.iter()) + .all(|(out, input)| out.eq(input)) + } +} + +/// Given the expression set of a projection, checks if the projection causes +/// any renaming or constructs a non-`Column` physical expression. +fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|column| column.name() == alias) + .unwrap_or(false) + }) +} + +/// Updates a source provider's projected columns according to the given +/// projection operator's expressions. To use this function safely, one must +/// ensure that all expressions are `Column` expressions without aliases. +fn new_projections_for_columns( + projection: &ProjectionExec, + source: &Option>, +) -> Vec { + projection + .expr() + .iter() + .filter_map(|(expr, _)| { + expr.as_any() + .downcast_ref::() + .and_then(|expr| source.as_ref().map(|proj| proj[expr.index()])) + }) + .collect() +} + +/// The function operates in two modes: +/// +/// 1) When `sync_with_child` is `true`: +/// +/// The function updates the indices of `expr` if the expression resides +/// in the input plan. For instance, given the expressions `a@1 + b@2` +/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are +/// updated to `a@0 + b@1` and `c@2`. +/// +/// 2) When `sync_with_child` is `false`: +/// +/// The function determines how the expression would be updated if a projection +/// was placed before the plan associated with the expression. If the expression +/// cannot be rewritten after the projection, it returns `None`. For example, +/// given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with +/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes +/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +fn update_expr( + expr: &Arc, + projected_exprs: &[(Arc, String)], + sync_with_child: bool, +) -> Result>> { + #[derive(Debug, PartialEq)] + enum RewriteState { + /// The expression is unchanged. + Unchanged, + /// Some part of the expression has been rewritten + RewrittenValid, + /// Some part of the expression has been rewritten, but some column + /// references could not be. + RewrittenInvalid, + } + + let mut state = RewriteState::Unchanged; + + let new_expr = expr + .clone() + .transform_up_mut(&mut |expr: Arc| { + if state == RewriteState::RewrittenInvalid { + return Ok(Transformed::No(expr)); + } + + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::No(expr)); + }; + if sync_with_child { + state = RewriteState::RewrittenValid; + // Update the index of `column`: + Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + } else { + // default to invalid, in case we can't find the relevant column + state = RewriteState::RewrittenInvalid; + // Determine how to update `column` to accommodate `projected_exprs` + projected_exprs + .iter() + .enumerate() + .find_map(|(index, (projected_expr, alias))| { + projected_expr.as_any().downcast_ref::().and_then( + |projected_column| { + column.name().eq(projected_column.name()).then(|| { + state = RewriteState::RewrittenValid; + Arc::new(Column::new(alias, index)) as _ + }) + }, + ) + }) + .map_or_else( + || Ok(Transformed::No(expr)), + |c| Ok(Transformed::Yes(c)), + ) + } + }); + + new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) +} + +/// Creates a new [`ProjectionExec`] instance with the given child plan and +/// projected expressions. +fn make_with_child( + projection: &ProjectionExec, + child: &Arc, +) -> Result> { + ProjectionExec::try_new(projection.expr().to_vec(), child.clone()) + .map(|e| Arc::new(e) as _) +} + +/// Returns `true` if all the expressions in the argument are `Column`s. +fn all_columns(exprs: &[(Arc, String)]) -> bool { + exprs.iter().all(|(expr, _)| expr.as_any().is::()) +} + +/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given +/// expressions is not a `Column`, returns `None`. +fn physical_to_column_exprs( + exprs: &[(Arc, String)], +) -> Option> { + exprs + .iter() + .map(|(expr, alias)| { + expr.as_any() + .downcast_ref::() + .map(|col| (col.clone(), alias.clone())) + }) + .collect() +} + +/// Returns the last index before encountering a column coming from the right table when traveling +/// through the projection from left to right, and the last index before encountering a column +/// coming from the left table when traveling through the projection from right to left. +/// If there is no column in the projection coming from the left side, it returns (-1, ...), +/// if there is no column in the projection coming from the right side, it returns (..., projection length). +fn join_table_borders( + left_table_column_count: usize, + projection_as_columns: &[(Column, String)], +) -> (i32, i32) { + let far_right_left_col_ind = projection_as_columns + .iter() + .enumerate() + .take_while(|(_, (projection_column, _))| { + projection_column.index() < left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(-1); + + let far_left_right_col_ind = projection_as_columns + .iter() + .enumerate() + .rev() + .take_while(|(_, (projection_column, _))| { + projection_column.index() >= left_table_column_count + }) + .last() + .map(|(index, _)| index as i32) + .unwrap_or(projection_as_columns.len() as i32); + + (far_right_left_col_ind, far_left_right_col_ind) +} + +/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// the join was replaced by a projection. +fn update_join_on( + proj_left_exprs: &[(Column, String)], + proj_right_exprs: &[(Column, String)], + hash_join_on: &[(Column, Column)], +) -> Option> { + let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on + .iter() + .map(|(left, right)| (left, right)) + .unzip(); + + let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs); + let new_right_columns = new_columns_for_join_on(&right_idx, proj_right_exprs); + + match (new_left_columns, new_right_columns) { + (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()), + _ => None, + } +} + +/// This function generates a new set of columns to be used in a hash join +/// operation based on a set of equi-join conditions (`hash_join_on`) and a +/// list of projection expressions (`projection_exprs`). +fn new_columns_for_join_on( + hash_join_on: &[&Column], + projection_exprs: &[(Column, String)], +) -> Option> { + let new_columns = hash_join_on + .iter() + .filter_map(|on| { + projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| on.name() == proj_column.name()) + .map(|(index, (_, alias))| Column::new(alias, index)) + }) + .collect::>(); + (new_columns.len() == hash_join_on.len()).then_some(new_columns) +} + +/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// the join was replaced by a projection. +fn update_join_filter( + projection_left_exprs: &[(Column, String)], + projection_right_exprs: &[(Column, String)], + join_filter: &JoinFilter, + join_left: &Arc, + join_right: &Arc, +) -> Option { + let mut new_left_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Left, + projection_left_exprs, + join_left.schema(), + ) + .into_iter(); + let mut new_right_indices = new_indices_for_join_filter( + join_filter, + JoinSide::Right, + projection_right_exprs, + join_right.schema(), + ) + .into_iter(); + + // Check if all columns match: + (new_right_indices.len() + new_left_indices.len() + == join_filter.column_indices().len()) + .then(|| { + JoinFilter::new( + join_filter.expression().clone(), + join_filter + .column_indices() + .iter() + .map(|col_idx| ColumnIndex { + index: if col_idx.side == JoinSide::Left { + new_left_indices.next().unwrap() + } else { + new_right_indices.next().unwrap() + }, + side: col_idx.side, + }) + .collect(), + join_filter.schema().clone(), + ) + }) +} + +/// This function determines and returns a vector of indices representing the +/// positions of columns in `projection_exprs` that are involved in `join_filter`, +/// and correspond to a particular side (`join_side`) of the join operation. +fn new_indices_for_join_filter( + join_filter: &JoinFilter, + join_side: JoinSide, + projection_exprs: &[(Column, String)], + join_child_schema: SchemaRef, +) -> Vec { + join_filter + .column_indices() + .iter() + .filter(|col_idx| col_idx.side == join_side) + .filter_map(|col_idx| { + projection_exprs.iter().position(|(col, _)| { + col.name() == join_child_schema.fields()[col_idx.index].name() + }) + }) + .collect() +} + +/// Checks three conditions for pushing a projection down through a join: +/// - Projection must narrow the join output schema. +/// - Columns coming from left/right tables must be collected at the left/right +/// sides of the output table. +/// - Left or right table is not lost after the projection. +fn join_allows_pushdown( + projection_as_columns: &[(Column, String)], + join_schema: SchemaRef, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, +) -> bool { + // Projection must narrow the join output: + projection_as_columns.len() < join_schema.fields().len() + // Are the columns from different tables mixed? + && (far_right_left_col_ind + 1 == far_left_right_col_ind) + // Left or right table is not lost after the projection. + && far_right_left_col_ind >= 0 + && far_left_right_col_ind < projection_as_columns.len() as i32 +} + +/// If pushing down the projection over this join's children seems possible, +/// this function constructs the new [`ProjectionExec`]s that will come on top +/// of the original children of the join. +fn new_join_children( + projection_as_columns: Vec<(Column, String)>, + far_right_left_col_ind: i32, + far_left_right_col_ind: i32, + left_child: &Arc, + right_child: &Arc, +) -> Result<(ProjectionExec, ProjectionExec)> { + let new_left = ProjectionExec::try_new( + projection_as_columns[0..=far_right_left_col_ind as _] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new(col.name(), col.index())) as _, + alias.clone(), + ) + }) + .collect_vec(), + left_child.clone(), + )?; + let left_size = left_child.schema().fields().len() as i32; + let new_right = ProjectionExec::try_new( + projection_as_columns[far_left_right_col_ind as _..] + .iter() + .map(|(col, alias)| { + ( + Arc::new(Column::new( + col.name(), + // Align projected expressions coming from the right + // table with the new right child projection: + (col.index() as i32 - left_size) as _, + )) as _, + alias.clone(), + ) + }) + .collect_vec(), + right_child.clone(), + )?; + + Ok((new_left, new_right)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_optimizer::output_requirements::OutputRequirementExec; + use crate::physical_optimizer::projection_pushdown::{ + join_table_borders, update_expr, ProjectionPushdown, + }; + use crate::physical_optimizer::PhysicalOptimizerRule; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; + use crate::physical_plan::joins::StreamJoinPartitionMode; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::sorts::sort::SortExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::{get_plan_string, ExecutionPlan}; + + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + }; + use datafusion_physical_expr::{ + Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, ScalarFunctionExpr, + }; + use datafusion_physical_plan::joins::SymmetricHashJoinExec; + use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; + use datafusion_physical_plan::union::UnionExec; + + use itertools::Itertools; + + #[test] + fn test_update_matching_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let child: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("b", 1)), "b".to_owned()), + (Arc::new(Column::new("d", 3)), "d".to_owned()), + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("f", 5)), "f".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &child, true)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_update_projected_exprs() -> Result<()> { + let exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Divide, + Arc::new(Column::new("e", 5)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 3)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Divide, + Arc::new(Column::new("c", 0)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Divide, + Arc::new(Column::new("b", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d", 2))), + vec![ + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 2)), + Operator::Plus, + Arc::new(Column::new("e", 5)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 3)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 5)), + Operator::Plus, + Arc::new(Column::new("d", 2)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 3)), + Operator::Modulo, + Arc::new(Column::new("e", 5)), + ))), + )?), + ]; + let projected_exprs: Vec<(Arc, String)> = vec![ + (Arc::new(Column::new("a", 0)), "a".to_owned()), + (Arc::new(Column::new("b", 1)), "b_new".to_owned()), + (Arc::new(Column::new("c", 2)), "c".to_owned()), + (Arc::new(Column::new("d", 3)), "d_new".to_owned()), + (Arc::new(Column::new("e", 4)), "e".to_owned()), + (Arc::new(Column::new("f", 5)), "f_new".to_owned()), + ]; + + let expected_exprs: Vec> = vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Divide, + Arc::new(Column::new("e", 4)), + )), + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Float32, + None, + )), + Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), + Arc::new(ScalarFunctionExpr::new( + "scalar_expr", + Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + vec![ + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_new", 1)), + Operator::Divide, + Arc::new(Column::new("c", 2)), + )), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Divide, + Arc::new(Column::new("b_new", 1)), + )), + ], + DataType::Int32, + None, + )), + Arc::new(CaseExpr::try_new( + Some(Arc::new(Column::new("d_new", 3))), + vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d_new", 3)), + Operator::Plus, + Arc::new(Column::new("e", 4)), + )) as Arc, + ), + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("d_new", 3)), + )) as Arc, + ), + ], + Some(Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Modulo, + Arc::new(Column::new("e", 4)), + ))), + )?), + ]; + + for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { + assert!(update_expr(&expr, &projected_exprs, false)? + .unwrap() + .eq(&expected_expr)); + } + + Ok(()) + } + + #[test] + fn test_join_table_borders() -> Result<()> { + let projections = vec![ + (Column::new("b", 1), "b".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("c", 2), "c".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("h", 7), "h".to_owned()), + (Column::new("g", 6), "g".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (4, 5) + ); + + let left_table_column_count = 8; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (7, 8) + ); + + let left_table_column_count = 1; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (-1, 0) + ); + + let projections = vec![ + (Column::new("a", 0), "a".to_owned()), + (Column::new("b", 1), "b".to_owned()), + (Column::new("d", 3), "d".to_owned()), + (Column::new("g", 6), "g".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("f", 5), "f".to_owned()), + (Column::new("e", 4), "e".to_owned()), + (Column::new("h", 7), "h".to_owned()), + ]; + let left_table_column_count = 5; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (2, 7) + ); + + let left_table_column_count = 7; + assert_eq!( + join_table_borders(left_table_column_count, &projections), + (6, 7) + ); + + Ok(()) + } + + fn create_simple_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![0, 1, 2, 3, 4]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + infinite_source: false, + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_csv_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ])); + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(&schema), + projection: Some(vec![3, 2, 1]), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![vec![]], + infinite_source: false, + }, + false, + 0, + 0, + None, + FileCompressionType::UNCOMPRESSED, + )) + } + + fn create_projecting_memory_exec() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[], schema, Some(vec![2, 0, 3, 4])).unwrap()) + } + + #[test] + fn test_csv_after_projection() -> Result<()> { + let csv = create_projecting_csv_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 2)), "b".to_string()), + (Arc::new(Column::new("d", 0)), "d".to_string()), + ], + csv.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@2 as b, d@0 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[d, c, b], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CsvExec: file_groups={1 group: [[x]]}, projection=[b, d], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_memory_after_projection() -> Result<()> { + let memory = create_projecting_memory_exec(); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 2)), "d".to_string()), + (Arc::new(Column::new("e", 3)), "e".to_string()), + (Arc::new(Column::new("a", 1)), "a".to_string()), + ], + memory.clone(), + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = ["MemoryExec: partitions=0, partition_sizes=[]"]; + assert_eq!(get_plan_string(&after_optimize), expected); + assert_eq!( + after_optimize + .clone() + .as_any() + .downcast_ref::() + .unwrap() + .projection() + .clone() + .unwrap(), + vec![3, 4, 0] + ); + + Ok(()) + } + + #[test] + fn test_streaming_table_after_projection() -> Result<()> { + struct DummyStreamPartition { + schema: SchemaRef, + } + impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } + } + + let streaming_table = StreamingTableExec::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + vec![Arc::new(DummyStreamPartition { + schema: Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])), + }) as _], + Some(&vec![0_usize, 2, 4, 3]), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 2)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options: SortOptions::default(), + }], + ] + .into_iter(), + true, + )?; + let projection = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("d", 3)), "d".to_string()), + (Arc::new(Column::new("e", 2)), "e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + ], + Arc::new(streaming_table) as _, + )?) as _; + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let result = after_optimize + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result.partition_schema(), + &Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + ])) + ); + assert_eq!( + result.projection().clone().unwrap().to_vec(), + vec![3_usize, 4, 0] + ); + assert_eq!( + result.projected_schema(), + &Schema::new(vec![ + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]) + ); + assert_eq!( + result.projected_output_ordering().into_iter().collect_vec(), + vec![ + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("e", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 2)), + options: SortOptions::default(), + }, + ], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("d", 0)), + options: SortOptions::default(), + }], + ] + ); + assert!(result.is_infinite()); + + Ok(()) + } + + #[test] + fn test_projection_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let child_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("e", 4)), "new_e".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ], + csv.clone(), + )?); + let top_projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_e", 1)), + )), + "binary".to_string(), + ), + (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ], + child_projection.clone(), + )?); + + let initial = get_plan_string(&top_projection); + let expected_initial = [ + "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", + " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_output_req_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(OutputRequirementExec::new( + csv.clone(), + Some(vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 1)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: Some(SortOptions::default()), + }, + ]), + Distribution::HashPartitioned(vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " OutputRequirementExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected: [&str; 3] = [ + "OutputRequirementExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + + assert_eq!(get_plan_string(&after_optimize), expected); + let expected_reqs = vec![ + PhysicalSortRequirement { + expr: Arc::new(Column::new("b", 2)), + options: Some(SortOptions::default()), + }, + PhysicalSortRequirement { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + options: Some(SortOptions::default()), + }, + ]; + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .clone() + .unwrap(), + expected_reqs + ); + let expected_distribution: Vec> = vec![ + Arc::new(Column::new("new_a", 1)), + Arc::new(Column::new("b", 2)), + ]; + if let Distribution::HashPartitioned(vec) = after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .required_input_distribution()[0] + .clone() + { + assert!(vec + .iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected))); + } else { + panic!("Expected HashPartitioned distribution!"); + }; + + Ok(()) + } + + #[test] + fn test_coalesce_partitions_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + coalesce_partitions, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CoalescePartitionsExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "CoalescePartitionsExec", + " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_filter_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + Operator::Gt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("d", 3)), + Operator::Minus, + Arc::new(Column::new("a", 0)), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, csv)?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("a", 0)), "a_new".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + (Arc::new(Column::new("d", 3)), "d".to_string()), + ], + filter.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " FilterExec: b@1 - a@0 > d@3 - a@0", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", + " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_join_after_projection() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + left_csv, + right_csv, + vec![(Column::new("b", 1), Column::new("c", 2))], + // b_left-(1+a_right)<=a_right+c_left + Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b_left_inter", 0)), + Operator::Minus, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::Plus, + Arc::new(Column::new("a_right_inter", 1)), + )), + )), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a_right_inter", 1)), + Operator::Plus, + Arc::new(Column::new("c_left_inter", 2)), + )), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ], + Schema::new(vec![ + Field::new("b_left_inter", DataType::Int32, true), + Field::new("a_right_inter", DataType::Int32, true), + Field::new("c_left_inter", DataType::Int32, true), + ]), + )), + &JoinType::Inner, + true, + StreamJoinPartitionMode::SinglePartition, + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), + (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ], + join, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + let expected_filter_col_ind = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ]; + + assert_eq!( + expected_filter_col_ind, + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .filter() + .unwrap() + .column_indices() + ); + + Ok(()) + } + + #[test] + fn test_repartition_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let repartition: Arc = Arc::new(RepartitionExec::try_new( + csv, + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("d", 3)), + ], + 6, + ), + )?); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("b", 1)), "b_new".to_string()), + (Arc::new(Column::new("a", 0)), "a".to_string()), + (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ], + repartition, + )?); + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", + " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + assert_eq!( + after_optimize + .as_any() + .downcast_ref::() + .unwrap() + .partitioning() + .clone(), + Partitioning::Hash( + vec![ + Arc::new(Column::new("a", 1)), + Arc::new(Column::new("b_new", 0)), + Arc::new(Column::new("d_new", 2)), + ], + 6, + ), + ); + + Ok(()) + } + + #[test] + fn test_sort_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_sort_preserving_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + options: SortOptions::default(), + }, + ], + csv.clone(), + )); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + sort_req.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } + + #[test] + fn test_union_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let union: Arc = + Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + (Arc::new(Column::new("c", 2)), "c".to_string()), + (Arc::new(Column::new("a", 0)), "new_a".to_string()), + (Arc::new(Column::new("b", 1)), "b".to_string()), + ], + union.clone(), + )?); + + let initial = get_plan_string(&projection); + let expected_initial = [ + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " UnionExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let expected = [ + "UnionExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 8fb3058308fba..b2ba7596db8d2 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -15,19 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! This module contains code to prune "containers" of row groups -//! based on statistics prior to execution. This can lead to -//! significant performance improvements by avoiding the need -//! to evaluate a plan on entire containers (e.g. an entire file) +//! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers" +//! based on statistics (e.g. Parquet Row Groups) //! -//! For example, DataFusion uses this code to prune (skip) row groups -//! while reading parquet files if it can be determined from the -//! predicate that nothing in the row group can match. -//! -//! This code can also be used by other systems to prune other -//! entities (e.g. entire files) if the statistics are known via some -//! other source (e.g. a catalog) - +//! [`Expr`]: crate::prelude::Expr use std::collections::HashSet; use std::convert::TryFrom; use std::sync::Arc; @@ -44,24 +35,30 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{downcast_value, ScalarValue}; +use datafusion_common::{downcast_value, plan_datafusion_err, ScalarValue}; +use datafusion_common::{ + internal_err, plan_err, + tree_node::{Transformed, TreeNode}, +}; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; -/// Interface to pass statistics information to [`PruningPredicate`] +/// Interface to pass statistics (min/max/nulls) information to [`PruningPredicate`]. /// -/// Returns statistics for containers / files of data in Arrays. +/// Returns statistics for containers / files as Arrow [`ArrayRef`], so the +/// evaluation happens once on a single `RecordBatch`, amortizing the overhead +/// of evaluating of the predicate. This is important when pruning 1000s of +/// containers which often happens in analytic systems. /// -/// For example, for the following three files with a single column +/// For example, for the following three files with a single column `a`: /// ```text /// file1: column a: min=5, max=10 /// file2: column a: No stats /// file2: column a: min=20, max=30 /// ``` /// -/// PruningStatistics should return: +/// PruningStatistics would return: /// /// ```text /// min_values("a") -> Some([5, Null, 20]) @@ -69,29 +66,78 @@ use log::trace; /// min_values("X") -> None /// ``` pub trait PruningStatistics { - /// return the minimum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn min_values(&self, column: &Column) -> Option; - /// return the maximum values for the named column, if known. - /// Note: the returned array must contain `num_containers()` rows. + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn max_values(&self, column: &Column) -> Option; - /// return the number of containers (e.g. row groups) being - /// pruned with these statistics + /// Return the number of containers (e.g. row groups) being + /// pruned with these statistics (the number of rows in each returned array) fn num_containers(&self) -> usize; - /// return the number of null values for the named column as an + /// Return the number of null values for the named column as an /// `Option`. /// - /// Note: the returned array must contain `num_containers()` rows. + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; } -/// Evaluates filter expressions on statistics in order to -/// prune data containers (e.g. parquet row group) +/// Evaluates filter expressions on statistics such as min/max values and null +/// counts, attempting to prove a "container" (e.g. Parquet Row Group) can be +/// skipped without reading the actual data, potentially leading to significant +/// performance improvements. +/// +/// For example, [`PruningPredicate`]s are used to prune Parquet Row Groups +/// based on the min/max values found in the Parquet metadata. If the +/// `PruningPredicate` can guarantee that no rows in the Row Group match the +/// filter, the entire Row Group is skipped during query execution. +/// +/// The `PruningPredicate` API is general, allowing it to be used for pruning +/// other types of containers (e.g. files) based on statistics that may be +/// known from external catalogs (e.g. Delta Lake) or other sources. Thus it +/// supports: +/// +/// 1. Arbitrary expressions expressions (including user defined functions) +/// +/// 2. Vectorized evaluation (provide more than one set of statistics at a time) +/// so it is suitable for pruning 1000s of containers. +/// +/// 3. Anything that implements the [`PruningStatistics`] trait, not just +/// Parquet metadata. +/// +/// # Example /// -/// See [`PruningPredicate::try_new`] for more information. +/// Given an expression like `x = 5` and statistics for 3 containers (Row +/// Groups, files, etc) `A`, `B`, and `C`: +/// +/// ```text +/// A: {x_min = 0, x_max = 4} +/// B: {x_min = 2, x_max = 10} +/// C: {x_min = 5, x_max = 8} +/// ``` +/// +/// Applying the `PruningPredicate` will concludes that `A` can be pruned: +/// +/// ```text +/// A: false (no rows could possibly match x = 5) +/// B: true (rows might match x = 5) +/// C: true (rows might match x = 5) +/// ``` +/// +/// See [`PruningPredicate::try_new`] and [`PruningPredicate::prune`] for more information. #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -143,17 +189,14 @@ impl PruningPredicate { /// /// `true`: There MAY be rows that match the predicate /// - /// `false`: There are no rows that could match the predicate - /// - /// Note this function takes a slice of statistics as a parameter - /// to amortize the cost of the evaluation of the predicate - /// against a single record batch. + /// `false`: There are no rows that could possibly match the predicate /// - /// Note: the predicate passed to `prune` should be simplified as + /// Note: the predicate passed to `prune` should already be simplified as /// much as possible (e.g. this pass doesn't handle some /// expressions like `b = false`, but it does handle the - /// simplified version `b`. The predicates are simplified via the - /// ConstantFolding optimizer pass + /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. + /// + /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { // build a RecordBatch that contains the min/max values in the // appropriate statistics columns @@ -183,10 +226,10 @@ impl PruningPredicate { Ok(vec![v; statistics.num_containers()]) } other => { - Err(DataFusionError::Internal(format!( + internal_err!( "Unexpected result of pruning predicate evaluation. Expected Boolean array \ or scalar but got {other:?}" - ))) + ) } } } @@ -223,8 +266,12 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } -/// Records for which columns statistics are necessary to evaluate a -/// pruning predicate. +/// Describes which columns statistics are necessary to evaluate a +/// [`PruningPredicate`]. +/// +/// This structure permits reading and creating the minimum number statistics, +/// which is important since statistics may be non trivial to read (e.g. large +/// strings or when there are 1000s of columns). /// /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed @@ -396,11 +443,11 @@ fn build_statistics_record_batch( let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); if num_containers != array.len() { - return Err(DataFusionError::Internal(format!( + return internal_err!( "mismatched statistics length. Expected {}, got {}", num_containers, array.len() - ))); + ); } // cast statistics array to required data type (e.g. parquet @@ -423,7 +470,7 @@ fn build_statistics_record_batch( ); RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { - DataFusionError::Plan(format!("Can not create statistics record batch: {err}")) + plan_datafusion_err!("Can not create statistics record batch: {err}") }) } @@ -453,10 +500,9 @@ impl<'a> PruningExpressionBuilder<'a> { (0, 1) => (right, left, right_columns, reverse_operator(op)?), _ => { // if more than one column used in expression - not supported - return Err(DataFusionError::Plan( + return plan_err!( "Multi-column expressions are not currently supported" - .to_string(), - )); + ); } }; @@ -471,9 +517,7 @@ impl<'a> PruningExpressionBuilder<'a> { let field = match schema.column_with_name(column.name()) { Some((_, f)) => f, _ => { - return Err(DataFusionError::Plan( - "Field not found in schema".to_string(), - )); + return plan_err!("Field not found in schema"); } }; @@ -525,9 +569,7 @@ fn rewrite_expr_to_prunable( schema: DFSchema, ) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { if !is_compare_op(op) { - return Err(DataFusionError::Plan( - "rewrite_expr_to_prunable only support compare expression".to_string(), - )); + return plan_err!("rewrite_expr_to_prunable only support compare expression"); } let column_expr_any = column_expr.as_any(); @@ -574,9 +616,7 @@ fn rewrite_expr_to_prunable( } else if let Some(not) = column_expr_any.downcast_ref::() { // `!col = true` --> `col = !true` if op != Operator::Eq && op != Operator::NotEq { - return Err(DataFusionError::Plan( - "Not with operator other than Eq / NotEq is not supported".to_string(), - )); + return plan_err!("Not with operator other than Eq / NotEq is not supported"); } if not .arg() @@ -588,14 +628,10 @@ fn rewrite_expr_to_prunable( let right = Arc::new(phys_expr::NotExpr::new(scalar_expr.clone())); Ok((left, reverse_operator(op)?, right)) } else { - Err(DataFusionError::Plan(format!( - "Not with complex expression {column_expr:?} is not supported" - ))) + plan_err!("Not with complex expression {column_expr:?} is not supported") } } else { - Err(DataFusionError::Plan(format!( - "column expression {column_expr:?} is not supported" - ))) + plan_err!("column expression {column_expr:?} is not supported") } } @@ -630,9 +666,9 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re ) { Ok(()) } else { - Err(DataFusionError::Plan(format!( + plan_err!( "Try Cast/Cast with from type {from_type} to type {to_type} is not supported" - ))) + ) } } @@ -841,85 +877,85 @@ fn build_predicate_expression( fn build_statistics_expr( expr_builder: &mut PruningExpressionBuilder, ) -> Result> { - let statistics_expr: Arc = - match expr_builder.op() { - Operator::NotEq => { - // column != literal => (min, max) = literal => - // !(min != literal && max != literal) ==> - // min != literal || literal != max - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::NotEq, - expr_builder.scalar_expr().clone(), - )), - Operator::Or, - Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), - Operator::NotEq, - max_column_expr, - )), - )) - } - Operator::Eq => { - // column = literal => (min, max) = literal => min <= literal && literal <= max - // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::LtEq, - expr_builder.scalar_expr().clone(), - )), - Operator::And, - Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), - Operator::LtEq, - max_column_expr, - )), - )) - } - Operator::Gt => { - // column > literal => (min, max) > literal => max > literal + let statistics_expr: Arc = match expr_builder.op() { + Operator::NotEq => { + // column != literal => (min, max) = literal => + // !(min != literal && max != literal) ==> + // min != literal || literal != max + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Arc::new(phys_expr::BinaryExpr::new( Arc::new(phys_expr::BinaryExpr::new( - expr_builder.max_column_expr()?, - Operator::Gt, + min_column_expr, + Operator::NotEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::GtEq => { - // column >= literal => (min, max) >= literal => max >= literal + )), + Operator::Or, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.max_column_expr()?, - Operator::GtEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::Lt => { - // column < literal => (min, max) < literal => min < literal + Operator::NotEq, + max_column_expr, + )), + )) + } + Operator::Eq => { + // column = literal => (min, max) = literal => min <= literal && literal <= max + // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Arc::new(phys_expr::BinaryExpr::new( Arc::new(phys_expr::BinaryExpr::new( - expr_builder.min_column_expr()?, - Operator::Lt, + min_column_expr, + Operator::LtEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::LtEq => { - // column <= literal => (min, max) <= literal => min <= literal + )), + Operator::And, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.min_column_expr()?, - Operator::LtEq, expr_builder.scalar_expr().clone(), - )) - } - // other expressions are not supported - _ => return Err(DataFusionError::Plan( + Operator::LtEq, + max_column_expr, + )), + )) + } + Operator::Gt => { + // column > literal => (min, max) > literal => max > literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::Gt, + expr_builder.scalar_expr().clone(), + )) + } + Operator::GtEq => { + // column >= literal => (min, max) >= literal => max >= literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::GtEq, + expr_builder.scalar_expr().clone(), + )) + } + Operator::Lt => { + // column < literal => (min, max) < literal => min < literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::Lt, + expr_builder.scalar_expr().clone(), + )) + } + Operator::LtEq => { + // column <= literal => (min, max) <= literal => min <= literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::LtEq, + expr_builder.scalar_expr().clone(), + )) + } + // other expressions are not supported + _ => { + return plan_err!( "expressions other than (neq, eq, gt, gteq, lt, lteq) are not supported" - .to_string(), - )), - }; + ); + } + }; Ok(statistics_expr) } @@ -1219,7 +1255,7 @@ mod tests { let batch = build_statistics_record_batch(&statistics, &required_columns).unwrap(); - let expected = vec![ + let expected = [ "+--------+--------+--------+--------+", "| s1_min | s2_max | s3_max | s3_min |", "+--------+--------+--------+--------+", @@ -1258,7 +1294,7 @@ mod tests { let batch = build_statistics_record_batch(&statistics, &required_columns).unwrap(); - let expected = vec![ + let expected = [ "+-------------------------------+", "| s1_min |", "+-------------------------------+", @@ -1304,7 +1340,7 @@ mod tests { let batch = build_statistics_record_batch(&statistics, &required_columns).unwrap(); - let expected = vec![ + let expected = [ "+--------+", "| s1_min |", "+--------+", diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs deleted file mode 100644 index fb867ff36c62d..0000000000000 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ /dev/null @@ -1,1196 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Repartition optimizer that introduces repartition nodes to increase the level of parallelism available -use datafusion_common::tree_node::Transformed; -use std::sync::Arc; - -use super::optimizer::PhysicalOptimizerRule; -use crate::config::ConfigOptions; -use crate::datasource::physical_plan::ParquetExec; -use crate::error::Result; -use crate::physical_plan::Partitioning::*; -use crate::physical_plan::{ - repartition::RepartitionExec, with_new_children_if_necessary, ExecutionPlan, -}; - -/// Optimizer that introduces repartition to introduce more -/// parallelism in the plan -/// -/// For example, given an input such as: -/// -/// -/// ```text -/// ┌─────────────────────────────────┐ -/// │ │ -/// │ ExecutionPlan │ -/// │ │ -/// └─────────────────────────────────┘ -/// ▲ ▲ -/// │ │ -/// ┌─────┘ └─────┐ -/// │ │ -/// │ │ -/// │ │ -/// ┌───────────┐ ┌───────────┐ -/// │ │ │ │ -/// │ batch A1 │ │ batch B1 │ -/// │ │ │ │ -/// ├───────────┤ ├───────────┤ -/// │ │ │ │ -/// │ batch A2 │ │ batch B2 │ -/// │ │ │ │ -/// ├───────────┤ ├───────────┤ -/// │ │ │ │ -/// │ batch A3 │ │ batch B3 │ -/// │ │ │ │ -/// └───────────┘ └───────────┘ -/// -/// Input Input -/// A B -/// ``` -/// -/// This optimizer will attempt to add a `RepartitionExec` to increase -/// the parallelism (to 3 in this case) -/// -/// ```text -/// ┌─────────────────────────────────┐ -/// │ │ -/// │ ExecutionPlan │ -/// │ │ -/// └─────────────────────────────────┘ -/// ▲ ▲ ▲ Input now has 3 -/// │ │ │ partitions -/// ┌───────┘ │ └───────┐ -/// │ │ │ -/// │ │ │ -/// ┌───────────┐ ┌───────────┐ ┌───────────┐ -/// │ │ │ │ │ │ -/// │ batch A1 │ │ batch A3 │ │ batch B3 │ -/// │ │ │ │ │ │ -/// ├───────────┤ ├───────────┤ ├───────────┤ -/// │ │ │ │ │ │ -/// │ batch B2 │ │ batch B1 │ │ batch A2 │ -/// │ │ │ │ │ │ -/// └───────────┘ └───────────┘ └───────────┘ -/// ▲ ▲ ▲ -/// │ │ │ -/// └─────────┐ │ ┌──────────┘ -/// │ │ │ -/// │ │ │ -/// ┌─────────────────────────────────┐ batches are -/// │ RepartitionExec(3) │ repartitioned -/// │ RoundRobin │ -/// │ │ -/// └─────────────────────────────────┘ -/// ▲ ▲ -/// │ │ -/// ┌─────┘ └─────┐ -/// │ │ -/// │ │ -/// │ │ -/// ┌───────────┐ ┌───────────┐ -/// │ │ │ │ -/// │ batch A1 │ │ batch B1 │ -/// │ │ │ │ -/// ├───────────┤ ├───────────┤ -/// │ │ │ │ -/// │ batch A2 │ │ batch B2 │ -/// │ │ │ │ -/// ├───────────┤ ├───────────┤ -/// │ │ │ │ -/// │ batch A3 │ │ batch B3 │ -/// │ │ │ │ -/// └───────────┘ └───────────┘ -/// -/// -/// Input Input -/// A B -/// ``` -#[derive(Default)] -pub struct Repartition {} - -impl Repartition { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -/// Recursively attempts to increase the overall parallelism of the -/// plan, while respecting ordering, by adding a `RepartitionExec` at -/// the output of `plan` if it would help parallelism and not destroy -/// any possibly useful ordering. -/// -/// It does so using a depth first scan of the tree, and repartitions -/// any plan that: -/// -/// 1. Has fewer partitions than `target_partitions` -/// -/// 2. Has a direct parent that `benefits_from_input_partitioning` -/// -/// 3. Does not destroy any existing sort order if the parent is -/// relying on it. -/// -/// if `can_reorder` is false, it means the parent node of `plan` is -/// trying to take advantage of the output sort order of plan, so it -/// should not be repartitioned if doing so would destroy the output -/// sort order. -/// -/// (Parent) - If can_reorder is false, means this parent node is -/// trying to use the sort ouder order this plan. If true -/// means parent doesn't care about sort order -/// -/// (plan) - We are deciding to add a partition above here -/// -/// (children) - Recursively visit all children first -/// -/// If 'would_benefit` is true, the upstream operator would benefit -/// from additional partitions and thus repatitioning is considered. -/// -/// if `is_root` is true, no repartition is added. -fn optimize_partitions( - target_partitions: usize, - plan: Arc, - is_root: bool, - can_reorder: bool, - would_benefit: bool, - repartition_file_scans: bool, - repartition_file_min_size: usize, -) -> Result>> { - // Recurse into children bottom-up (attempt to repartition as - // early as possible) - let new_plan = if plan.children().is_empty() { - // leaf node - don't replace children - Transformed::No(plan) - } else { - let children = plan - .children() - .iter() - .enumerate() - .map(|(idx, child)| { - // Does plan itself (not its parent) require its input to - // be sorted in some way? - let required_input_ordering = - plan_has_required_input_ordering(plan.as_ref()); - - // We can reorder a child if: - // - It has no ordering to preserve, or - // - Its parent has no required input ordering and does not - // maintain input ordering. - // Check if this condition holds: - let can_reorder_child = child.output_ordering().is_none() - || (!required_input_ordering - && (can_reorder || !plan.maintains_input_order()[idx])); - - optimize_partitions( - target_partitions, - child.clone(), - false, // child is not root - can_reorder_child, - plan.benefits_from_input_partitioning(), - repartition_file_scans, - repartition_file_min_size, - ) - .map(Transformed::into) - }) - .collect::>()?; - with_new_children_if_necessary(plan, children)? - }; - - let (new_plan, transformed) = new_plan.into_pair(); - - // decide if we should bother trying to repartition the output of this plan - let mut could_repartition = match new_plan.output_partitioning() { - // Apply when underlying node has less than `self.target_partitions` amount of concurrency - RoundRobinBatch(x) => x < target_partitions, - UnknownPartitioning(x) => x < target_partitions, - // we don't want to introduce partitioning after hash partitioning - // as the plan will likely depend on this - Hash(_, _) => false, - }; - - // Don't need to apply when the returned row count is not greater than 1 - let stats = new_plan.statistics(); - if stats.is_exact { - could_repartition = could_repartition - && stats.num_rows.map(|num_rows| num_rows > 1).unwrap_or(true); - } - - // don't repartition root of the plan - if is_root { - could_repartition = false; - } - - let repartition_allowed = would_benefit && could_repartition && can_reorder; - - // If repartition is not allowed - return plan as it is - if !repartition_allowed { - return Ok(if transformed { - Transformed::Yes(new_plan) - } else { - Transformed::No(new_plan) - }); - } - - // For ParquetExec return internally repartitioned version of the plan in case `repartition_file_scans` is set - if let Some(parquet_exec) = new_plan.as_any().downcast_ref::() { - if repartition_file_scans { - return Ok(Transformed::Yes(Arc::new( - parquet_exec - .get_repartitioned(target_partitions, repartition_file_min_size), - ))); - } - } - - // Otherwise - return plan wrapped up in RepartitionExec - Ok(Transformed::Yes(Arc::new(RepartitionExec::try_new( - new_plan, - RoundRobinBatch(target_partitions), - )?))) -} - -/// Returns true if `plan` requires any of inputs to be sorted in some -/// way for correctness. If this is true, its output should not be -/// repartitioned if it would destroy the required order. -fn plan_has_required_input_ordering(plan: &dyn ExecutionPlan) -> bool { - // NB: checking `is_empty()` is not the right check! - plan.required_input_ordering().iter().any(Option::is_some) -} - -impl PhysicalOptimizerRule for Repartition { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - let target_partitions = config.execution.target_partitions; - let enabled = config.optimizer.enable_round_robin_repartition; - let repartition_file_scans = config.optimizer.repartition_file_scans; - let repartition_file_min_size = config.optimizer.repartition_file_min_size; - // Don't run optimizer if target_partitions == 1 - if !enabled || target_partitions == 1 { - Ok(plan) - } else { - let is_root = true; - let can_reorder = plan.output_ordering().is_none(); - let would_benefit = false; - optimize_partitions( - target_partitions, - plan.clone(), - is_root, - can_reorder, - would_benefit, - repartition_file_scans, - repartition_file_min_size, - ) - .map(Transformed::into) - } - } - - fn name(&self) -> &str { - "repartition" - } - - fn schema_check(&self) -> bool { - true - } -} - -#[cfg(test)] -#[ctor::ctor] -fn init() { - let _ = env_logger::try_init(); -} - -#[cfg(test)] -mod tests { - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - - use super::*; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; - use crate::physical_optimizer::dist_enforcement::EnforceDistribution; - use crate::physical_optimizer::sort_enforcement::EnforceSorting; - use crate::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, - }; - use crate::physical_plan::expressions::{col, PhysicalSortExpr}; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; - use crate::physical_plan::projection::ProjectionExec; - use crate::physical_plan::sorts::sort::SortExec; - use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::union::UnionExec; - use crate::physical_plan::{displayable, DisplayFormatType, Statistics}; - use datafusion_physical_expr::PhysicalSortRequirement; - - fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])) - } - - /// Create a non sorted parquet exec - fn parquet_exec() -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - )) - } - - /// Create a non sorted parquet exec over two files / partitions - fn parquet_exec_two_partitions() -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![ - vec![PartitionedFile::new("x".to_string(), 100)], - vec![PartitionedFile::new("y".to_string(), 200)], - ], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - )) - } - - // Created a sorted parquet exec - fn parquet_exec_sorted() -> Arc { - let sort_exprs = vec![PhysicalSortExpr { - expr: col("c1", &schema()).unwrap(), - options: SortOptions::default(), - }]; - - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - infinite_source: false, - }, - None, - None, - )) - } - - // Created a sorted parquet exec with multiple files - fn parquet_exec_multiple_sorted() -> Arc { - let sort_exprs = vec![PhysicalSortExpr { - expr: col("c1", &schema()).unwrap(), - options: SortOptions::default(), - }]; - - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![ - vec![PartitionedFile::new("x".to_string(), 100)], - vec![PartitionedFile::new("y".to_string(), 100)], - ], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - infinite_source: false, - }, - None, - None, - )) - } - - fn sort_preserving_merge_exec( - input: Arc, - ) -> Arc { - let expr = vec![PhysicalSortExpr { - expr: col("c1", &schema()).unwrap(), - options: arrow::compute::SortOptions::default(), - }]; - - Arc::new(SortPreservingMergeExec::new(expr, input)) - } - - fn filter_exec(input: Arc) -> Arc { - Arc::new(FilterExec::try_new(col("c1", &schema()).unwrap(), input).unwrap()) - } - - fn sort_exec( - input: Arc, - preserve_partitioning: bool, - ) -> Arc { - let sort_exprs = vec![PhysicalSortExpr { - expr: col("c1", &schema()).unwrap(), - options: SortOptions::default(), - }]; - let new_sort = SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning); - Arc::new(new_sort) - } - - fn projection_exec(input: Arc) -> Arc { - let exprs = vec![(col("c1", &schema()).unwrap(), "c1".to_string())]; - Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) - } - - fn aggregate(input: Arc) -> Arc { - let schema = schema(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![], - vec![], - vec![], - Arc::new( - AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![], - vec![], - vec![], - input, - schema.clone(), - ) - .unwrap(), - ), - schema, - ) - .unwrap(), - ) - } - - fn limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new( - Arc::new(LocalLimitExec::new(input, 100)), - 0, - Some(100), - )) - } - - fn limit_exec_with_skip(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new( - Arc::new(LocalLimitExec::new(input, 100)), - 5, - Some(100), - )) - } - - fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) - } - - fn sort_required_exec(input: Arc) -> Arc { - Arc::new(SortRequiredExec::new(input)) - } - - fn trim_plan_display(plan: &str) -> Vec<&str> { - plan.split('\n') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .collect() - } - - /// Runs the repartition optimizer and asserts the plan against the expected - macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, 10, false, 1024); - }; - - ($EXPECTED_LINES: expr, $PLAN: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - - let mut config = ConfigOptions::new(); - config.execution.target_partitions = $TARGET_PARTITIONS; - config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; - config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; - - // run optimizer - let optimizers: Vec> = vec![ - Arc::new(Repartition::new()), - // EnforceDistribution is an essential rule to be applied. - // Otherwise, the correctness of the generated optimized plan cannot be guaranteed - Arc::new(EnforceDistribution::new()), - // EnforceSorting is an essential rule to be applied. - // Otherwise, the correctness of the generated optimized plan cannot be guaranteed - Arc::new(EnforceSorting::new()), - ]; - let optimized = optimizers.into_iter().fold($PLAN, |plan, optimizer| { - optimizer.optimize(plan, &config).unwrap() - }); - - // Now format correctly - let plan = displayable(optimized.as_ref()).indent().to_string(); - let actual_lines = trim_plan_display(&plan); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; - } - - #[test] - fn added_repartition_to_single_partition() -> Result<()> { - let plan = aggregate(parquet_exec()); - - let expected = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_deepest_node() -> Result<()> { - let plan = aggregate(filter_exec(parquet_exec())); - - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "FilterExec: c1@0", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_unsorted_limit() -> Result<()> { - let plan = limit_exec(filter_exec(parquet_exec())); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - "CoalescePartitionsExec", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // nothing sorts the data, so the local limit doesn't require sorted data either - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_unsorted_limit_with_skip() -> Result<()> { - let plan = limit_exec_with_skip(filter_exec(parquet_exec())); - - let expected = &[ - "GlobalLimitExec: skip=5, fetch=100", - "CoalescePartitionsExec", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // nothing sorts the data, so the local limit doesn't require sorted data either - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_sorted_limit() -> Result<()> { - let plan = limit_exec(sort_exec(parquet_exec(), false)); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - // data is sorted so can't repartition here - "SortExec: expr=[c1@0 ASC]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_sorted_limit_with_filter() -> Result<()> { - let plan = limit_exec(filter_exec(sort_exec(parquet_exec(), false))); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // data is sorted so can't repartition here even though - // filter would benefit from parallelism, the answers might be wrong - "SortExec: expr=[c1@0 ASC]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_ignores_limit() -> Result<()> { - let plan = aggregate(limit_exec(filter_exec(limit_exec(parquet_exec())))); - - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "GlobalLimitExec: skip=0, fetch=100", - "CoalescePartitionsExec", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // repartition should happen prior to the filter to maximize parallelism - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - // Expect no repartition to happen for local limit - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_ignores_limit_with_skip() -> Result<()> { - let plan = aggregate(limit_exec_with_skip(filter_exec(limit_exec( - parquet_exec(), - )))); - - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "GlobalLimitExec: skip=5, fetch=100", - "CoalescePartitionsExec", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // repartition should happen prior to the filter to maximize parallelism - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - // Expect no repartition to happen for local limit - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - // repartition works differently for limit when there is a sort below it - - #[test] - fn repartition_ignores_union() -> Result<()> { - let plan = union_exec(vec![parquet_exec(); 5]); - - let expected = &[ - "UnionExec", - // Expect no repartition of ParquetExec - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_through_sort_preserving_merge() -> Result<()> { - // sort preserving merge with non-sorted input - let plan = sort_preserving_merge_exec(parquet_exec()); - - // need repartiton and resort as the data was not sorted correctly - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "SortExec: expr=[c1@0 ASC]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_ignores_sort_preserving_merge() -> Result<()> { - // sort preserving merge already sorted input, - let plan = sort_preserving_merge_exec(parquet_exec_multiple_sorted()); - - // should not repartition / sort (as the data was already sorted) - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { - // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) - let input = union_exec(vec![parquet_exec_sorted(); 2]); - let plan = sort_preserving_merge_exec(input); - - // should not repartition / sort (as the data was already sorted) - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "UnionExec", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_does_not_destroy_sort() -> Result<()> { - // SortRequired - // Parquet(sorted) - - let plan = sort_required_exec(parquet_exec_sorted()); - - // should not repartition as doing so destroys the necessary sort order - let expected = &[ - "SortRequiredExec", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { - // model a more complicated scenario where one child of a union can be repartitioned for performance - // but the other can not be - // - // Union - // SortRequired - // Parquet(sorted) - // Filter - // Parquet(unsorted) - - let input1 = sort_required_exec(parquet_exec_sorted()); - let input2 = filter_exec(parquet_exec()); - let plan = union_exec(vec![input1, input2]); - - // should not repartition below the SortRequired as that - // destroys the sort order but should still repartition for - // FilterExec - let expected = &[ - "UnionExec", - // union input 1: no repartitioning - "SortRequiredExec", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - // union input 2: should repartition - "FilterExec: c1@0", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_transitively_with_projection() -> Result<()> { - // non sorted input - let plan = sort_preserving_merge_exec(projection_exec(parquet_exec())); - - // needs to repartition / sort as the data was not sorted correctly - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "SortExec: expr=[c1@0 ASC]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ProjectionExec: expr=[c1@0 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_ignores_transitively_with_projection() -> Result<()> { - // sorted input - let plan = - sort_preserving_merge_exec(projection_exec(parquet_exec_multiple_sorted())); - - // data should not be repartitioned / resorted - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "ProjectionExec: expr=[c1@0 as c1]", - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_transitively_past_sort_with_projection() -> Result<()> { - let plan = - sort_preserving_merge_exec(sort_exec(projection_exec(parquet_exec()), true)); - - let expected = &[ - "SortExec: expr=[c1@0 ASC]", - "ProjectionExec: expr=[c1@0 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_transitively_past_sort_with_filter() -> Result<()> { - let plan = - sort_preserving_merge_exec(sort_exec(filter_exec(parquet_exec()), true)); - - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[c1@0 ASC]", - "FilterExec: c1@0", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { - let plan = sort_preserving_merge_exec(sort_exec( - projection_exec(filter_exec(parquet_exec())), - true, - )); - - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[c1@0 ASC]", - "ProjectionExec: expr=[c1@0 as c1]", - "FilterExec: c1@0", - // repartition is lowest down - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn parallelization_single_partition() -> Result<()> { - let plan = aggregate(parquet_exec()); - - let expected = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "ParquetExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_two_partitions() -> Result<()> { - let plan = aggregate(parquet_exec_two_partitions()); - - let expected = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - // Plan already has two partitions - "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_two_partitions_into_four() -> Result<()> { - let plan = aggregate(parquet_exec_two_partitions()); - - let expected = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - // Multiple source files splitted across partitions - "ParquetExec: file_groups={4 groups: [[x:0..75], [x:75..100, y:0..50], [y:50..125], [y:125..200]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 4, true, 10); - Ok(()) - } - - #[test] - fn parallelization_sorted_limit() -> Result<()> { - let plan = limit_exec(sort_exec(parquet_exec(), false)); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - // data is sorted so can't repartition here - "SortExec: expr=[c1@0 ASC]", - // Doesn't parallelize for SortExec without preserve_partitioning - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_limit_with_filter() -> Result<()> { - let plan = limit_exec(filter_exec(sort_exec(parquet_exec(), false))); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // data is sorted so can't repartition here even though - // filter would benefit from parallelism, the answers might be wrong - "SortExec: expr=[c1@0 ASC]", - // SortExec doesn't benefit from input partitioning - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_ignores_limit() -> Result<()> { - let plan = aggregate(limit_exec(filter_exec(limit_exec(parquet_exec())))); - - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[]", - "CoalescePartitionsExec", - "AggregateExec: mode=Partial, gby=[], aggr=[]", - "RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - "GlobalLimitExec: skip=0, fetch=100", - "CoalescePartitionsExec", - "LocalLimitExec: fetch=100", - "FilterExec: c1@0", - // repartition should happen prior to the filter to maximize parallelism - "RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - "GlobalLimitExec: skip=0, fetch=100", - // Limit doesn't benefit from input partitionins - no parallelism - "LocalLimitExec: fetch=100", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_union_inputs() -> Result<()> { - let plan = union_exec(vec![parquet_exec(); 5]); - - let expected = &[ - "UnionExec", - // Union doesn't benefit from input partitioning - no parallelism - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { - // sort preserving merge already sorted input, - let plan = sort_preserving_merge_exec(parquet_exec_sorted()); - - // parallelization potentially could break sort order - let expected = &[ - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_sort_preserving_merge_with_union() -> Result<()> { - // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) - let input = union_exec(vec![parquet_exec_sorted(); 2]); - let plan = sort_preserving_merge_exec(input); - - // should not repartition / sort (as the data was already sorted) - let expected = &[ - "SortPreservingMergeExec: [c1@0 ASC]", - "UnionExec", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_does_not_destroy_sort() -> Result<()> { - // SortRequired - // Parquet(sorted) - - let plan = sort_required_exec(parquet_exec_sorted()); - - // no parallelization to preserve sort order - let expected = &[ - "SortRequiredExec", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - #[test] - fn parallelization_ignores_transitively_with_projection() -> Result<()> { - // sorted input - let plan = sort_preserving_merge_exec(projection_exec(parquet_exec_sorted())); - - // data should not be repartitioned / resorted - let expected = &[ - "ProjectionExec: expr=[c1@0 as c1]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[c1], output_ordering=[c1@0 ASC]", - ]; - - assert_optimized!(expected, plan, 2, true, 10); - Ok(()) - } - - /// Models operators like BoundedWindowExec that require an input - /// ordering but is easy to construct - #[derive(Debug)] - struct SortRequiredExec { - input: Arc, - } - - impl SortRequiredExec { - fn new(input: Arc) -> Self { - Self { input } - } - } - - impl ExecutionPlan for SortRequiredExec { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> SchemaRef { - self.input.schema() - } - - fn output_partitioning(&self) -> crate::physical_plan::Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.input.output_ordering() - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec>> { - vec![self - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs)] - } - - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - assert_eq!(children.len(), 1); - let child = children.pop().unwrap(); - Ok(Arc::new(Self::new(child))) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unreachable!(); - } - - fn statistics(&self) -> Statistics { - self.input.statistics() - } - - fn fmt_as( - &self, - _t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - write!(f, "SortRequiredExec") - } - } -} diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs new file mode 100644 index 0000000000000..09274938cbcea --- /dev/null +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -0,0 +1,961 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule that replaces executors that lose ordering with their +//! order-preserving variants when it is helpful; either in terms of +//! performance or to accommodate unbounded streams by fixing the pipeline. + +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_optimizer::utils::{is_coalesce_partitions, is_sort, ExecTree}; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + +use super::utils::is_repartition; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_physical_plan::unbounded_output; + +/// For a given `plan`, this object carries the information one needs from its +/// descendants to decide whether it is beneficial to replace order-losing (but +/// somewhat faster) variants of certain operators with their order-preserving +/// (but somewhat slower) cousins. +#[derive(Debug, Clone)] +pub(crate) struct OrderPreservationContext { + pub(crate) plan: Arc, + ordering_onwards: Vec>, +} + +impl OrderPreservationContext { + /// Creates a "default" order-preservation context. + pub fn new(plan: Arc) -> Self { + let length = plan.children().len(); + OrderPreservationContext { + plan, + ordering_onwards: vec![None; length], + } + } + + /// Creates a new order-preservation context from those of children nodes. + pub fn new_from_children_nodes( + children_nodes: Vec, + parent_plan: Arc, + ) -> Result { + let children_plans = children_nodes + .iter() + .map(|item| item.plan.clone()) + .collect(); + let ordering_onwards = children_nodes + .into_iter() + .enumerate() + .map(|(idx, item)| { + // `ordering_onwards` tree keeps track of executors that maintain + // ordering, (or that can maintain ordering with the replacement of + // its variant) + let plan = item.plan; + let children = plan.children(); + let ordering_onwards = item.ordering_onwards; + if children.is_empty() { + // Plan has no children, there is nothing to propagate. + None + } else if ordering_onwards[0].is_none() + && ((is_repartition(&plan) && !plan.maintains_input_order()[0]) + || (is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some())) + { + Some(ExecTree::new(plan, idx, vec![])) + } else { + let children = ordering_onwards + .into_iter() + .flatten() + .filter(|item| { + // Only consider operators that maintains ordering + plan.maintains_input_order()[item.idx] + || is_coalesce_partitions(&plan) + || is_repartition(&plan) + }) + .collect::>(); + if children.is_empty() { + None + } else { + Some(ExecTree::new(plan, idx, children)) + } + } + }) + .collect(); + let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); + Ok(OrderPreservationContext { + plan, + ordering_onwards, + }) + } + + /// Computes order-preservation contexts for every child of the plan. + pub fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(OrderPreservationContext::new) + .collect() + } +} + +impl TreeNode for OrderPreservationContext { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in self.children() { + match op(&child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_nodes = children + .into_iter() + .map(transform) + .collect::>>()?; + OrderPreservationContext::new_from_children_nodes(children_nodes, self.plan) + } + } +} + +/// Calculates the updated plan by replacing executors that lose ordering +/// inside the `ExecTree` with their order-preserving variants. This will +/// generate an alternative plan, which will be accepted or rejected later on +/// depending on whether it helps us remove a `SortExec`. +fn get_updated_plan( + exec_tree: &ExecTree, + // Flag indicating that it is desirable to replace `RepartitionExec`s with + // `SortPreservingRepartitionExec`s: + is_spr_better: bool, + // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s + // with `SortPreservingMergeExec`s: + is_spm_better: bool, +) -> Result> { + let plan = exec_tree.plan.clone(); + + let mut children = plan.children(); + // Update children and their descendants in the given tree: + for item in &exec_tree.children { + children[item.idx] = get_updated_plan(item, is_spr_better, is_spm_better)?; + } + // Construct the plan with updated children: + let mut plan = plan.with_new_children(children)?; + + // When a `RepartitionExec` doesn't preserve ordering, replace it with + // a `SortPreservingRepartitionExec` if appropriate: + if is_repartition(&plan) && !plan.maintains_input_order()[0] && is_spr_better { + let child = plan.children().swap_remove(0); + let repartition = RepartitionExec::try_new(child, plan.output_partitioning())? + .with_preserve_order(); + plan = Arc::new(repartition) as _ + } + // When the input of a `CoalescePartitionsExec` has an ordering, replace it + // with a `SortPreservingMergeExec` if appropriate: + let mut children = plan.children(); + if is_coalesce_partitions(&plan) + && children[0].output_ordering().is_some() + && is_spm_better + { + let child = children.swap_remove(0); + plan = Arc::new(SortPreservingMergeExec::new( + child.output_ordering().unwrap_or(&[]).to_vec(), + child, + )) as _ + } + Ok(plan) +} + +/// The `replace_with_order_preserving_variants` optimizer sub-rule tries to +/// remove `SortExec`s from the physical plan by replacing operators that do +/// not preserve ordering with their order-preserving variants; i.e. by replacing +/// `RepartitionExec`s with `SortPreservingRepartitionExec`s or by replacing +/// `CoalescePartitionsExec`s with `SortPreservingMergeExec`s. +/// +/// If this replacement is helpful for removing a `SortExec`, it updates the plan. +/// Otherwise, it leaves the plan unchanged. +/// +/// Note: this optimizer sub-rule will only produce `SortPreservingRepartitionExec`s +/// if the query is bounded or if the config option `bounded_order_preserving_variants` +/// is set to `true`. +/// +/// The algorithm flow is simply like this: +/// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. +/// 1_1. During the traversal, build an `ExecTree` to keep track of operators +/// that maintain ordering (or can maintain ordering when replaced by an +/// order-preserving variant) until a `SortExec` is found. +/// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing +/// operators that do not preserve ordering in the `ExecTree` with their order +/// preserving variants. +/// 3. Check if the `SortExec` is still necessary in the updated plan by comparing +/// its input ordering with the output ordering it imposes. We do this because +/// replacing operators that lose ordering with their order-preserving variants +/// enables us to preserve the previously lost ordering at the input of `SortExec`. +/// 4. If the `SortExec` in question turns out to be unnecessary, remove it and use +/// updated plan. Otherwise, use the original plan. +/// 5. Continue the bottom-up traversal until another `SortExec` is seen, or the traversal +/// is complete. +pub(crate) fn replace_with_order_preserving_variants( + requirements: OrderPreservationContext, + // A flag indicating that replacing `RepartitionExec`s with + // `SortPreservingRepartitionExec`s is desirable when it helps + // to remove a `SortExec` from the plan. If this flag is `false`, + // this replacement should only be made to fix the pipeline (streaming). + is_spr_better: bool, + // A flag indicating that replacing `CoalescePartitionsExec`s with + // `SortPreservingMergeExec`s is desirable when it helps to remove + // a `SortExec` from the plan. If this flag is `false`, this replacement + // should only be made to fix the pipeline (streaming). + is_spm_better: bool, + config: &ConfigOptions, +) -> Result> { + let plan = &requirements.plan; + let ordering_onwards = &requirements.ordering_onwards; + if is_sort(plan) { + let exec_tree = if let Some(exec_tree) = &ordering_onwards[0] { + exec_tree + } else { + return Ok(Transformed::No(requirements)); + }; + // For unbounded cases, replace with the order-preserving variant in + // any case, as doing so helps fix the pipeline. + // Also do the replacement if opted-in via config options. + let use_order_preserving_variant = + config.optimizer.prefer_existing_sort || unbounded_output(plan); + let updated_sort_input = get_updated_plan( + exec_tree, + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + )?; + // If this sort is unnecessary, we should remove it and update the plan: + if updated_sort_input + .equivalence_properties() + .ordering_satisfy(plan.output_ordering().unwrap_or(&[])) + { + return Ok(Transformed::Yes(OrderPreservationContext { + plan: updated_sort_input, + ordering_onwards: vec![None], + })); + } + } + + Ok(Transformed::No(requirements)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::datasource::file_format::file_compression_type::FileCompressionType; + use crate::datasource::listing::PartitionedFile; + use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; + use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; + use crate::physical_plan::repartition::RepartitionExec; + use crate::physical_plan::sorts::sort::SortExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::{displayable, get_plan_string, Partitioning}; + use crate::prelude::SessionConfig; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::tree_node::TreeNode; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_expr::{JoinType, Operator}; + use datafusion_physical_expr::expressions::{self, col, Column}; + use datafusion_physical_expr::PhysicalSortExpr; + + /// Runs the `replace_with_order_preserving_variants` sub-rule and asserts the plan + /// against the original and expected plans. + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan + /// `$PLAN`: the plan to optimized + /// `$ALLOW_BOUNDED`: whether to allow the plan to be optimized for bounded cases + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + assert_optimized!( + $EXPECTED_PLAN_LINES, + $EXPECTED_OPTIMIZED_PLAN_LINES, + $PLAN, + false + ); + }; + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $ALLOW_BOUNDED: expr) => { + let physical_plan = $PLAN; + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); + + // Run the rule top-down + // let optimized_physical_plan = physical_plan.transform_down(&replace_repartition_execs)?; + let config = SessionConfig::new().with_prefer_existing_sort($ALLOW_BOUNDED); + let plan_with_pipeline_fixer = OrderPreservationContext::new(physical_plan); + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options()))?; + let optimized_physical_plan = parallel.plan; + + // Get string representation of the plan + let actual = get_plan_string(&optimized_physical_plan); + assert_eq!( + expected_optimized_lines, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + }; + } + + #[tokio::test] + // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected + async fn test_replace_multiple_input_repartition_1() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); + let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_inter_children_change_only() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr_default("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let sort = sort_exec( + vec![sort_expr_default("a", &coalesce_partitions.schema())], + coalesce_partitions, + false, + ); + let repartition_rr2 = repartition_exec_round_robin(sort); + let repartition_hash2 = repartition_exec_hash(repartition_rr2); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); + + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort2.schema())], + sort2, + ); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[a@0 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + ]; + + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_replace_multiple_input_repartition_2() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let filter = filter_exec(repartition_rr); + let repartition_hash = repartition_exec_hash(filter); + let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_replace_multiple_input_repartition_with_extra_steps() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let filter = filter_exec(repartition_hash); + let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); + let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_replace_multiple_input_repartition_with_extra_steps_2() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); + let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); + let filter = filter_exec(repartition_hash); + let coalesce_batches_exec_2 = coalesce_batches_exec(filter); + let sort = + sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_not_replacing_when_no_need_to_preserve_sorting() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let filter = filter_exec(repartition_hash); + let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); + + let physical_plan: Arc = + coalesce_partitions_exec(coalesce_batches_exec); + + let expected_input = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_multiple_replacable_repartitions() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let filter = filter_exec(repartition_hash); + let coalesce_batches = coalesce_batches_exec(filter); + let repartition_hash_2 = repartition_exec_hash(coalesce_batches); + let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_not_replace_with_different_orderings() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let sort = sort_exec( + vec![sort_expr_default("c", &repartition_hash.schema())], + repartition_hash, + true, + ); + + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort.schema())], + sort, + ); + + let expected_input = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_lost_ordering() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let physical_plan = + sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); + + let expected_input = [ + "SortExec: expr=[a@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_lost_and_kept_ordering() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, true); + let repartition_rr = repartition_exec_round_robin(source); + let repartition_hash = repartition_exec_hash(repartition_rr); + let coalesce_partitions = coalesce_partitions_exec(repartition_hash); + let sort = sort_exec( + vec![sort_expr_default("c", &coalesce_partitions.schema())], + coalesce_partitions, + false, + ); + let repartition_rr2 = repartition_exec_round_robin(sort); + let repartition_hash2 = repartition_exec_hash(repartition_rr2); + let filter = filter_exec(repartition_hash2); + let sort2 = + sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); + + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("c", &sort2.schema())], + sort2, + ); + + let expected_input = [ + "SortPreservingMergeExec: [c@1 ASC]", + " SortExec: expr=[c@1 ASC]", + " FilterExec: c@1 > 3", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + let expected_optimized = [ + "SortPreservingMergeExec: [c@1 ASC]", + " FilterExec: c@1 > 3", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=c@1 ASC", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " SortExec: expr=[c@1 ASC]", + " CoalescePartitionsExec", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_multiple_child_trees() -> Result<()> { + let schema = create_test_schema()?; + + let left_sort_exprs = vec![sort_expr("a", &schema)]; + let left_source = csv_exec_sorted(&schema, left_sort_exprs, true); + let left_repartition_rr = repartition_exec_round_robin(left_source); + let left_repartition_hash = repartition_exec_hash(left_repartition_rr); + let left_coalesce_partitions = + Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); + + let right_sort_exprs = vec![sort_expr("a", &schema)]; + let right_source = csv_exec_sorted(&schema, right_sort_exprs, true); + let right_repartition_rr = repartition_exec_round_robin(right_source); + let right_repartition_hash = repartition_exec_hash(right_repartition_rr); + let right_coalesce_partitions = + Arc::new(CoalesceBatchesExec::new(right_repartition_hash, 4096)); + + let hash_join_exec = + hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); + let sort = sort_exec( + vec![sort_expr_default("a", &hash_join_exec.schema())], + hash_join_exec, + true, + ); + + let physical_plan = sort_preserving_merge_exec( + vec![sort_expr_default("a", &sort.schema())], + sort, + ); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_with_bounded_input() -> Result<()> { + let schema = create_test_schema()?; + let sort_exprs = vec![sort_expr("a", &schema)]; + let source = csv_exec_sorted(&schema, sort_exprs, false); + let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); + let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); + + let physical_plan = + sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + + let expected_input = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST]", + " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", + " SortPreservingRepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, sort_exprs=a@0 ASC NULLS LAST", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) + } + + // End test cases + // Start test helpers + + fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + let sort_opts = SortOptions { + nulls_first: false, + descending: false, + }; + sort_expr_options(name, schema, sort_opts) + } + + fn sort_expr_default(name: &str, schema: &Schema) -> PhysicalSortExpr { + let sort_opts = SortOptions::default(); + sort_expr_options(name, schema, sort_opts) + } + + fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, + ) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } + } + + fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, + preserve_partitioning: bool, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new( + SortExec::new(sort_exprs, input) + .with_preserve_partitioning(preserve_partitioning), + ) + } + + fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + } + + fn repartition_exec_round_robin( + input: Arc, + ) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(8)).unwrap(), + ) + } + + fn repartition_exec_hash(input: Arc) -> Arc { + let input_schema = input.schema(); + Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash(vec![col("c", &input_schema).unwrap()], 8), + ) + .unwrap(), + ) + } + + fn filter_exec(input: Arc) -> Arc { + let input_schema = input.schema(); + let predicate = expressions::binary( + col("c", &input_schema).unwrap(), + Operator::Gt, + expressions::lit(3i32), + &input_schema, + ) + .unwrap(); + Arc::new(FilterExec::try_new(predicate, input).unwrap()) + } + + fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 8192)) + } + + fn coalesce_partitions_exec(input: Arc) -> Arc { + Arc::new(CoalescePartitionsExec::new(input)) + } + + fn hash_join_exec( + left: Arc, + right: Arc, + ) -> Arc { + let left_on = col("c", &left.schema()).unwrap(); + let right_on = col("c", &right.schema()).unwrap(); + let left_col = left_on.as_any().downcast_ref::().unwrap(); + let right_col = right_on.as_any().downcast_ref::().unwrap(); + Arc::new( + HashJoinExec::try_new( + left, + right, + vec![(left_col.clone(), right_col.clone())], + None, + &JoinType::Inner, + PartitionMode::Partitioned, + false, + ) + .unwrap(), + ) + } + + fn create_test_schema() -> Result { + let column_a = Field::new("a", DataType::Int32, false); + let column_b = Field::new("b", DataType::Int32, false); + let column_c = Field::new("c", DataType::Int32, false); + let column_d = Field::new("d", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![column_a, column_b, column_c, column_d])); + + Ok(schema) + } + + // creates a csv exec source for the test purposes + // projection and has_header parameters are given static due to testing needs + fn csv_exec_sorted( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + infinite_source: bool, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + let projection: Vec = vec![0, 2, 3]; + + Arc::new(CsvExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new( + "file_path".to_string(), + 100, + )]], + statistics: Statistics::new_unknown(schema), + projection: Some(projection), + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + infinite_source, + }, + true, + 0, + b'"', + None, + FileCompressionType::UNCOMPRESSED, + )) + } +} diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 20a5038b7aa71..b9502d92ac12f 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -14,31 +14,35 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use crate::physical_optimizer::utils::{add_sort_above, is_limit, is_union, is_window}; + +use std::sync::Arc; + +use crate::physical_optimizer::utils::{ + add_sort_above, is_limit, is_sort_preserving_merge, is_union, is_window, +}; use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::utils::JoinSide; -use crate::physical_plan::joins::SortMergeJoinExec; +use crate::physical_plan::joins::utils::calculate_join_output_ordering; +use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{ - ordering_satisfy_requirement, requirements_compatible, +use datafusion_physical_expr::{ + LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; + use itertools::izip; -use std::ops::Deref; -use std::sync::Arc; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total /// computational cost by pushing down `SortExec`s through some executors. /// -/// [`EnforceSorting`]: crate::physical_optimizer::sort_enforcement::EnforceSorting +/// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting #[derive(Debug, Clone)] pub(crate) struct SortPushDown { /// Current plan @@ -120,35 +124,31 @@ pub(crate) fn pushdown_sorts( requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; - let parent_required = requirements.required_ordering.as_deref(); - const ERR_MSG: &str = "Expects parent requirement to contain something"; - let err = || DataFusionError::Plan(ERR_MSG.to_string()); + let parent_required = requirements.required_ordering.as_deref().unwrap_or(&[]); if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let mut new_plan = plan.clone(); - if !ordering_satisfy_requirement( - plan.output_ordering(), - parent_required, - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) { + let new_plan = if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { // If the current plan is a SortExec, modify it to satisfy parent requirements: - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); - new_plan = sort_exec.input.clone(); - add_sort_above(&mut new_plan, parent_required_expr)?; + let mut new_plan = sort_exec.input().clone(); + add_sort_above(&mut new_plan, parent_required, sort_exec.fetch()); + new_plan + } else { + requirements.plan }; let required_ordering = new_plan .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs); + .map(PhysicalSortRequirement::from_sort_exprs) + .unwrap_or_default(); // Since new_plan is a SortExec, we can safely get the 0th index. - let child = &new_plan.children()[0]; + let child = new_plan.children().swap_remove(0); if let Some(adjusted) = - pushdown_requirement_to_children(child, required_ordering.as_deref())? + pushdown_requirement_to_children(&child, &required_ordering)? { // Can push down requirements Ok(Transformed::Yes(SortPushDown { - plan: child.clone(), + plan: child, required_ordering: None, adjusted_request_ordering: adjusted, })) @@ -158,12 +158,10 @@ pub(crate) fn pushdown_sorts( } } else { // Executors other than SortExec - if ordering_satisfy_requirement( - plan.output_ordering(), - parent_required, - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) { + if plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required) + { // Satisfies parent requirements, immediately return. return Ok(Transformed::Yes(SortPushDown { required_ordering: None, @@ -173,17 +171,14 @@ pub(crate) fn pushdown_sorts( // Can not satisfy the parent requirements, check whether the requirements can be pushed down: if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_required)? { Ok(Transformed::Yes(SortPushDown { - plan: plan.clone(), + plan: requirements.plan, required_ordering: None, adjusted_request_ordering: adjusted, })) } else { // Can not push down requirements, add new SortExec: - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); - let mut new_plan = plan.clone(); - add_sort_above(&mut new_plan, parent_required_expr)?; + let mut new_plan = requirements.plan; + add_sort_above(&mut new_plan, parent_required, None); Ok(Transformed::Yes(SortPushDown::init(new_plan))) } } @@ -191,18 +186,21 @@ pub(crate) fn pushdown_sorts( fn pushdown_requirement_to_children( plan: &Arc, - parent_required: Option<&[PhysicalSortRequirement]>, + parent_required: LexRequirementRef, ) -> Result>>>> { - const ERR_MSG: &str = "Expects parent requirement to contain something"; - let err = || DataFusionError::Plan(ERR_MSG.to_string()); let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].as_deref(); - let child_plan = plan.children()[0].clone(); + let request_child = required_input_ordering[0].as_deref().unwrap_or(&[]); + let child_plan = plan.children().swap_remove(0); match determine_children_requirement(parent_required, request_child, child_plan) { RequirementsCompatibility::Satisfy => { - Ok(Some(vec![request_child.map(|r| r.to_vec())])) + let req = if request_child.is_empty() { + None + } else { + Some(request_child.to_vec()) + }; + Ok(Some(vec![req])) } RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])), RequirementsCompatibility::NonCompatible => Ok(None), @@ -210,44 +208,37 @@ fn pushdown_requirement_to_children( } else if is_union(plan) { // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - Ok(Some(vec![ - parent_required.map(|elem| elem.to_vec()); - plan.children().len() - ])) + let req = if parent_required.is_empty() { + None + } else { + Some(parent_required.to_vec()) + }; + Ok(Some(vec![req; plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { // If the current plan is SortMergeJoinExec - let left_columns_len = smj.left.schema().fields().len(); - let parent_required_expr = PhysicalSortRequirement::to_sort_exprs( - parent_required.ok_or_else(err)?.iter().cloned(), - ); + let left_columns_len = smj.left().schema().fields().len(); + let parent_required_expr = + PhysicalSortRequirement::to_sort_exprs(parent_required.iter().cloned()); let expr_source_side = - expr_source_sides(&parent_required_expr, smj.join_type, left_columns_len); + expr_source_sides(&parent_required_expr, smj.join_type(), left_columns_len); match expr_source_side { - Some(JoinSide::Left) if maintains_input_order[0] => { + Some(JoinSide::Left) => try_pushdown_requirements_to_join( + smj, + parent_required, + parent_required_expr, + JoinSide::Left, + ), + Some(JoinSide::Right) => { + let right_offset = + smj.schema().fields.len() - smj.right().schema().fields.len(); + let new_right_required = + shift_right_required(parent_required, right_offset)?; + let new_right_required_expr = + PhysicalSortRequirement::to_sort_exprs(new_right_required); try_pushdown_requirements_to_join( - plan, + smj, parent_required, - parent_required_expr, - JoinSide::Left, - ) - } - Some(JoinSide::Right) if maintains_input_order[1] => { - let new_right_required = match smj.join_type { - JoinType::Inner | JoinType::Right => shift_right_required( - parent_required.ok_or_else(err)?, - left_columns_len, - )?, - JoinType::RightSemi | JoinType::RightAnti => { - parent_required.ok_or_else(err)?.to_vec() - } - _ => Err(DataFusionError::Plan( - "Unexpected SortMergeJoin type here".to_string(), - ))?, - }; - try_pushdown_requirements_to_join( - plan, - Some(new_right_required.deref()), - parent_required_expr, + new_right_required_expr, JoinSide::Right, ) } @@ -263,16 +254,45 @@ fn pushdown_requirement_to_children( // TODO: Add support for Projection push down || plan.as_any().is::() || is_limit(plan) + || plan.as_any().is::() { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. // Pushing down is not beneficial Ok(None) + } else if is_sort_preserving_merge(plan) { + let new_ordering = + PhysicalSortRequirement::to_sort_exprs(parent_required.to_vec()); + let mut spm_eqs = plan.equivalence_properties(); + // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. + spm_eqs = spm_eqs.with_reorder(new_ordering); + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or(&[])) { + Ok(None) + } else { + // Can push-down through SortPreservingMergeExec, because parent requirement is finer + // than SortPreservingMergeExec output ordering. + let req = if parent_required.is_empty() { + None + } else { + Some(parent_required.to_vec()) + }; + Ok(Some(vec![req])) + } } else { - Ok(Some(vec![ - parent_required.map(|elem| elem.to_vec()); - plan.children().len() - ])) + Ok(Some( + maintains_input_order + .into_iter() + .map(|flag| { + if flag && !parent_required.is_empty() { + Some(parent_required.to_vec()) + } else { + None + } + }) + .collect(), + )) } // TODO: Add support for Projection push down } @@ -282,64 +302,71 @@ fn pushdown_requirement_to_children( /// If the the parent requirements are more specific, push down the parent requirements /// If they are not compatible, need to add Sort. fn determine_children_requirement( - parent_required: Option<&[PhysicalSortRequirement]>, - request_child: Option<&[PhysicalSortRequirement]>, + parent_required: LexRequirementRef, + request_child: LexRequirementRef, child_plan: Arc, ) -> RequirementsCompatibility { - if requirements_compatible( - request_child, - parent_required, - || child_plan.ordering_equivalence_properties(), - || child_plan.equivalence_properties(), - ) { + if child_plan + .equivalence_properties() + .requirements_compatible(request_child, parent_required) + { // request child requirements are more specific, no need to push down the parent requirements RequirementsCompatibility::Satisfy - } else if requirements_compatible( - parent_required, - request_child, - || child_plan.ordering_equivalence_properties(), - || child_plan.equivalence_properties(), - ) { + } else if child_plan + .equivalence_properties() + .requirements_compatible(parent_required, request_child) + { // parent requirements are more specific, adjust the request child requirements and push down the new requirements - let adjusted = parent_required.map(|r| r.to_vec()); + let adjusted = if parent_required.is_empty() { + None + } else { + Some(parent_required.to_vec()) + }; RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible } } - fn try_pushdown_requirements_to_join( - plan: &Arc, - parent_required: Option<&[PhysicalSortRequirement]>, + smj: &SortMergeJoinExec, + parent_required: LexRequirementRef, sort_expr: Vec, push_side: JoinSide, ) -> Result>>>> { - let child_idx = match push_side { - JoinSide::Left => 0, - JoinSide::Right => 1, + let left_ordering = smj.left().output_ordering().unwrap_or(&[]); + let right_ordering = smj.right().output_ordering().unwrap_or(&[]); + let (new_left_ordering, new_right_ordering) = match push_side { + JoinSide::Left => (sort_expr.as_slice(), right_ordering), + JoinSide::Right => (left_ordering, sort_expr.as_slice()), }; - let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[child_idx].as_deref(); - let child_plan = plan.children()[child_idx].clone(); - match determine_children_requirement(parent_required, request_child, child_plan) { - RequirementsCompatibility::Satisfy => Ok(None), - RequirementsCompatibility::Compatible(adjusted) => { - let new_adjusted = match push_side { - JoinSide::Left => { - vec![adjusted, required_input_ordering[1].clone()] - } - JoinSide::Right => { - vec![required_input_ordering[0].clone(), adjusted] - } - }; - Ok(Some(new_adjusted)) - } - RequirementsCompatibility::NonCompatible => { - // Can not push down, add new SortExec - add_sort_above(&mut plan.clone(), sort_expr)?; - Ok(None) + let join_type = smj.join_type(); + let probe_side = SortMergeJoinExec::probe_side(&join_type); + let new_output_ordering = calculate_join_output_ordering( + new_left_ordering, + new_right_ordering, + join_type, + smj.on(), + smj.left().schema().fields.len(), + &smj.maintains_input_order(), + Some(probe_side), + ); + let mut smj_eqs = smj.equivalence_properties(); + // smj will have this ordering when its input changes. + smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default()); + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); + Ok(should_pushdown.then(|| { + let mut required_input_ordering = smj.required_input_ordering(); + let new_req = Some(PhysicalSortRequirement::from_sort_exprs(&sort_expr)); + match push_side { + JoinSide::Left => { + required_input_ordering[0] = new_req; + } + JoinSide::Right => { + required_input_ordering[1] = new_req; + } } - } + required_input_ordering + })) } fn expr_source_sides( @@ -391,7 +418,7 @@ fn expr_source_sides( } fn shift_right_required( - parent_required: &[PhysicalSortRequirement], + parent_required: LexRequirementRef, left_columns_len: usize, ) -> Result> { let new_right_required: Vec = parent_required @@ -413,10 +440,9 @@ fn shift_right_required( if new_right_required.len() == parent_required.len() { Ok(new_right_required) } else { - Err(DataFusionError::Plan( + plan_err!( "Expect to shift all the parent required column indexes for SortMergeJoin" - .to_string(), - )) + ) } } diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 8689b016b01c1..37a76eff1ee28 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -17,10 +17,35 @@ //! Collection of testing utility functions that are leveraged by the query optimizer rules +use std::sync::Arc; + +use crate::datasource::listing::PartitionedFile; +use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; use crate::error::Result; +use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{JoinFilter, JoinOn}; +use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::union::UnionExec; +use crate::physical_plan::windows::create_window_expr; +use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; + +use arrow_schema::{Schema, SchemaRef, SortOptions}; +use datafusion_common::{JoinType, Statistics}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use async_trait::async_trait; -use std::sync::Arc; async fn register_current_csv( ctx: &SessionContext, @@ -115,14 +140,13 @@ impl QueryCase { async fn run_case(&self, ctx: SessionContext, error: Option<&String>) -> Result<()> { let dataframe = ctx.sql(self.sql.as_str()).await?; let plan = dataframe.create_physical_plan().await; - if error.is_some() { + if let Some(error) = error { let plan_error = plan.unwrap_err(); - let initial = error.unwrap().to_string(); assert!( - plan_error.to_string().contains(initial.as_str()), + plan_error.to_string().contains(error.as_str()), "plan_error: {:?} doesn't contain message: {:?}", plan_error, - initial.as_str() + error.as_str() ); } else { assert!(plan.is_ok()) @@ -130,3 +154,207 @@ impl QueryCase { Ok(()) } } + +pub fn sort_merge_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, +) -> Arc { + Arc::new( + SortMergeJoinExec::try_new( + left, + right, + join_on.clone(), + *join_type, + vec![SortOptions::default(); join_on.len()], + false, + ) + .unwrap(), + ) +} + +/// make PhysicalSortExpr with default options +pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + sort_expr_options(name, schema, SortOptions::default()) +} + +/// PhysicalSortExpr with specified options +pub fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, +) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } +} + +pub fn coalesce_partitions_exec(input: Arc) -> Arc { + Arc::new(CoalescePartitionsExec::new(input)) +} + +pub(crate) fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) +} + +pub fn hash_join_exec( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, +) -> Result> { + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + filter, + join_type, + PartitionMode::Partitioned, + true, + )?)) +} + +pub fn bounded_window_exec( + col_name: &str, + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); + let schema = input.schema(); + + Arc::new( + crate::physical_plan::windows::BoundedWindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col(col_name, &schema).unwrap()], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + ) + .unwrap()], + input.clone(), + vec![], + InputOrderMode::Sorted, + ) + .unwrap(), + ) +} + +pub fn filter_exec( + predicate: Arc, + input: Arc, +) -> Arc { + Arc::new(FilterExec::try_new(predicate, input).unwrap()) +} + +pub fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) +} + +/// Create a non sorted parquet exec +pub fn parquet_exec(schema: &SchemaRef) -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + )) +} + +// Created a sorted parquet exec +pub fn parquet_exec_sorted( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::new_unknown(schema), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + infinite_source: false, + }, + None, + None, + )) +} + +pub fn union_exec(input: Vec>) -> Arc { + Arc::new(UnionExec::new(input)) +} + +pub fn limit_exec(input: Arc) -> Arc { + global_limit_exec(local_limit_exec(input)) +} + +pub fn local_limit_exec(input: Arc) -> Arc { + Arc::new(LocalLimitExec::new(input, 100)) +} + +pub fn global_limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +} + +pub fn repartition_exec(input: Arc) -> Arc { + Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) +} + +pub fn spr_repartition_exec(input: Arc) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(), + ) +} + +pub fn aggregate_exec(input: Arc) -> Arc { + let schema = input.schema(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![], + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +pub fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 128)) +} + +pub fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::new(sort_exprs, input)) +} diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs new file mode 100644 index 0000000000000..52d34d4f81986 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that detects aggregate operations that could use a limited bucket count + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::ExecutionPlan; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed +pub struct TopKAggregation {} + +impl TopKAggregation { + /// Create a new `LimitAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + order: &PhysicalSortExpr, + limit: usize, + ) -> Option> { + // ensure the sort direction matches aggregate function + let (field, desc) = aggr.get_minmax_desc()?; + if desc != order.options.descending { + return None; + } + let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; + let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; + if !kt.is_primitive() && kt != DataType::Utf8 { + return None; + } + if aggr.filter_expr().iter().any(|e| e.is_some()) { + return None; + } + + // ensure the sort is on the same field as the aggregate output + let col = order.expr.as_any().downcast_ref::()?; + if col.name() != field.name() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_by().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.order_by_expr().to_vec(), + aggr.input().clone(), + aggr.input_schema(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + fn transform_sort(plan: Arc) -> Option> { + let sort = plan.as_any().downcast_ref::()?; + + let children = sort.children(); + let child = children.iter().exactly_one().ok()?; + let order = sort.output_ordering()?; + let order = order.iter().exactly_one().ok()?; + let limit = sort.fetch()?; + + let is_cardinality_preserving = |plan: Arc| { + plan.as_any() + .downcast_ref::() + .is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + }; + + let mut cardinality_preserved = true; + let mut closure = |plan: Arc| { + if !cardinality_preserved { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + // either we run into an Aggregate and transform it + match Self::transform_agg(aggr, order, limit) { + None => cardinality_preserved = false, + Some(plan) => return Ok(Transformed::Yes(plan)), + } + } else { + // or we continue down whitelisted nodes of other types + if !is_cardinality_preserving(plan.clone()) { + cardinality_preserved = false; + } + } + Ok(Transformed::No(plan)) + }; + let child = child.clone().transform_down_mut(&mut closure).ok()?; + let sort = SortExec::new(sort.expr().to_vec(), child) + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()); + Some(Arc::new(sort)) + } +} + +impl Default for TopKAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for TopKAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_topk_aggregation { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +// see `aggregate.slt` for tests diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 68efa06c3fbf5..fccc1db0d3598 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -17,11 +17,10 @@ //! Collection of utility functions that are leveraged by the query optimizer rules -use std::borrow::Borrow; -use std::collections::HashSet; +use std::fmt; +use std::fmt::Formatter; use std::sync::Arc; -use crate::error::Result; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; @@ -29,25 +28,87 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; -use crate::physical_plan::ExecutionPlan; -use datafusion_common::DataFusionError; -use datafusion_physical_expr::utils::ordering_satisfy; -use datafusion_physical_expr::PhysicalSortExpr; +use crate::physical_plan::{get_plan_string, ExecutionPlan}; + +use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; + +/// This object implements a tree that we use while keeping track of paths +/// leading to [`SortExec`]s. +#[derive(Debug, Clone)] +pub(crate) struct ExecTree { + /// The `ExecutionPlan` associated with this node + pub plan: Arc, + /// Child index of the plan in its parent + pub idx: usize, + /// Children of the plan that would need updating if we remove leaf executors + pub children: Vec, +} + +impl fmt::Display for ExecTree { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let plan_string = get_plan_string(&self.plan); + write!(f, "\nidx: {:?}", self.idx)?; + write!(f, "\nplan: {:?}", plan_string)?; + for child in self.children.iter() { + write!(f, "\nexec_tree:{}", child)?; + } + writeln!(f) + } +} + +impl ExecTree { + /// Create new Exec tree + pub fn new( + plan: Arc, + idx: usize, + children: Vec, + ) -> Self { + ExecTree { + plan, + idx, + children, + } + } +} + +/// Get `ExecTree` for each child of the plan if they are tracked. +/// # Arguments +/// +/// * `n_children` - Children count of the plan of interest +/// * `onward` - Contains `Some(ExecTree)` of the plan tracked. +/// - Contains `None` is plan is not tracked. +/// +/// # Returns +/// +/// A `Vec>` that contains tracking information of each child. +/// If a child is `None`, it is not tracked. If `Some(ExecTree)` child is tracked also. +pub(crate) fn get_children_exectrees( + n_children: usize, + onward: &Option, +) -> Vec> { + let mut children_onward = vec![None; n_children]; + if let Some(exec_tree) = &onward { + for child in &exec_tree.children { + children_onward[child.idx] = Some(child.clone()); + } + } + children_onward +} /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. pub fn add_sort_above( node: &mut Arc, - sort_expr: Vec, -) -> Result<()> { + sort_requirement: LexRequirementRef, + fetch: Option, +) { // If the ordering requirement is already satisfied, do not add a sort. - if !ordering_satisfy( - node.output_ordering(), - Some(&sort_expr), - || node.equivalence_properties(), - || node.ordering_equivalence_properties(), - ) { - let new_sort = SortExec::new(sort_expr, node.clone()); + if !node + .equivalence_properties() + .ordering_satisfy_requirement(sort_requirement) + { + let sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirement.to_vec()); + let new_sort = SortExec::new(sort_expr, node.clone()).with_fetch(fetch); *node = Arc::new(if node.output_partitioning().partition_count() > 1 { new_sort.with_preserve_partitioning(true) @@ -55,65 +116,6 @@ pub fn add_sort_above( new_sort }) as _ } - Ok(()) -} - -/// Find indices of each element in `targets` inside `items`. If one of the -/// elements is absent in `items`, returns an error. -pub fn find_indices>( - items: &[T], - targets: impl IntoIterator, -) -> Result> { - targets - .into_iter() - .map(|target| items.iter().position(|e| target.borrow().eq(e))) - .collect::>() - .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) -} - -/// Merges collections `first` and `second`, removes duplicates and sorts the -/// result, returning it as a [`Vec`]. -pub fn merge_and_order_indices, S: Borrow>( - first: impl IntoIterator, - second: impl IntoIterator, -) -> Vec { - let mut result: Vec<_> = first - .into_iter() - .map(|e| *e.borrow()) - .chain(second.into_iter().map(|e| *e.borrow())) - .collect::>() - .into_iter() - .collect(); - result.sort(); - result -} - -/// Checks whether the given index sequence is monotonically non-decreasing. -pub fn is_sorted>(sequence: impl IntoIterator) -> bool { - // TODO: Remove this function when `is_sorted` graduates from Rust nightly. - let mut previous = 0; - for item in sequence.into_iter() { - let current = *item.borrow(); - if current < previous { - return false; - } - previous = current; - } - true -} - -/// Calculates the set difference between sequences `first` and `second`, -/// returning the result as a [`Vec`]. Preserves the ordering of `first`. -pub fn set_difference, S: Borrow>( - first: impl IntoIterator, - second: impl IntoIterator, -) -> Vec { - let set: HashSet<_> = second.into_iter().map(|e| *e.borrow()).collect(); - first - .into_iter() - .map(|e| *e.borrow()) - .filter(|e| !set.contains(e)) - .collect() } /// Checks whether the given operator is a limit; @@ -152,53 +154,3 @@ pub fn is_union(plan: &Arc) -> bool { pub fn is_repartition(plan: &Arc) -> bool { plan.as_any().is::() } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_find_indices() -> Result<()> { - assert_eq!(find_indices(&[0, 3, 4], [0, 3, 4])?, vec![0, 1, 2]); - assert_eq!(find_indices(&[0, 3, 4], [0, 4, 3])?, vec![0, 2, 1]); - assert_eq!(find_indices(&[3, 0, 4], [0, 3])?, vec![1, 0]); - assert!(find_indices(&[0, 3], [0, 3, 4]).is_err()); - assert!(find_indices(&[0, 3, 4], [0, 2]).is_err()); - Ok(()) - } - - #[tokio::test] - async fn test_merge_and_order_indices() { - assert_eq!( - merge_and_order_indices([0, 3, 4], [1, 3, 5]), - vec![0, 1, 3, 4, 5] - ); - // Result should be ordered, even if inputs are not - assert_eq!( - merge_and_order_indices([3, 0, 4], [5, 1, 3]), - vec![0, 1, 3, 4, 5] - ); - } - - #[tokio::test] - async fn test_is_sorted() { - assert!(is_sorted::([])); - assert!(is_sorted([0])); - assert!(is_sorted([0, 3, 4])); - assert!(is_sorted([0, 1, 2])); - assert!(is_sorted([0, 1, 4])); - assert!(is_sorted([0usize; 0])); - assert!(is_sorted([1, 2])); - assert!(!is_sorted([3, 2])); - } - - #[tokio::test] - async fn test_set_difference() { - assert_eq!(set_difference([0, 3, 4], [1, 2]), vec![0, 3, 4]); - assert_eq!(set_difference([0, 3, 4], [1, 2, 4]), vec![0, 3]); - // return value should have same ordering with the in1 - assert_eq!(set_difference([3, 4, 0], [1, 2, 4]), vec![3, 0]); - assert_eq!(set_difference([0, 3, 4], [4, 1, 2]), vec![0, 3]); - assert_eq!(set_difference([3, 4, 0], [4, 1, 2]), vec![3, 0]); - } -} diff --git a/datafusion/core/src/physical_plan/aggregates/bounded_aggregate_stream.rs b/datafusion/core/src/physical_plan/aggregates/bounded_aggregate_stream.rs deleted file mode 100644 index 4bbac3c4a52ac..0000000000000 --- a/datafusion/core/src/physical_plan/aggregates/bounded_aggregate_stream.rs +++ /dev/null @@ -1,1044 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This file implements streaming aggregation on ordered GROUP BY expressions. -//! Generated output will itself have an ordering and the executor can run with -//! bounded memory, ensuring composability in streaming cases. - -use std::cmp::min; -use std::ops::Range; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::vec; - -use ahash::RandomState; -use futures::ready; -use futures::stream::{Stream, StreamExt}; -use hashbrown::raw::RawTable; -use itertools::izip; - -use crate::physical_plan::aggregates::{ - evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, - AggregationOrdering, GroupByOrderMode, PhysicalGroupBy, RowAccumulatorItem, -}; -use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; -use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; -use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_execution::TaskContext; - -use crate::physical_plan::aggregates::utils::{ - aggr_state_schema, col_to_scalar, get_at_indices, get_optional_filters, - read_as_batch, slice_and_maybe_filter, ExecutionState, GroupState, -}; -use arrow::array::{new_null_array, ArrayRef, UInt32Builder}; -use arrow::compute::{cast, SortColumn}; -use arrow::datatypes::DataType; -use arrow::row::{OwnedRow, RowConverter, SortField}; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::utils::{evaluate_partition_ranges, get_row_at_idx}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; -use datafusion_physical_expr::hash_utils::create_hashes; -use datafusion_row::accessor::RowAccessor; -use datafusion_row::layout::RowLayout; - -use super::AggregateExec; - -/// Grouping aggregate with row-format aggregation states inside. -/// -/// For each aggregation entry, we use: -/// - [Arrow-row] represents grouping keys for fast hash computation and comparison directly on raw bytes. -/// - [WordAligned] row to store aggregation state, designed to be CPU-friendly when updates over every field are often. -/// -/// The architecture is the following: -/// -/// 1. For each input RecordBatch, update aggregation states corresponding to all appeared grouping keys. -/// 2. At the end of the aggregation (e.g. end of batches in a partition), the accumulator converts its state to a RecordBatch of a single row -/// 3. The RecordBatches of all accumulators are merged (`concatenate` in `rust/arrow`) together to a single RecordBatch. -/// 4. The state's RecordBatch is `merge`d to a new state -/// 5. The state is mapped to the final value -/// -/// [Arrow-row]: OwnedRow -/// [WordAligned]: datafusion_row::layout -pub(crate) struct BoundedAggregateStream { - schema: SchemaRef, - input: SendableRecordBatchStream, - mode: AggregateMode, - - normal_aggr_expr: Vec>, - /// Aggregate expressions not supporting row accumulation - normal_aggregate_expressions: Vec>>, - /// Filter expression for each normal aggregate expression - normal_filter_expressions: Vec>>, - - /// Aggregate expressions supporting row accumulation - row_aggregate_expressions: Vec>>, - /// Filter expression for each row aggregate expression - row_filter_expressions: Vec>>, - row_accumulators: Vec, - row_converter: RowConverter, - row_aggr_schema: SchemaRef, - row_aggr_layout: Arc, - - group_by: PhysicalGroupBy, - - aggr_state: AggregationState, - exec_state: ExecutionState, - baseline_metrics: BaselineMetrics, - random_state: RandomState, - /// size to be used for resulting RecordBatches - batch_size: usize, - /// threshold for using `ScalarValue`s to update - /// accumulators during high-cardinality aggregations for each input batch. - scalar_update_factor: usize, - /// if the result is chunked into batches, - /// last offset is preserved for continuation. - row_group_skip_position: usize, - /// keeps range for each accumulator in the field - /// first element in the array corresponds to normal accumulators - /// second element in the array corresponds to row accumulators - indices: [Vec>; 2], - aggregation_ordering: AggregationOrdering, - is_end: bool, -} - -impl BoundedAggregateStream { - /// Create a new BoundedAggregateStream - pub fn new( - agg: &AggregateExec, - context: Arc, - partition: usize, - aggregation_ordering: AggregationOrdering, // Stores algorithm mode and output ordering - ) -> Result { - let agg_schema = Arc::clone(&agg.schema); - let agg_group_by = agg.group_by.clone(); - let agg_filter_expr = agg.filter_expr.clone(); - - let batch_size = context.session_config().batch_size(); - let scalar_update_factor = context.session_config().agg_scalar_update_factor(); - let input = agg.input.execute(partition, Arc::clone(&context))?; - let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); - - let timer = baseline_metrics.elapsed_compute().timer(); - - let mut start_idx = agg_group_by.expr.len(); - let mut row_aggr_expr = vec![]; - let mut row_agg_indices = vec![]; - let mut row_aggregate_expressions = vec![]; - let mut row_filter_expressions = vec![]; - let mut normal_aggr_expr = vec![]; - let mut normal_agg_indices = vec![]; - let mut normal_aggregate_expressions = vec![]; - let mut normal_filter_expressions = vec![]; - // The expressions to evaluate the batch, one vec of expressions per aggregation. - // Assuming create_schema() always puts group columns in front of aggregation columns, we set - // col_idx_base to the group expression count. - let all_aggregate_expressions = - aggregates::aggregate_expressions(&agg.aggr_expr, &agg.mode, start_idx)?; - let filter_expressions = match agg.mode { - AggregateMode::Partial | AggregateMode::Single => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } - }; - for ((expr, others), filter) in agg - .aggr_expr - .iter() - .zip(all_aggregate_expressions.into_iter()) - .zip(filter_expressions.into_iter()) - { - let n_fields = match agg.mode { - // In partial aggregation, we keep additional fields in order to successfully - // merge aggregation results downstream. - AggregateMode::Partial => expr.state_fields()?.len(), - _ => 1, - }; - // Stores range of each expression: - let aggr_range = Range { - start: start_idx, - end: start_idx + n_fields, - }; - if expr.row_accumulator_supported() { - row_aggregate_expressions.push(others); - row_filter_expressions.push(filter.clone()); - row_agg_indices.push(aggr_range); - row_aggr_expr.push(expr.clone()); - } else { - normal_aggregate_expressions.push(others); - normal_filter_expressions.push(filter.clone()); - normal_agg_indices.push(aggr_range); - normal_aggr_expr.push(expr.clone()); - } - start_idx += n_fields; - } - - let row_accumulators = aggregates::create_row_accumulators(&row_aggr_expr)?; - - let row_aggr_schema = aggr_state_schema(&row_aggr_expr); - - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); - let row_converter = RowConverter::new( - group_schema - .fields() - .iter() - .map(|f| SortField::new(f.data_type().clone())) - .collect(), - )?; - - let row_aggr_layout = Arc::new(RowLayout::new(&row_aggr_schema)); - - let name = format!("BoundedAggregateStream[{partition}]"); - let aggr_state = AggregationState { - reservation: MemoryConsumer::new(name).register(context.memory_pool()), - map: RawTable::with_capacity(0), - ordered_group_states: Vec::with_capacity(0), - }; - - timer.done(); - - let exec_state = ExecutionState::ReadingInput; - - Ok(BoundedAggregateStream { - schema: agg_schema, - input, - mode: agg.mode, - normal_aggr_expr, - normal_aggregate_expressions, - normal_filter_expressions, - row_aggregate_expressions, - row_filter_expressions, - row_accumulators, - row_converter, - row_aggr_schema, - row_aggr_layout, - group_by: agg_group_by, - aggr_state, - exec_state, - baseline_metrics, - random_state: Default::default(), - batch_size, - scalar_update_factor, - row_group_skip_position: 0, - indices: [normal_agg_indices, row_agg_indices], - is_end: false, - aggregation_ordering, - }) - } -} - -impl Stream for BoundedAggregateStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); - - loop { - match self.exec_state { - ExecutionState::ReadingInput => { - match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = self.group_aggregate_batch(batch); - timer.done(); - - // allocate memory - // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with - // overshooting a bit. Also this means we either store the whole record batch or not. - let result = result.and_then(|allocated| { - self.aggr_state.reservation.try_grow(allocated) - }); - - if let Err(e) = result { - return Poll::Ready(Some(Err(e))); - } - } - // inner had error, return to caller - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - // inner is done, producing output - None => { - for element in self.aggr_state.ordered_group_states.iter_mut() - { - element.status = GroupStatus::CanEmit; - } - self.exec_state = ExecutionState::ProducingOutput; - } - } - } - - ExecutionState::ProducingOutput => { - let timer = elapsed_compute.timer(); - let result = self.create_batch_from_map(); - - timer.done(); - - match result { - // made output - Ok(Some(result)) => { - let batch = result.record_output(&self.baseline_metrics); - self.row_group_skip_position += batch.num_rows(); - self.exec_state = ExecutionState::ReadingInput; - self.prune(); - return Poll::Ready(Some(Ok(batch))); - } - // end of output - Ok(None) => { - self.exec_state = ExecutionState::Done; - } - // error making output - Err(error) => return Poll::Ready(Some(Err(error))), - } - } - ExecutionState::Done => return Poll::Ready(None), - } - } - } -} - -impl RecordBatchStream for BoundedAggregateStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -/// This utility object encapsulates the row object, the hash and the group -/// indices for a group. This information is used when executing streaming -/// GROUP BY calculations. -struct GroupOrderInfo { - owned_row: OwnedRow, - hash: u64, - range: Range, -} - -impl BoundedAggregateStream { - // Update the aggr_state according to group_by values (result of group_by_expressions) when group by - // expressions are fully ordered. - fn update_ordered_group_state( - &mut self, - group_values: &[ArrayRef], - per_group_indices: Vec, - allocated: &mut usize, - ) -> Result> { - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `aggr_state` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - let AggregationState { - map: row_map, - ordered_group_states: row_group_states, - .. - } = &mut self.aggr_state; - - for GroupOrderInfo { - owned_row, - hash, - range, - } in per_group_indices - { - let entry = row_map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let ordered_group_state = &row_group_states[*group_idx]; - let group_state = &ordered_group_state.group_state; - owned_row.row() == group_state.group_by_values.row() - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut row_group_states[*group_idx].group_state; - - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; - for row in range.start..range.end { - // remember this row - group_state.indices.push_accounted(row as u32, allocated); - } - } - // 1.2 Need to create new entry - None => { - let accumulator_set = - aggregates::create_accumulators(&self.normal_aggr_expr)?; - let row = get_row_at_idx(group_values, range.start)?; - let ordered_columns = self - .aggregation_ordering - .order_indices - .iter() - .map(|idx| row[*idx].clone()) - .collect::>(); - // Add new entry to group_states and save newly created index - let group_state = GroupState { - group_by_values: owned_row, - aggregation_buffer: vec![ - 0; - self.row_aggr_layout.fixed_part_width() - ], - accumulator_set, - indices: (range.start as u32..range.end as u32) - .collect::>(), // 1.3 - }; - let group_idx = row_group_states.len(); - - // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by - // `group_states` (see allocation down below) - *allocated += std::mem::size_of_val(&group_state.group_by_values) - + (std::mem::size_of::() - * group_state.aggregation_buffer.capacity()) - + (std::mem::size_of::() * group_state.indices.capacity()); - - // Allocation done by normal accumulators - *allocated += (std::mem::size_of::>() - * group_state.accumulator_set.capacity()) - + group_state - .accumulator_set - .iter() - .map(|accu| accu.size()) - .sum::(); - - // for hasher function, use precomputed hash value - row_map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - allocated, - ); - - let ordered_group_state = OrderedGroupState { - group_state, - ordered_columns, - status: GroupStatus::GroupProgress, - hash, - }; - row_group_states.push_accounted(ordered_group_state, allocated); - - groups_with_rows.push(group_idx); - } - }; - } - Ok(groups_with_rows) - } - - // Update the aggr_state according to group_by values (result of group_by_expressions) - fn update_group_state( - &mut self, - group_values: &[ArrayRef], - allocated: &mut usize, - ) -> Result> { - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `aggr_state` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - let group_rows = self.row_converter.convert_columns(group_values)?; - let n_rows = group_rows.num_rows(); - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; n_rows]; - create_hashes(group_values, &self.random_state, &mut batch_hashes)?; - - let AggregationState { - map, - ordered_group_states: group_states, - .. - } = &mut self.aggr_state; - - for (row, hash) in batch_hashes.into_iter().enumerate() { - let entry = map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let group_state = &group_states[*group_idx].group_state; - group_rows.row(row) == group_state.group_by_values.row() - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut group_states[*group_idx].group_state; - - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; - - group_state.indices.push_accounted(row as u32, allocated); // remember this row - } - // 1.2 Need to create new entry - None => { - let accumulator_set = - aggregates::create_accumulators(&self.normal_aggr_expr)?; - let row_values = get_row_at_idx(group_values, row)?; - let ordered_columns = self - .aggregation_ordering - .order_indices - .iter() - .map(|idx| row_values[*idx].clone()) - .collect::>(); - let group_state = GroupState { - group_by_values: group_rows.row(row).owned(), - aggregation_buffer: vec![ - 0; - self.row_aggr_layout.fixed_part_width() - ], - accumulator_set, - indices: vec![row as u32], // 1.3 - }; - let group_idx = group_states.len(); - - // NOTE: do NOT include the `GroupState` struct size in here because this is captured by - // `group_states` (see allocation down below) - *allocated += std::mem::size_of_val(&group_state.group_by_values) - + (std::mem::size_of::() - * group_state.aggregation_buffer.capacity()) - + (std::mem::size_of::() * group_state.indices.capacity()); - - // Allocation done by normal accumulators - *allocated += (std::mem::size_of::>() - * group_state.accumulator_set.capacity()) - + group_state - .accumulator_set - .iter() - .map(|accu| accu.size()) - .sum::(); - - // for hasher function, use precomputed hash value - map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - allocated, - ); - - // Add new entry to group_states and save newly created index - let ordered_group_state = OrderedGroupState { - group_state, - ordered_columns, - status: GroupStatus::GroupProgress, - hash, - }; - group_states.push_accounted(ordered_group_state, allocated); - - groups_with_rows.push(group_idx); - } - }; - } - Ok(groups_with_rows) - } - - // Update the accumulator results, according to aggr_state. - #[allow(clippy::too_many_arguments)] - fn update_accumulators_using_batch( - &mut self, - groups_with_rows: &[usize], - offsets: &[usize], - row_values: &[Vec], - normal_values: &[Vec], - row_filter_values: &[Option], - normal_filter_values: &[Option], - allocated: &mut usize, - ) -> Result<()> { - // 2.1 for each key in this batch - // 2.2 for each aggregation - // 2.3 `slice` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - groups_with_rows - .iter() - .zip(offsets.windows(2)) - .try_for_each(|(group_idx, offsets)| { - let group_state = - &mut self.aggr_state.ordered_group_states[*group_idx].group_state; - // 2.2 - // Process row accumulators - self.row_accumulators - .iter_mut() - .zip(row_values.iter()) - .zip(row_filter_values.iter()) - .try_for_each(|((accumulator, aggr_array), filter_opt)| { - let values = slice_and_maybe_filter( - aggr_array, - filter_opt.as_ref(), - offsets, - )?; - let mut state_accessor = - RowAccessor::new_from_layout(self.row_aggr_layout.clone()); - state_accessor - .point_to(0, group_state.aggregation_buffer.as_mut_slice()); - match self.mode { - AggregateMode::Partial | AggregateMode::Single => { - accumulator.update_batch(&values, &mut state_accessor) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values, &mut state_accessor) - } - } - })?; - // normal accumulators - group_state - .accumulator_set - .iter_mut() - .zip(normal_values.iter()) - .zip(normal_filter_values.iter()) - .try_for_each(|((accumulator, aggr_array), filter_opt)| { - let values = slice_and_maybe_filter( - aggr_array, - filter_opt.as_ref(), - offsets, - )?; - let size_pre = accumulator.size(); - let res = match self.mode { - AggregateMode::Partial | AggregateMode::Single => { - accumulator.update_batch(&values) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values) - } - }; - let size_post = accumulator.size(); - *allocated += size_post.saturating_sub(size_pre); - res - }) - // 2.5 - .and({ - group_state.indices.clear(); - Ok(()) - }) - })?; - Ok(()) - } - - // Update the accumulator results, according to aggr_state. - fn update_accumulators_using_scalar( - &mut self, - groups_with_rows: &[usize], - row_values: &[Vec], - row_filter_values: &[Option], - ) -> Result<()> { - let filter_bool_array = row_filter_values - .iter() - .map(|filter_opt| match filter_opt { - Some(f) => Ok(Some(as_boolean_array(f)?)), - None => Ok(None), - }) - .collect::>>()?; - - for group_idx in groups_with_rows { - let group_state = - &mut self.aggr_state.ordered_group_states[*group_idx].group_state; - let mut state_accessor = - RowAccessor::new_from_layout(self.row_aggr_layout.clone()); - state_accessor.point_to(0, group_state.aggregation_buffer.as_mut_slice()); - for idx in &group_state.indices { - for (accumulator, values_array, filter_array) in izip!( - self.row_accumulators.iter_mut(), - row_values.iter(), - filter_bool_array.iter() - ) { - if values_array.len() == 1 { - let scalar_value = - col_to_scalar(&values_array[0], filter_array, *idx as usize)?; - accumulator.update_scalar(&scalar_value, &mut state_accessor)?; - } else { - let scalar_values = values_array - .iter() - .map(|array| { - col_to_scalar(array, filter_array, *idx as usize) - }) - .collect::>>()?; - accumulator - .update_scalar_values(&scalar_values, &mut state_accessor)?; - } - } - } - // clear the group indices in this group - group_state.indices.clear(); - } - - Ok(()) - } - - /// Perform group-by aggregation for the given [`RecordBatch`]. - /// - /// If successful, this returns the additional number of bytes that were allocated during this process. - /// - fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result { - // Evaluate the grouping expressions: - let group_by_values = evaluate_group_by(&self.group_by, &batch)?; - // Keep track of memory allocated: - let mut allocated = 0usize; - - // Evaluate the aggregation expressions. - // We could evaluate them after the `take`, but since we need to evaluate all - // of them anyways, it is more performant to do it while they are together. - let row_aggr_input_values = - evaluate_many(&self.row_aggregate_expressions, &batch)?; - let normal_aggr_input_values = - evaluate_many(&self.normal_aggregate_expressions, &batch)?; - let row_filter_values = evaluate_optional(&self.row_filter_expressions, &batch)?; - let normal_filter_values = - evaluate_optional(&self.normal_filter_expressions, &batch)?; - - let row_converter_size_pre = self.row_converter.size(); - for group_values in &group_by_values { - let groups_with_rows = if let AggregationOrdering { - mode: GroupByOrderMode::FullyOrdered, - order_indices, - ordering, - } = &self.aggregation_ordering - { - let group_rows = self.row_converter.convert_columns(group_values)?; - let n_rows = group_rows.num_rows(); - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; n_rows]; - create_hashes(group_values, &self.random_state, &mut batch_hashes)?; - let sort_column = order_indices - .iter() - .enumerate() - .map(|(idx, cur_idx)| SortColumn { - values: group_values[*cur_idx].clone(), - options: Some(ordering[idx].options), - }) - .collect::>(); - let n_rows = group_rows.num_rows(); - let ranges = evaluate_partition_ranges(n_rows, &sort_column)?; - let per_group_indices = ranges - .into_iter() - .map(|range| GroupOrderInfo { - owned_row: group_rows.row(range.start).owned(), - hash: batch_hashes[range.start], - range, - }) - .collect::>(); - self.update_ordered_group_state( - group_values, - per_group_indices, - &mut allocated, - )? - } else { - self.update_group_state(group_values, &mut allocated)? - }; - - // Decide the accumulators update mode, use scalar value to update the accumulators when all of the conditions are meet: - // 1) The aggregation mode is Partial or Single - // 2) There is not normal aggregation expressions - // 3) The number of affected groups is high (entries in `aggr_state` have rows need to update). Usually the high cardinality case - if matches!(self.mode, AggregateMode::Partial | AggregateMode::Single) - && normal_aggr_input_values.is_empty() - && normal_filter_values.is_empty() - && groups_with_rows.len() >= batch.num_rows() / self.scalar_update_factor - { - self.update_accumulators_using_scalar( - &groups_with_rows, - &row_aggr_input_values, - &row_filter_values, - )?; - } else { - // Collect all indices + offsets based on keys in this vec - let mut batch_indices = UInt32Builder::with_capacity(0); - let mut offsets = vec![0]; - let mut offset_so_far = 0; - for &group_idx in groups_with_rows.iter() { - let indices = &self.aggr_state.ordered_group_states[group_idx] - .group_state - .indices; - batch_indices.append_slice(indices); - offset_so_far += indices.len(); - offsets.push(offset_so_far); - } - let batch_indices = batch_indices.finish(); - - let row_filter_values = - get_optional_filters(&row_filter_values, &batch_indices); - let normal_filter_values = - get_optional_filters(&normal_filter_values, &batch_indices); - if self.aggregation_ordering.mode == GroupByOrderMode::FullyOrdered { - self.update_accumulators_using_batch( - &groups_with_rows, - &offsets, - &row_aggr_input_values, - &normal_aggr_input_values, - &row_filter_values, - &normal_filter_values, - &mut allocated, - )?; - } else { - let row_values = - get_at_indices(&row_aggr_input_values, &batch_indices)?; - let normal_values = - get_at_indices(&normal_aggr_input_values, &batch_indices)?; - self.update_accumulators_using_batch( - &groups_with_rows, - &offsets, - &row_values, - &normal_values, - &row_filter_values, - &normal_filter_values, - &mut allocated, - )?; - }; - } - } - allocated += self - .row_converter - .size() - .saturating_sub(row_converter_size_pre); - - let mut new_result = false; - let last_ordered_columns = self - .aggr_state - .ordered_group_states - .last() - .map(|item| item.ordered_columns.clone()); - - if let Some(last_ordered_columns) = last_ordered_columns { - for cur_group in &mut self.aggr_state.ordered_group_states { - if cur_group.ordered_columns != last_ordered_columns { - // We will no longer receive value. Set status to GroupStatus::CanEmit - // meaning we can generate result for this group. - cur_group.status = GroupStatus::CanEmit; - new_result = true; - } - } - } - if new_result { - self.exec_state = ExecutionState::ProducingOutput; - } - - Ok(allocated) - } -} - -#[derive(Debug, PartialEq)] -enum GroupStatus { - // `GroupProgress` means data for current group is not complete. New data may arrive. - GroupProgress, - // `CanEmit` means data for current group is completed. And its result can emitted. - CanEmit, - // Emitted means that result for the groups is outputted. Group can be pruned from state. - Emitted, -} - -/// The state that is built for each output group. -#[derive(Debug)] -pub struct OrderedGroupState { - group_state: GroupState, - ordered_columns: Vec, - status: GroupStatus, - hash: u64, -} - -/// The state of all the groups -pub struct AggregationState { - pub reservation: MemoryReservation, - - /// Logically maps group values to an index in `group_states` - /// - /// Uses the raw API of hashbrown to avoid actually storing the - /// keys in the table - /// - /// keys: u64 hashes of the GroupValue - /// values: (hash, index into `group_states`) - pub map: RawTable<(u64, usize)>, - - /// State for each group - pub ordered_group_states: Vec, -} - -impl std::fmt::Debug for AggregationState { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - // hashes are not store inline, so could only get values - let map_string = "RawTable"; - f.debug_struct("AggregationState") - .field("map", &map_string) - .field("ordered_group_states", &self.ordered_group_states) - .finish() - } -} - -impl BoundedAggregateStream { - /// Prune the groups from the `self.aggr_state.group_states` which are in - /// `GroupStatus::Emitted`(this status means that result of this group emitted/outputted already, and - /// we are sure that these groups cannot receive new rows.) status. - fn prune(&mut self) { - let n_partition = self.aggr_state.ordered_group_states.len(); - self.aggr_state - .ordered_group_states - .retain(|elem| elem.status != GroupStatus::Emitted); - let n_partition_new = self.aggr_state.ordered_group_states.len(); - let n_pruned = n_partition - n_partition_new; - self.aggr_state.map.clear(); - for (idx, item) in self.aggr_state.ordered_group_states.iter().enumerate() { - self.aggr_state - .map - .insert(item.hash, (item.hash, idx), |(hash, _)| *hash); - } - self.row_group_skip_position -= n_pruned; - } - - /// Create a RecordBatch with all group keys and accumulator' states or values. - fn create_batch_from_map(&mut self) -> Result> { - let skip_items = self.row_group_skip_position; - if skip_items > self.aggr_state.ordered_group_states.len() || self.is_end { - return Ok(None); - } - self.is_end |= skip_items == self.aggr_state.ordered_group_states.len(); - if self.aggr_state.ordered_group_states.is_empty() { - let schema = self.schema.clone(); - return Ok(Some(RecordBatch::new_empty(schema))); - } - - let end_idx = min( - skip_items + self.batch_size, - self.aggr_state.ordered_group_states.len(), - ); - let group_state_chunk = - &self.aggr_state.ordered_group_states[skip_items..end_idx]; - // Consider only the groups that can be emitted. (The ones we are sure that will not receive new entry.) - let group_state_chunk = group_state_chunk - .iter() - .filter(|item| item.status == GroupStatus::CanEmit) - .collect::>(); - - if group_state_chunk.is_empty() { - let schema = self.schema.clone(); - return Ok(Some(RecordBatch::new_empty(schema))); - } - - // Buffers for each distinct group (i.e. row accumulator memories) - let mut state_buffers = group_state_chunk - .iter() - .map(|gs| gs.group_state.aggregation_buffer.clone()) - .collect::>(); - - let output_fields = self.schema.fields(); - // Store row accumulator results (either final output or intermediate state): - let row_columns = match self.mode { - AggregateMode::Partial => { - read_as_batch(&state_buffers, &self.row_aggr_schema) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single => { - let mut results = vec![]; - for (idx, acc) in self.row_accumulators.iter().enumerate() { - let mut state_accessor = RowAccessor::new(&self.row_aggr_schema); - let current = state_buffers - .iter_mut() - .map(|buffer| { - state_accessor.point_to(0, buffer); - acc.evaluate(&state_accessor) - }) - .collect::>>()?; - // Get corresponding field for row accumulator - let field = &output_fields[self.indices[1][idx].start]; - let result = if current.is_empty() { - Ok(arrow::array::new_empty_array(field.data_type())) - } else { - let item = ScalarValue::iter_to_array(current)?; - // cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - cast(&item, field.data_type()) - }?; - results.push(result); - } - results - } - }; - - // Store normal accumulator results (either final output or intermediate state): - let mut columns = vec![]; - for (idx, &Range { start, end }) in self.indices[0].iter().enumerate() { - for (field_idx, field) in output_fields[start..end].iter().enumerate() { - let current = match self.mode { - AggregateMode::Partial => ScalarValue::iter_to_array( - group_state_chunk.iter().map(|group_state| { - group_state.group_state.accumulator_set[idx] - .state() - .map(|v| v[field_idx].clone()) - .expect("Unexpected accumulator state in hash aggregate") - }), - ), - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single => ScalarValue::iter_to_array( - group_state_chunk.iter().map(|group_state| { - group_state.group_state.accumulator_set[idx] - .evaluate() - .expect("Unexpected accumulator state in hash aggregate") - }), - ), - }?; - // Cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - let result = cast(¤t, field.data_type())?; - columns.push(result); - } - } - - // Stores the group by fields - let group_buffers = group_state_chunk - .iter() - .map(|gs| gs.group_state.group_by_values.row()) - .collect::>(); - let mut output: Vec = self.row_converter.convert_rows(group_buffers)?; - - // The size of the place occupied by row and normal accumulators - let extra: usize = self - .indices - .iter() - .flatten() - .map(|Range { start, end }| end - start) - .sum(); - let empty_arr = new_null_array(&DataType::Null, 1); - output.extend(std::iter::repeat(empty_arr).take(extra)); - - // Write results of both accumulator types to the corresponding location in - // the output schema: - let results = [columns.into_iter(), row_columns.into_iter()]; - for (outer, mut current) in results.into_iter().enumerate() { - for &Range { start, end } in self.indices[outer].iter() { - for item in output.iter_mut().take(end).skip(start) { - *item = current.next().expect("Columns cannot be empty"); - } - } - } - - // Set status of the emitted groups to GroupStatus::Emitted mode. - for gs in self.aggr_state.ordered_group_states[skip_items..end_idx].iter_mut() { - if gs.status == GroupStatus::CanEmit { - gs.status = GroupStatus::Emitted; - } - } - - Ok(Some(RecordBatch::try_new(self.schema.clone(), output)?)) - } -} diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs deleted file mode 100644 index ba02bc096bbab..0000000000000 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ /dev/null @@ -1,764 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Hash aggregation through row format - -use std::cmp::min; -use std::ops::Range; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::vec; - -use ahash::RandomState; -use arrow::row::{RowConverter, SortField}; -use datafusion_physical_expr::hash_utils::create_hashes; -use futures::ready; -use futures::stream::{Stream, StreamExt}; - -use crate::physical_plan::aggregates::utils::{ - aggr_state_schema, col_to_scalar, get_at_indices, get_optional_filters, - read_as_batch, slice_and_maybe_filter, ExecutionState, GroupState, -}; -use crate::physical_plan::aggregates::{ - evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, - PhysicalGroupBy, RowAccumulatorItem, -}; -use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; -use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; -use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; -use arrow::array::*; -use arrow::compute::cast; -use arrow::datatypes::DataType; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::{Result, ScalarValue}; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; -use datafusion_row::accessor::RowAccessor; -use datafusion_row::layout::RowLayout; -use hashbrown::raw::RawTable; -use itertools::izip; - -use super::AggregateExec; - -/// Grouping aggregate with row-format aggregation states inside. -/// -/// For each aggregation entry, we use: -/// - [Arrow-row] represents grouping keys for fast hash computation and comparison directly on raw bytes. -/// - [WordAligned] row to store aggregation state, designed to be CPU-friendly when updates over every field are often. -/// -/// The architecture is the following: -/// -/// 1. For each input RecordBatch, update aggregation states corresponding to all appeared grouping keys. -/// 2. At the end of the aggregation (e.g. end of batches in a partition), the accumulator converts its state to a RecordBatch of a single row -/// 3. The RecordBatches of all accumulators are merged (`concatenate` in `rust/arrow`) together to a single RecordBatch. -/// 4. The state's RecordBatch is `merge`d to a new state -/// 5. The state is mapped to the final value -/// -/// [WordAligned]: datafusion_row::layout -pub(crate) struct GroupedHashAggregateStream { - schema: SchemaRef, - input: SendableRecordBatchStream, - mode: AggregateMode, - - normal_aggr_expr: Vec>, - /// Aggregate expressions not supporting row accumulation - normal_aggregate_expressions: Vec>>, - /// Filter expression for each normal aggregate expression - normal_filter_expressions: Vec>>, - - /// Aggregate expressions supporting row accumulation - row_aggregate_expressions: Vec>>, - /// Filter expression for each row aggregate expression - row_filter_expressions: Vec>>, - row_accumulators: Vec, - row_converter: RowConverter, - row_aggr_schema: SchemaRef, - row_aggr_layout: Arc, - - group_by: PhysicalGroupBy, - - aggr_state: AggregationState, - exec_state: ExecutionState, - baseline_metrics: BaselineMetrics, - random_state: RandomState, - /// size to be used for resulting RecordBatches - batch_size: usize, - /// threshold for using `ScalarValue`s to update - /// accumulators during high-cardinality aggregations for each input batch. - scalar_update_factor: usize, - /// if the result is chunked into batches, - /// last offset is preserved for continuation. - row_group_skip_position: usize, - /// keeps range for each accumulator in the field - /// first element in the array corresponds to normal accumulators - /// second element in the array corresponds to row accumulators - indices: [Vec>; 2], -} - -impl GroupedHashAggregateStream { - /// Create a new GroupedHashAggregateStream - pub fn new( - agg: &AggregateExec, - context: Arc, - partition: usize, - ) -> Result { - let agg_schema = Arc::clone(&agg.schema); - let agg_group_by = agg.group_by.clone(); - let agg_filter_expr = agg.filter_expr.clone(); - - let batch_size = context.session_config().batch_size(); - let scalar_update_factor = context.session_config().agg_scalar_update_factor(); - let input = agg.input.execute(partition, Arc::clone(&context))?; - let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); - - let timer = baseline_metrics.elapsed_compute().timer(); - - let mut start_idx = agg_group_by.expr.len(); - let mut row_aggr_expr = vec![]; - let mut row_agg_indices = vec![]; - let mut row_aggregate_expressions = vec![]; - let mut row_filter_expressions = vec![]; - let mut normal_aggr_expr = vec![]; - let mut normal_agg_indices = vec![]; - let mut normal_aggregate_expressions = vec![]; - let mut normal_filter_expressions = vec![]; - // The expressions to evaluate the batch, one vec of expressions per aggregation. - // Assuming create_schema() always puts group columns in front of aggregation columns, we set - // col_idx_base to the group expression count. - let all_aggregate_expressions = - aggregates::aggregate_expressions(&agg.aggr_expr, &agg.mode, start_idx)?; - let filter_expressions = match agg.mode { - AggregateMode::Partial | AggregateMode::Single => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } - }; - for ((expr, others), filter) in agg - .aggr_expr - .iter() - .zip(all_aggregate_expressions.into_iter()) - .zip(filter_expressions.into_iter()) - { - let n_fields = match agg.mode { - // In partial aggregation, we keep additional fields in order to successfully - // merge aggregation results downstream. - AggregateMode::Partial => expr.state_fields()?.len(), - _ => 1, - }; - // Stores range of each expression: - let aggr_range = Range { - start: start_idx, - end: start_idx + n_fields, - }; - if expr.row_accumulator_supported() { - row_aggregate_expressions.push(others); - row_filter_expressions.push(filter.clone()); - row_agg_indices.push(aggr_range); - row_aggr_expr.push(expr.clone()); - } else { - normal_aggregate_expressions.push(others); - normal_filter_expressions.push(filter.clone()); - normal_agg_indices.push(aggr_range); - normal_aggr_expr.push(expr.clone()); - } - start_idx += n_fields; - } - - let row_accumulators = aggregates::create_row_accumulators(&row_aggr_expr)?; - - let row_aggr_schema = aggr_state_schema(&row_aggr_expr); - - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); - let row_converter = RowConverter::new( - group_schema - .fields() - .iter() - .map(|f| SortField::new(f.data_type().clone())) - .collect(), - )?; - - let row_aggr_layout = Arc::new(RowLayout::new(&row_aggr_schema)); - - let name = format!("GroupedHashAggregateStream[{partition}]"); - let aggr_state = AggregationState { - reservation: MemoryConsumer::new(name).register(context.memory_pool()), - map: RawTable::with_capacity(0), - group_states: Vec::with_capacity(0), - }; - - timer.done(); - - let exec_state = ExecutionState::ReadingInput; - - Ok(GroupedHashAggregateStream { - schema: agg_schema, - input, - mode: agg.mode, - normal_aggr_expr, - normal_aggregate_expressions, - normal_filter_expressions, - row_aggregate_expressions, - row_filter_expressions, - row_accumulators, - row_converter, - row_aggr_schema, - row_aggr_layout, - group_by: agg_group_by, - aggr_state, - exec_state, - baseline_metrics, - random_state: Default::default(), - batch_size, - scalar_update_factor, - row_group_skip_position: 0, - indices: [normal_agg_indices, row_agg_indices], - }) - } -} - -impl Stream for GroupedHashAggregateStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); - - loop { - match self.exec_state { - ExecutionState::ReadingInput => { - match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = self.group_aggregate_batch(batch); - timer.done(); - - // allocate memory - // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with - // overshooting a bit. Also this means we either store the whole record batch or not. - let result = result.and_then(|allocated| { - self.aggr_state.reservation.try_grow(allocated) - }); - - if let Err(e) = result { - return Poll::Ready(Some(Err(e))); - } - } - // inner had error, return to caller - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - // inner is done, producing output - None => { - self.exec_state = ExecutionState::ProducingOutput; - } - } - } - - ExecutionState::ProducingOutput => { - let timer = elapsed_compute.timer(); - let result = self.create_batch_from_map(); - - timer.done(); - self.row_group_skip_position += self.batch_size; - - match result { - // made output - Ok(Some(result)) => { - let batch = result.record_output(&self.baseline_metrics); - return Poll::Ready(Some(Ok(batch))); - } - // end of output - Ok(None) => { - self.exec_state = ExecutionState::Done; - } - // error making output - Err(error) => return Poll::Ready(Some(Err(error))), - } - } - ExecutionState::Done => return Poll::Ready(None), - } - } - } -} - -impl RecordBatchStream for GroupedHashAggregateStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl GroupedHashAggregateStream { - // Update the row_aggr_state according to groub_by values (result of group_by_expressions) - fn update_group_state( - &mut self, - group_values: &[ArrayRef], - allocated: &mut usize, - ) -> Result> { - let group_rows = self.row_converter.convert_columns(group_values)?; - let n_rows = group_rows.num_rows(); - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `aggr_state` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; n_rows]; - create_hashes(group_values, &self.random_state, &mut batch_hashes)?; - - let AggregationState { - map, group_states, .. - } = &mut self.aggr_state; - - for (row, hash) in batch_hashes.into_iter().enumerate() { - let entry = map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let group_state = &group_states[*group_idx]; - group_rows.row(row) == group_state.group_by_values.row() - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut group_states[*group_idx]; - - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; - - group_state.indices.push_accounted(row as u32, allocated); // remember this row - } - // 1.2 Need to create new entry - None => { - let accumulator_set = - aggregates::create_accumulators(&self.normal_aggr_expr)?; - // Add new entry to group_states and save newly created index - let group_state = GroupState { - group_by_values: group_rows.row(row).owned(), - aggregation_buffer: vec![ - 0; - self.row_aggr_layout.fixed_part_width() - ], - accumulator_set, - indices: vec![row as u32], // 1.3 - }; - let group_idx = group_states.len(); - - // NOTE: do NOT include the `GroupState` struct size in here because this is captured by - // `group_states` (see allocation down below) - *allocated += std::mem::size_of_val(&group_state.group_by_values) - + (std::mem::size_of::() - * group_state.aggregation_buffer.capacity()) - + (std::mem::size_of::() * group_state.indices.capacity()); - - // Allocation done by normal accumulators - *allocated += (std::mem::size_of::>() - * group_state.accumulator_set.capacity()) - + group_state - .accumulator_set - .iter() - .map(|accu| accu.size()) - .sum::(); - - // for hasher function, use precomputed hash value - map.insert_accounted( - (hash, group_idx), - |(hash, _group_index)| *hash, - allocated, - ); - - group_states.push_accounted(group_state, allocated); - - groups_with_rows.push(group_idx); - } - }; - } - Ok(groups_with_rows) - } - - // Update the accumulator results, according to row_aggr_state. - #[allow(clippy::too_many_arguments)] - fn update_accumulators_using_batch( - &mut self, - groups_with_rows: &[usize], - offsets: &[usize], - row_values: &[Vec], - normal_values: &[Vec], - row_filter_values: &[Option], - normal_filter_values: &[Option], - allocated: &mut usize, - ) -> Result<()> { - // 2.1 for each key in this batch - // 2.2 for each aggregation - // 2.3 `slice` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - groups_with_rows - .iter() - .zip(offsets.windows(2)) - .try_for_each(|(group_idx, offsets)| { - let group_state = &mut self.aggr_state.group_states[*group_idx]; - // 2.2 - // Process row accumulators - self.row_accumulators - .iter_mut() - .zip(row_values.iter()) - .zip(row_filter_values.iter()) - .try_for_each(|((accumulator, aggr_array), filter_opt)| { - let values = slice_and_maybe_filter( - aggr_array, - filter_opt.as_ref(), - offsets, - )?; - let mut state_accessor = - RowAccessor::new_from_layout(self.row_aggr_layout.clone()); - state_accessor - .point_to(0, group_state.aggregation_buffer.as_mut_slice()); - match self.mode { - AggregateMode::Partial | AggregateMode::Single => { - accumulator.update_batch(&values, &mut state_accessor) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values, &mut state_accessor) - } - } - })?; - // normal accumulators - group_state - .accumulator_set - .iter_mut() - .zip(normal_values.iter()) - .zip(normal_filter_values.iter()) - .try_for_each(|((accumulator, aggr_array), filter_opt)| { - let values = slice_and_maybe_filter( - aggr_array, - filter_opt.as_ref(), - offsets, - )?; - let size_pre = accumulator.size(); - let res = match self.mode { - AggregateMode::Partial | AggregateMode::Single => { - accumulator.update_batch(&values) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values) - } - }; - let size_post = accumulator.size(); - *allocated += size_post.saturating_sub(size_pre); - res - }) - // 2.5 - .and({ - group_state.indices.clear(); - Ok(()) - }) - })?; - Ok(()) - } - - // Update the accumulator results, according to row_aggr_state. - fn update_accumulators_using_scalar( - &mut self, - groups_with_rows: &[usize], - row_values: &[Vec], - row_filter_values: &[Option], - ) -> Result<()> { - let filter_bool_array = row_filter_values - .iter() - .map(|filter_opt| match filter_opt { - Some(f) => Ok(Some(as_boolean_array(f)?)), - None => Ok(None), - }) - .collect::>>()?; - - for group_idx in groups_with_rows { - let group_state = &mut self.aggr_state.group_states[*group_idx]; - let mut state_accessor = - RowAccessor::new_from_layout(self.row_aggr_layout.clone()); - state_accessor.point_to(0, group_state.aggregation_buffer.as_mut_slice()); - for idx in &group_state.indices { - for (accumulator, values_array, filter_array) in izip!( - self.row_accumulators.iter_mut(), - row_values.iter(), - filter_bool_array.iter() - ) { - if values_array.len() == 1 { - let scalar_value = - col_to_scalar(&values_array[0], filter_array, *idx as usize)?; - accumulator.update_scalar(&scalar_value, &mut state_accessor)?; - } else { - let scalar_values = values_array - .iter() - .map(|array| { - col_to_scalar(array, filter_array, *idx as usize) - }) - .collect::>>()?; - accumulator - .update_scalar_values(&scalar_values, &mut state_accessor)?; - } - } - } - // clear the group indices in this group - group_state.indices.clear(); - } - - Ok(()) - } - - /// Perform group-by aggregation for the given [`RecordBatch`]. - /// - /// If successful, this returns the additional number of bytes that were allocated during this process. - /// - fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result { - // Evaluate the grouping expressions: - let group_by_values = evaluate_group_by(&self.group_by, &batch)?; - // Keep track of memory allocated: - let mut allocated = 0usize; - - // Evaluate the aggregation expressions. - // We could evaluate them after the `take`, but since we need to evaluate all - // of them anyways, it is more performant to do it while they are together. - let row_aggr_input_values = - evaluate_many(&self.row_aggregate_expressions, &batch)?; - let normal_aggr_input_values = - evaluate_many(&self.normal_aggregate_expressions, &batch)?; - let row_filter_values = evaluate_optional(&self.row_filter_expressions, &batch)?; - let normal_filter_values = - evaluate_optional(&self.normal_filter_expressions, &batch)?; - - let row_converter_size_pre = self.row_converter.size(); - for group_values in &group_by_values { - let groups_with_rows = - self.update_group_state(group_values, &mut allocated)?; - // Decide the accumulators update mode, use scalar value to update the accumulators when all of the conditions are meet: - // 1) The aggregation mode is Partial or Single - // 2) There is not normal aggregation expressions - // 3) The number of affected groups is high (entries in `aggr_state` have rows need to update). Usually the high cardinality case - if matches!(self.mode, AggregateMode::Partial | AggregateMode::Single) - && normal_aggr_input_values.is_empty() - && normal_filter_values.is_empty() - && groups_with_rows.len() >= batch.num_rows() / self.scalar_update_factor - { - self.update_accumulators_using_scalar( - &groups_with_rows, - &row_aggr_input_values, - &row_filter_values, - )?; - } else { - // Collect all indices + offsets based on keys in this vec - let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0); - let mut offsets = vec![0]; - let mut offset_so_far = 0; - for &group_idx in groups_with_rows.iter() { - let indices = &self.aggr_state.group_states[group_idx].indices; - batch_indices.append_slice(indices); - offset_so_far += indices.len(); - offsets.push(offset_so_far); - } - let batch_indices = batch_indices.finish(); - - let row_values = get_at_indices(&row_aggr_input_values, &batch_indices)?; - let normal_values = - get_at_indices(&normal_aggr_input_values, &batch_indices)?; - let row_filter_values = - get_optional_filters(&row_filter_values, &batch_indices); - let normal_filter_values = - get_optional_filters(&normal_filter_values, &batch_indices); - self.update_accumulators_using_batch( - &groups_with_rows, - &offsets, - &row_values, - &normal_values, - &row_filter_values, - &normal_filter_values, - &mut allocated, - )?; - } - } - allocated += self - .row_converter - .size() - .saturating_sub(row_converter_size_pre); - Ok(allocated) - } -} - -/// The state of all the groups -pub(crate) struct AggregationState { - pub reservation: MemoryReservation, - - /// Logically maps group values to an index in `group_states` - /// - /// Uses the raw API of hashbrown to avoid actually storing the - /// keys in the table - /// - /// keys: u64 hashes of the GroupValue - /// values: (hash, index into `group_states`) - pub map: RawTable<(u64, usize)>, - - /// State for each group - pub group_states: Vec, -} - -impl std::fmt::Debug for AggregationState { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - // hashes are not store inline, so could only get values - let map_string = "RawTable"; - f.debug_struct("AggregationState") - .field("map", &map_string) - .field("group_states", &self.group_states) - .finish() - } -} - -impl GroupedHashAggregateStream { - /// Create a RecordBatch with all group keys and accumulator' states or values. - fn create_batch_from_map(&mut self) -> Result> { - let skip_items = self.row_group_skip_position; - if skip_items > self.aggr_state.group_states.len() { - return Ok(None); - } - if self.aggr_state.group_states.is_empty() { - let schema = self.schema.clone(); - return Ok(Some(RecordBatch::new_empty(schema))); - } - - let end_idx = min( - skip_items + self.batch_size, - self.aggr_state.group_states.len(), - ); - let group_state_chunk = &self.aggr_state.group_states[skip_items..end_idx]; - - if group_state_chunk.is_empty() { - let schema = self.schema.clone(); - return Ok(Some(RecordBatch::new_empty(schema))); - } - - // Buffers for each distinct group (i.e. row accumulator memories) - let mut state_buffers = group_state_chunk - .iter() - .map(|gs| gs.aggregation_buffer.clone()) - .collect::>(); - - let output_fields = self.schema.fields(); - // Store row accumulator results (either final output or intermediate state): - let row_columns = match self.mode { - AggregateMode::Partial => { - read_as_batch(&state_buffers, &self.row_aggr_schema) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single => { - let mut results = vec![]; - for (idx, acc) in self.row_accumulators.iter().enumerate() { - let mut state_accessor = RowAccessor::new(&self.row_aggr_schema); - let current = state_buffers - .iter_mut() - .map(|buffer| { - state_accessor.point_to(0, buffer); - acc.evaluate(&state_accessor) - }) - .collect::>>()?; - // Get corresponding field for row accumulator - let field = &output_fields[self.indices[1][idx].start]; - let result = if current.is_empty() { - Ok(arrow::array::new_empty_array(field.data_type())) - } else { - let item = ScalarValue::iter_to_array(current)?; - // cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - cast(&item, field.data_type()) - }?; - results.push(result); - } - results - } - }; - - // Store normal accumulator results (either final output or intermediate state): - let mut columns = vec![]; - for (idx, &Range { start, end }) in self.indices[0].iter().enumerate() { - for (field_idx, field) in output_fields[start..end].iter().enumerate() { - let current = match self.mode { - AggregateMode::Partial => ScalarValue::iter_to_array( - group_state_chunk.iter().map(|group_state| { - group_state.accumulator_set[idx] - .state() - .map(|v| v[field_idx].clone()) - .expect("Unexpected accumulator state in hash aggregate") - }), - ), - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single => ScalarValue::iter_to_array( - group_state_chunk.iter().map(|group_state| { - group_state.accumulator_set[idx] - .evaluate() - .expect("Unexpected accumulator state in hash aggregate") - }), - ), - }?; - // Cast output if needed (e.g. for types like Dictionary where - // the intermediate GroupByScalar type was not the same as the - // output - let result = cast(¤t, field.data_type())?; - columns.push(result); - } - } - - // Stores the group by fields - let group_buffers = group_state_chunk - .iter() - .map(|gs| gs.group_by_values.row()) - .collect::>(); - let mut output: Vec = self.row_converter.convert_rows(group_buffers)?; - - // The size of the place occupied by row and normal accumulators - let extra: usize = self - .indices - .iter() - .flatten() - .map(|Range { start, end }| end - start) - .sum(); - let empty_arr = new_null_array(&DataType::Null, 1); - output.extend(std::iter::repeat(empty_arr).take(extra)); - - // Write results of both accumulator types to the corresponding location in - // the output schema: - let results = [columns.into_iter(), row_columns.into_iter()]; - for (outer, mut current) in results.into_iter().enumerate() { - for &Range { start, end } in self.indices[outer].iter() { - for item in output.iter_mut().take(end).skip(start) { - *item = current.next().expect("Columns cannot be empty"); - } - } - } - Ok(Some(RecordBatch::try_new(self.schema.clone(), output)?)) - } -} diff --git a/datafusion/core/src/physical_plan/aggregates/utils.rs b/datafusion/core/src/physical_plan/aggregates/utils.rs deleted file mode 100644 index a55464edd145c..0000000000000 --- a/datafusion/core/src/physical_plan/aggregates/utils.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This file contains various utility functions that are common to both -//! batch and streaming aggregation code. - -use crate::physical_plan::aggregates::AccumulatorItem; -use arrow::compute; -use arrow::compute::filter; -use arrow::row::OwnedRow; -use arrow_array::types::UInt32Type; -use arrow_array::{Array, ArrayRef, BooleanArray, PrimitiveArray}; -use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::utils::get_arrayref_at_indices; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::AggregateExpr; -use datafusion_row::reader::{read_row, RowReader}; -use datafusion_row::MutableRecordBatch; -use std::sync::Arc; - -/// This object encapsulates the state that is built for each output group. -#[derive(Debug)] -pub(crate) struct GroupState { - /// The actual group by values, stored sequentially - pub group_by_values: OwnedRow, - - // Accumulator state, stored sequentially - pub aggregation_buffer: Vec, - - // Accumulator state, one for each aggregate that doesn't support row accumulation - pub accumulator_set: Vec, - - /// Scratch space used to collect indices for input rows in a - /// batch that have values to aggregate, reset on each batch. - pub indices: Vec, -} - -#[derive(Debug)] -/// This object tracks the aggregation phase. -pub(crate) enum ExecutionState { - ReadingInput, - ProducingOutput, - Done, -} - -pub(crate) fn aggr_state_schema(aggr_expr: &[Arc]) -> SchemaRef { - let fields = aggr_expr - .iter() - .flat_map(|expr| expr.state_fields().unwrap().into_iter()) - .collect::>(); - Arc::new(Schema::new(fields)) -} - -pub(crate) fn read_as_batch(rows: &[Vec], schema: &Schema) -> Vec { - let mut output = MutableRecordBatch::new(rows.len(), Arc::new(schema.clone())); - let mut row = RowReader::new(schema); - - for data in rows { - row.point_to(0, data); - read_row(&row, &mut output, schema); - } - - output.output_as_columns() -} - -pub(crate) fn get_at_indices( - input_values: &[Vec], - batch_indices: &PrimitiveArray, -) -> Result>> { - input_values - .iter() - .map(|array| get_arrayref_at_indices(array, batch_indices)) - .collect() -} - -pub(crate) fn get_optional_filters( - original_values: &[Option>], - batch_indices: &PrimitiveArray, -) -> Vec>> { - original_values - .iter() - .map(|array| { - array.as_ref().map(|array| { - compute::take( - array.as_ref(), - batch_indices, - None, // None: no index check - ) - .unwrap() - }) - }) - .collect() -} - -pub(crate) fn slice_and_maybe_filter( - aggr_array: &[ArrayRef], - filter_opt: Option<&Arc>, - offsets: &[usize], -) -> Result> { - let (offset, length) = (offsets[0], offsets[1] - offsets[0]); - let sliced_arrays: Vec = aggr_array - .iter() - .map(|array| array.slice(offset, length)) - .collect(); - - if let Some(f) = filter_opt { - let sliced = f.slice(offset, length); - let filter_array = as_boolean_array(&sliced)?; - - sliced_arrays - .iter() - .map(|array| filter(array, filter_array).map_err(DataFusionError::ArrowError)) - .collect() - } else { - Ok(sliced_arrays) - } -} - -/// This method is similar to Scalar::try_from_array except for the Null handling. -/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]. -pub(crate) fn col_to_scalar( - array: &ArrayRef, - filter: &Option<&BooleanArray>, - row_index: usize, -) -> Result { - if array.is_null(row_index) { - return Ok(ScalarValue::Null); - } - if let Some(filter) = filter { - if !filter.value(row_index) { - return Ok(ScalarValue::Null); - } - } - ScalarValue::try_from_array(array, row_index) -} diff --git a/datafusion/core/src/physical_plan/display.rs b/datafusion/core/src/physical_plan/display.rs deleted file mode 100644 index 5f286eed185cf..0000000000000 --- a/datafusion/core/src/physical_plan/display.rs +++ /dev/null @@ -1,202 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Implementation of physical plan display. See -//! [`crate::physical_plan::displayable`] for examples of how to -//! format - -use std::fmt; - -use crate::logical_expr::{StringifiedPlan, ToStringifiedPlan}; - -use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; - -/// Options for controlling how each [`ExecutionPlan`] should format itself -#[derive(Debug, Clone, Copy)] -pub enum DisplayFormatType { - /// Default, compact format. Example: `FilterExec: c12 < 10.0` - Default, -} - -/// Wraps an `ExecutionPlan` with various ways to display this plan -pub struct DisplayableExecutionPlan<'a> { - inner: &'a dyn ExecutionPlan, - /// How to show metrics - show_metrics: ShowMetrics, -} - -impl<'a> DisplayableExecutionPlan<'a> { - /// Create a wrapper around an [`'ExecutionPlan'] which can be - /// pretty printed in a variety of ways - pub fn new(inner: &'a dyn ExecutionPlan) -> Self { - Self { - inner, - show_metrics: ShowMetrics::None, - } - } - - /// Create a wrapper around an [`'ExecutionPlan'] which can be - /// pretty printed in a variety of ways that also shows aggregated - /// metrics - pub fn with_metrics(inner: &'a dyn ExecutionPlan) -> Self { - Self { - inner, - show_metrics: ShowMetrics::Aggregated, - } - } - - /// Create a wrapper around an [`'ExecutionPlan'] which can be - /// pretty printed in a variety of ways that also shows all low - /// level metrics - pub fn with_full_metrics(inner: &'a dyn ExecutionPlan) -> Self { - Self { - inner, - show_metrics: ShowMetrics::Full, - } - } - - /// Return a `format`able structure that produces a single line - /// per node. - /// - /// ```text - /// ProjectionExec: expr=[a] - /// CoalesceBatchesExec: target_batch_size=8192 - /// FilterExec: a < 5 - /// RepartitionExec: partitioning=RoundRobinBatch(16) - /// CsvExec: source=...", - /// ``` - pub fn indent(&self) -> impl fmt::Display + 'a { - struct Wrapper<'a> { - plan: &'a dyn ExecutionPlan, - show_metrics: ShowMetrics, - } - impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let t = DisplayFormatType::Default; - let mut visitor = IndentVisitor { - t, - f, - indent: 0, - show_metrics: self.show_metrics, - }; - accept(self.plan, &mut visitor) - } - } - Wrapper { - plan: self.inner, - show_metrics: self.show_metrics, - } - } - - /// Return a single-line summary of the root of the plan - /// Example: `ProjectionExec: expr=[a@0 as a]`. - pub fn one_line(&self) -> impl fmt::Display + 'a { - struct Wrapper<'a> { - plan: &'a dyn ExecutionPlan, - show_metrics: ShowMetrics, - } - - impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut visitor = IndentVisitor { - f, - t: DisplayFormatType::Default, - indent: 0, - show_metrics: self.show_metrics, - }; - visitor.pre_visit(self.plan)?; - Ok(()) - } - } - - Wrapper { - plan: self.inner, - show_metrics: self.show_metrics, - } - } -} - -#[derive(Debug, Clone, Copy)] -enum ShowMetrics { - /// Do not show any metrics - None, - - /// Show aggregrated metrics across partition - Aggregated, - - /// Show full per-partition metrics - Full, -} - -/// Formats plans with a single line per node. -struct IndentVisitor<'a, 'b> { - /// How to format each node - t: DisplayFormatType, - /// Write to this formatter - f: &'a mut fmt::Formatter<'b>, - /// Indent size - indent: usize, - /// How to show metrics - show_metrics: ShowMetrics, -} - -impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { - type Error = fmt::Error; - fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { - write!(self.f, "{:indent$}", "", indent = self.indent * 2)?; - plan.fmt_as(self.t, self.f)?; - match self.show_metrics { - ShowMetrics::None => {} - ShowMetrics::Aggregated => { - if let Some(metrics) = plan.metrics() { - let metrics = metrics - .aggregate_by_name() - .sorted_for_display() - .timestamps_removed(); - - write!(self.f, ", metrics=[{metrics}]")?; - } else { - write!(self.f, ", metrics=[]")?; - } - } - ShowMetrics::Full => { - if let Some(metrics) = plan.metrics() { - write!(self.f, ", metrics=[{metrics}]")?; - } else { - write!(self.f, ", metrics=[]")?; - } - } - } - writeln!(self.f)?; - self.indent += 1; - Ok(true) - } - - fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { - self.indent -= 1; - Ok(true) - } -} - -impl<'a> ToStringifiedPlan for DisplayableExecutionPlan<'a> { - fn to_stringified( - &self, - plan_type: crate::logical_expr::PlanType, - ) -> StringifiedPlan { - StringifiedPlan::new(plan_type, self.indent().to_string()) - } -} diff --git a/datafusion/core/src/physical_plan/filter.rs b/datafusion/core/src/physical_plan/filter.rs deleted file mode 100644 index a6f00846de8ff..0000000000000 --- a/datafusion/core/src/physical_plan/filter.rs +++ /dev/null @@ -1,671 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! FilterExec evaluates a boolean predicate against all input batches to determine which rows to -//! include in its output batches. - -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use super::expressions::PhysicalSortExpr; -use super::{ColumnStatistics, RecordBatchStream, SendableRecordBatchStream, Statistics}; -use crate::physical_plan::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - Column, DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, - PhysicalExpr, -}; -use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::{split_conjunction, AnalysisContext}; - -use log::trace; - -use datafusion_execution::TaskContext; -use futures::stream::{Stream, StreamExt}; - -/// FilterExec evaluates a boolean predicate against all input batches to determine which rows to -/// include in its output batches. -#[derive(Debug)] -pub struct FilterExec { - /// The expression to filter on. This expression must evaluate to a boolean value. - predicate: Arc, - /// The input plan - input: Arc, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, -} - -impl FilterExec { - /// Create a FilterExec on an input - pub fn try_new( - predicate: Arc, - input: Arc, - ) -> Result { - match predicate.data_type(input.schema().as_ref())? { - DataType::Boolean => Ok(Self { - predicate, - input: input.clone(), - metrics: ExecutionPlanMetricsSet::new(), - }), - other => Err(DataFusionError::Plan(format!( - "Filter predicate must return boolean values, not {other:?}" - ))), - } - } - - /// The expression to filter on. This expression must evaluate to a boolean value. - pub fn predicate(&self) -> &Arc { - &self.predicate - } - - /// The input plan - pub fn input(&self) -> &Arc { - &self.input - } -} - -impl ExecutionPlan for FilterExec { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - // The filter operator does not make any changes to the schema of its input - self.input.schema() - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - /// Get the output partitioning of this plan - fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() - } - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns an error to indicate this. - fn unbounded_output(&self, children: &[bool]) -> Result { - Ok(children[0]) - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.input.output_ordering() - } - - fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input - vec![true] - } - - fn equivalence_properties(&self) -> EquivalenceProperties { - // Combine the equal predicates with the input equivalence properties - let mut input_properties = self.input.equivalence_properties(); - let (equal_pairs, _ne_pairs) = collect_columns_from_predicate(&self.predicate); - for new_condition in equal_pairs { - input_properties.add_equal_conditions(new_condition) - } - input_properties - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(FilterExec::try_new( - self.predicate.clone(), - children[0].clone(), - )?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - trace!("Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - Ok(Box::pin(FilterExecStream { - schema: self.input.schema(), - predicate: self.predicate.clone(), - input: self.input.execute(partition, context)?, - baseline_metrics, - })) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "FilterExec: {}", self.predicate) - } - } - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - /// The output statistics of a filtering operation can be estimated if the - /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Statistics { - let input_stats = self.input.statistics(); - let starter_ctx = - AnalysisContext::from_statistics(self.input.schema().as_ref(), &input_stats); - - let analysis_ctx = self.predicate.analyze(starter_ctx); - - match analysis_ctx.boundaries { - Some(boundaries) => { - // Build back the column level statistics from the boundaries inside the - // analysis context. It is possible that these are going to be different - // than the input statistics, especially when a comparison is made inside - // the predicate expression (e.g. `col1 > 100`). - let column_statistics = analysis_ctx - .column_boundaries - .iter() - .map(|boundary| match boundary { - Some(boundary) => ColumnStatistics { - min_value: Some(boundary.min_value.clone()), - max_value: Some(boundary.max_value.clone()), - ..Default::default() - }, - None => ColumnStatistics::default(), - }) - .collect(); - - Statistics { - num_rows: input_stats.num_rows.zip(boundaries.selectivity).map( - |(num_rows, selectivity)| { - (num_rows as f64 * selectivity).ceil() as usize - }, - ), - total_byte_size: input_stats - .total_byte_size - .zip(boundaries.selectivity) - .map(|(num_rows, selectivity)| { - (num_rows as f64 * selectivity).ceil() as usize - }), - column_statistics: Some(column_statistics), - ..Default::default() - } - } - None => Statistics::default(), - } - } -} - -/// The FilterExec streams wraps the input iterator and applies the predicate expression to -/// determine which rows to include in its output batches -struct FilterExecStream { - /// Output schema, which is the same as the input schema for this operator - schema: SchemaRef, - /// The expression to filter on. This expression must evaluate to a boolean value. - predicate: Arc, - /// The input partition to filter. - input: SendableRecordBatchStream, - /// runtime metrics recording - baseline_metrics: BaselineMetrics, -} - -pub(crate) fn batch_filter( - batch: &RecordBatch, - predicate: &Arc, -) -> Result { - predicate - .evaluate(batch) - .map(|v| v.into_array(batch.num_rows())) - .and_then(|array| { - Ok(as_boolean_array(&array)?) - // apply filter array to record batch - .and_then(|filter_array| Ok(filter_record_batch(batch, filter_array)?)) - }) -} - -impl Stream for FilterExecStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let poll; - loop { - match self.input.poll_next_unpin(cx) { - Poll::Ready(value) => match value { - Some(Ok(batch)) => { - let timer = self.baseline_metrics.elapsed_compute().timer(); - let filtered_batch = batch_filter(&batch, &self.predicate)?; - // skip entirely filtered batches - if filtered_batch.num_rows() == 0 { - continue; - } - timer.done(); - poll = Poll::Ready(Some(Ok(filtered_batch))); - break; - } - _ => { - poll = Poll::Ready(value); - break; - } - }, - Poll::Pending => { - poll = Poll::Pending; - break; - } - } - } - self.baseline_metrics.record_poll(poll) - } - - fn size_hint(&self) -> (usize, Option) { - // same number of record batches - self.input.size_hint() - } -} - -impl RecordBatchStream for FilterExecStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -/// Return the equals Column-Pairs and Non-equals Column-Pairs -fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { - let mut eq_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); - let mut ne_predicate_columns: Vec<(&Column, &Column)> = Vec::new(); - - let predicates = split_conjunction(predicate); - predicates.into_iter().for_each(|p| { - if let Some(binary) = p.as_any().downcast_ref::() { - let left = binary.left(); - let right = binary.right(); - if left.as_any().is::() && right.as_any().is::() { - let left_column = left.as_any().downcast_ref::().unwrap(); - let right_column = right.as_any().downcast_ref::().unwrap(); - match binary.op() { - Operator::Eq => { - eq_predicate_columns.push((left_column, right_column)) - } - Operator::NotEq => { - ne_predicate_columns.push((left_column, right_column)) - } - _ => {} - } - } - } - }); - - (eq_predicate_columns, ne_predicate_columns) -} -/// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates -pub type EqualAndNonEqual<'a> = - (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); - -#[cfg(test)] -mod tests { - - use super::*; - use crate::physical_plan::expressions::*; - use crate::physical_plan::ExecutionPlan; - use crate::physical_plan::{collect, with_new_children_if_necessary}; - use crate::prelude::SessionContext; - use crate::test; - use crate::test::exec::StatisticsExec; - use crate::test_util; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::utils::DataPtr; - use datafusion_common::ColumnStatistics; - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; - use std::iter::Iterator; - use std::sync::Arc; - - #[tokio::test] - async fn simple_predicate() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - - let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; - - let predicate: Arc = binary( - binary(col("c2", &schema)?, Operator::Gt, lit(1u32), &schema)?, - Operator::And, - binary(col("c2", &schema)?, Operator::Lt, lit(4u32), &schema)?, - &schema, - )?; - - let filter: Arc = - Arc::new(FilterExec::try_new(predicate, csv)?); - - let results = collect(filter, task_ctx).await?; - - results - .iter() - .for_each(|batch| assert_eq!(13, batch.num_columns())); - let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); - assert_eq!(41, row_count); - - Ok(()) - } - - #[tokio::test] - async fn with_new_children() -> Result<()> { - let schema = test_util::aggr_test_schema(); - let partitions = 4; - let input = test::scan_partitioned_csv(partitions)?; - - let predicate: Arc = - binary(col("c2", &schema)?, Operator::Gt, lit(1u32), &schema)?; - - let filter: Arc = - Arc::new(FilterExec::try_new(predicate, input.clone())?); - - let new_filter = filter.clone().with_new_children(vec![input.clone()])?; - assert!(!Arc::data_ptr_eq(&filter, &new_filter)); - - let new_filter2 = - with_new_children_if_necessary(filter.clone(), vec![input])?.into(); - assert!(Arc::data_ptr_eq(&filter, &new_filter2)); - - Ok(()) - } - - #[tokio::test] - async fn collect_columns_predicates() -> Result<()> { - let schema = test_util::aggr_test_schema(); - let predicate: Arc = binary( - binary( - binary(col("c2", &schema)?, Operator::GtEq, lit(1u32), &schema)?, - Operator::And, - binary(col("c2", &schema)?, Operator::Eq, lit(4u32), &schema)?, - &schema, - )?, - Operator::And, - binary( - binary( - col("c2", &schema)?, - Operator::Eq, - col("c9", &schema)?, - &schema, - )?, - Operator::And, - binary( - col("c1", &schema)?, - Operator::NotEq, - col("c13", &schema)?, - &schema, - )?, - &schema, - )?, - &schema, - )?; - - let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); - - assert_eq!(1, equal_pairs.len()); - assert_eq!(equal_pairs[0].0.name(), "c2"); - assert_eq!(equal_pairs[0].1.name(), "c9"); - - assert_eq!(1, ne_pairs.len()); - assert_eq!(ne_pairs[0].0.name(), "c1"); - assert_eq!(ne_pairs[0].1.name(), "c13"); - - Ok(()) - } - - #[tokio::test] - async fn test_filter_statistics_basic_expr() -> Result<()> { - // Table: - // a: min=1, max=100 - let bytes_per_row = 4; - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let input = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Some(100), - total_byte_size: Some(100 * bytes_per_row), - column_statistics: Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), - ..Default::default() - }]), - ..Default::default() - }, - schema.clone(), - )); - - // a <= 25 - let predicate: Arc = - binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?; - - // WHERE a <= 25 - let filter: Arc = - Arc::new(FilterExec::try_new(predicate, input)?); - - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, Some(25)); - assert_eq!(statistics.total_byte_size, Some(25 * bytes_per_row)); - - Ok(()) - } - - #[tokio::test] - async fn test_filter_statistics_column_level_basic_expr() -> Result<()> { - // Table: - // a: min=1, max=100 - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let input = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Some(100), - column_statistics: Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), - ..Default::default() - }]), - ..Default::default() - }, - schema.clone(), - )); - - // a <= 25 - let predicate: Arc = - binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?; - - // WHERE a <= 25 - let filter: Arc = - Arc::new(FilterExec::try_new(predicate, input)?); - - let statistics = filter.statistics(); - - // a must be in [1, 25] range now! - assert_eq!(statistics.num_rows, Some(25)); - assert_eq!( - statistics.column_statistics, - Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(25))), - ..Default::default() - }]) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_filter_statistics_column_level_nested() -> Result<()> { - // Table: - // a: min=1, max=100 - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let input = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Some(100), - column_statistics: Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), - ..Default::default() - }]), - ..Default::default() - }, - schema.clone(), - )); - - // WHERE a <= 25 - let sub_filter: Arc = Arc::new(FilterExec::try_new( - binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?, - input, - )?); - - // Nested filters (two separate physical plans, instead of AND chain in the expr) - // WHERE a >= 10 - // WHERE a <= 25 - let filter: Arc = Arc::new(FilterExec::try_new( - binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, - sub_filter, - )?); - - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, Some(16)); - assert_eq!( - statistics.column_statistics, - Some(vec![ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(10))), - max_value: Some(ScalarValue::Int32(Some(25))), - ..Default::default() - }]) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_filter_statistics_column_level_nested_multiple() -> Result<()> { - // Table: - // a: min=1, max=100 - // b: min=1, max=50 - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - let input = Arc::new(StatisticsExec::new( - Statistics { - num_rows: Some(100), - column_statistics: Some(vec![ - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), - ..Default::default() - }, - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(50))), - ..Default::default() - }, - ]), - ..Default::default() - }, - schema.clone(), - )); - - // WHERE a <= 25 - let a_lte_25: Arc = Arc::new(FilterExec::try_new( - binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?, - input, - )?); - - // WHERE b > 45 - let b_gt_5: Arc = Arc::new(FilterExec::try_new( - binary(col("b", &schema)?, Operator::Gt, lit(45i32), &schema)?, - a_lte_25, - )?); - - // WHERE a >= 10 - let filter: Arc = Arc::new(FilterExec::try_new( - binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, - b_gt_5, - )?); - - let statistics = filter.statistics(); - // On a uniform distribution, only fifteen rows will satisfy the - // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only - // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). - // - // Which would result with a selectivity of '15/100 * 5/50' or 0.015 - // and that means about %1.5 of the all rows (rounded up to 2 rows). - assert_eq!(statistics.num_rows, Some(2)); - assert_eq!( - statistics.column_statistics, - Some(vec![ - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(10))), - max_value: Some(ScalarValue::Int32(Some(25))), - ..Default::default() - }, - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(45))), - max_value: Some(ScalarValue::Int32(Some(50))), - ..Default::default() - } - ]) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_filter_statistics_when_input_stats_missing() -> Result<()> { - // Table: - // a: min=???, max=??? (missing) - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let input = Arc::new(StatisticsExec::new( - Statistics { - column_statistics: Some(vec![ColumnStatistics { - ..Default::default() - }]), - ..Default::default() - }, - schema.clone(), - )); - - // a <= 25 - let predicate: Arc = - binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?; - - // WHERE a <= 25 - let filter: Arc = - Arc::new(FilterExec::try_new(predicate, input)?); - - let statistics = filter.statistics(); - assert_eq!(statistics.num_rows, None); - - Ok(()) - } -} diff --git a/datafusion/core/src/physical_plan/insert.rs b/datafusion/core/src/physical_plan/insert.rs deleted file mode 100644 index f3bd701a565de..0000000000000 --- a/datafusion/core/src/physical_plan/insert.rs +++ /dev/null @@ -1,221 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Execution plan for writing data to [`DataSink`]s - -use super::expressions::PhysicalSortExpr; -use super::{ - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, -}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use arrow_array::{ArrayRef, UInt64Array}; -use arrow_schema::{DataType, Field, Schema}; -use async_trait::async_trait; -use core::fmt; -use datafusion_common::Result; -use datafusion_physical_expr::PhysicalSortRequirement; -use futures::StreamExt; -use std::any::Any; -use std::fmt::{Debug, Display}; -use std::sync::Arc; - -use crate::physical_plan::stream::RecordBatchStreamAdapter; -use crate::physical_plan::Distribution; -use datafusion_common::DataFusionError; -use datafusion_execution::TaskContext; - -/// `DataSink` implements writing streams of [`RecordBatch`]es to -/// user defined destinations. -/// -/// The `Display` impl is used to format the sink for explain plan -/// output. -#[async_trait] -pub trait DataSink: Display + Debug + Send + Sync { - // TODO add desired input ordering - // How does this sink want its input ordered? - - /// Writes the data to the sink, returns the number of values written - /// - /// This method will be called exactly once during each DML - /// statement. Thus prior to return, the sink should do any commit - /// or rollback required. - async fn write_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result; -} - -/// Execution plan for writing record batches to a [`DataSink`] -/// -/// Returns a single row with the number of values written -pub struct InsertExec { - /// Input plan that produces the record batches to be written. - input: Arc, - /// Sink to whic to write - sink: Arc, - /// Schema describing the structure of the data. - schema: SchemaRef, -} - -impl fmt::Debug for InsertExec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "InsertExec schema: {:?}", self.schema) - } -} - -impl InsertExec { - /// Create a plan to write to `sink` - pub fn new(input: Arc, sink: Arc) -> Self { - Self { - input, - sink, - schema: make_count_schema(), - } - } -} - -impl ExecutionPlan for InsertExec { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(1) - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn required_input_distribution(&self) -> Vec { - vec![Distribution::SinglePartition] - } - - fn required_input_ordering(&self) -> Vec>> { - // Require that the InsertExec gets the data in the order the - // input produced it (otherwise the optimizer may chose to reorder - // the input which could result in unintended / poor UX) - // - // More rationale: - // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 - vec![self - .input - .output_ordering() - .map(PhysicalSortRequirement::from_sort_exprs)] - } - - fn maintains_input_order(&self) -> Vec { - vec![false] - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(Self { - input: children[0].clone(), - sink: self.sink.clone(), - schema: self.schema.clone(), - })) - } - - /// Execute the plan and return a stream of `RecordBatch`es for - /// the specified partition. - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - if partition != 0 { - return Err(DataFusionError::Internal( - format!("Invalid requested partition {partition}. InsertExec requires a single input partition." - ))); - } - - // Execute each of our own input's partitions and pass them to the sink - let input_partition_count = self.input.output_partitioning().partition_count(); - if input_partition_count != 1 { - return Err(DataFusionError::Internal(format!( - "Invalid input partition count {input_partition_count}. \ - InsertExec needs only a single partition." - ))); - } - - let data = self.input.execute(0, context.clone())?; - let schema = self.schema.clone(); - let sink = self.sink.clone(); - - let stream = futures::stream::once(async move { - sink.write_all(data, &context).await.map(make_count_batch) - }) - .boxed(); - - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "InsertExec: sink={}", self.sink) - } - } - } - - fn statistics(&self) -> Statistics { - Statistics::default() - } -} - -/// Create a output record batch with a count -/// -/// ```text -/// +-------+, -/// | count |, -/// +-------+, -/// | 6 |, -/// +-------+, -/// ``` -fn make_count_batch(count: u64) -> RecordBatch { - let array = Arc::new(UInt64Array::from(vec![count])) as ArrayRef; - - RecordBatch::try_from_iter_with_nullable(vec![("count", array, false)]).unwrap() -} - -fn make_count_schema() -> SchemaRef { - // define a schema. - Arc::new(Schema::new(vec![Field::new( - "count", - DataType::UInt64, - false, - )])) -} diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs deleted file mode 100644 index 992de86dfe177..0000000000000 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ /dev/null @@ -1,667 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This file contains common subroutines for regular and symmetric hash join -//! related functionality, used both in join calculations and optimization rules. - -use std::collections::HashMap; -use std::sync::Arc; -use std::{fmt, usize}; - -use arrow::datatypes::SchemaRef; - -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::Interval; -use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use hashbrown::raw::RawTable; -use smallvec::SmallVec; - -use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; -use datafusion_common::Result; - -// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -// -// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used -// to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// TODO: speed up collision check and move away from using a hashbrown HashMap -// https://github.com/apache/arrow-datafusion/issues/50 -pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>); - -impl JoinHashMap { - /// In this implementation, the scale_factor variable determines how conservative the shrinking strategy is. - /// The value of scale_factor is set to 4, which means the capacity will be reduced by 25% - /// when necessary. You can adjust the scale_factor value to achieve the desired - /// ,balance between memory usage and performance. - // - // If you increase the scale_factor, the capacity will shrink less aggressively, - // leading to potentially higher memory usage but fewer resizes. - // Conversely, if you decrease the scale_factor, the capacity will shrink more aggressively, - // potentially leading to lower memory usage but more frequent resizing. - pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) { - let capacity = self.0.capacity(); - let len = self.0.len(); - - if capacity > scale_factor * len { - let new_capacity = (capacity * (scale_factor - 1)) / scale_factor; - self.0.shrink_to(new_capacity, |(hash, _)| *hash) - } - } - - pub(crate) fn size(&self) -> usize { - self.0.allocation_info().1.size() - } -} - -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) - } -} - -fn check_filter_expr_contains_sort_information( - expr: &Arc, - reference: &Arc, -) -> bool { - expr.eq(reference) - || expr - .children() - .iter() - .any(|e| check_filter_expr_contains_sort_information(e, reference)) -} - -/// Create a one to one mapping from main columns to filter columns using -/// filter column indices. A column index looks like: -/// ```text -/// ColumnIndex { -/// index: 0, // field index in main schema -/// side: JoinSide::Left, // child side -/// } -/// ``` -pub fn map_origin_col_to_filter_col( - filter: &JoinFilter, - schema: &SchemaRef, - side: &JoinSide, -) -> Result> { - let filter_schema = filter.schema(); - let mut col_to_col_map: HashMap = HashMap::new(); - for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { - if index.side.eq(side) { - // Get the main field from column index: - let main_field = schema.field(index.index); - // Create a column expression: - let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?; - // Since the order of by filter.column_indices() is the same with - // that of intermediate schema fields, we can get the column directly. - let filter_field = filter_schema.field(filter_schema_index); - let filter_col = Column::new(filter_field.name(), filter_schema_index); - // Insert mapping: - col_to_col_map.insert(main_col, filter_col); - } - } - Ok(col_to_col_map) -} - -/// This function analyzes [`PhysicalSortExpr`] graphs with respect to monotonicity -/// (sorting) properties. This is necessary since monotonically increasing and/or -/// decreasing expressions are required when using join filter expressions for -/// data pruning purposes. -/// -/// The method works as follows: -/// 1. Maps the original columns to the filter columns using the [`map_origin_col_to_filter_col`] function. -/// 2. Collects all columns in the sort expression using the [`collect_columns`] function. -/// 3. Checks if all columns are included in the map we obtain in the first step. -/// 4. If all columns are included, the sort expression is converted into a filter expression using -/// the [`convert_filter_columns`] function. -/// 5. Searches for the converted filter expression in the filter expression using the -/// [`check_filter_expr_contains_sort_information`] function. -/// 6. If an exact match is found, returns the converted filter expression as [`Some(Arc)`]. -/// 7. If all columns are not included or an exact match is not found, returns [`None`]. -/// -/// Examples: -/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". -/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. -/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. -/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, -/// there is no exact match, so this expression does not indicate pruning. -pub fn convert_sort_expr_with_filter_schema( - side: &JoinSide, - filter: &JoinFilter, - schema: &SchemaRef, - sort_expr: &PhysicalSortExpr, -) -> Result>> { - let column_map = map_origin_col_to_filter_col(filter, schema, side)?; - let expr = sort_expr.expr.clone(); - // Get main schema columns: - let expr_columns = collect_columns(&expr); - // Calculation is possible with `column_map` since sort exprs belong to a child. - let all_columns_are_included = - expr_columns.iter().all(|col| column_map.contains_key(col)); - if all_columns_are_included { - // Since we are sure that one to one column mapping includes all columns, we convert - // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { - Some(transformed) => Transformed::Yes(transformed), - None => Transformed::No(p), - } - }) - })?; - // Search the converted `PhysicalExpr` in filter expression; if an exact - // match is found, use this sorted expression in graph traversals. - if check_filter_expr_contains_sort_information( - filter.expression(), - &converted_filter_expr, - ) { - return Ok(Some(converted_filter_expr)); - } - } - Ok(None) -} - -/// This function is used to build the filter expression based on the sort order of input columns. -/// -/// It first calls the [`convert_sort_expr_with_filter_schema`] method to determine if the sort -/// order of columns can be used in the filter expression. If it returns a [`Some`] value, the -/// method wraps the result in a [`SortedFilterExpr`] instance with the original sort expression and -/// the converted filter expression. Otherwise, this function returns an error. -/// -/// The `SortedFilterExpr` instance contains information about the sort order of columns that can -/// be used in the filter expression, which can be used to optimize the query execution process. -pub fn build_filter_input_order( - side: JoinSide, - filter: &JoinFilter, - schema: &SchemaRef, - order: &PhysicalSortExpr, -) -> Result> { - let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?; - Ok(opt_expr.map(|filter_expr| SortedFilterExpr::new(order.clone(), filter_expr))) -} - -/// Convert a physical expression into a filter expression using the given -/// column mapping information. -fn convert_filter_columns( - input: &dyn PhysicalExpr, - column_map: &HashMap, -) -> Result>> { - // Attempt to downcast the input expression to a Column type. - Ok(if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } else { - // If the downcast fails, return the input expression as is. - None - }) -} - -/// The [SortedFilterExpr] object represents a sorted filter expression. It -/// contains the following information: The origin expression, the filter -/// expression, an interval encapsulating expression bounds, and a stable -/// index identifying the expression in the expression DAG. -/// -/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides -/// and uses new column names. In this process, a column exchange is done so -/// we can utilize sorting information while traversing the filter expression -/// DAG for interval calculations. When evaluating the inner buffer, we use -/// `origin_sorted_expr`. -#[derive(Debug, Clone)] -pub struct SortedFilterExpr { - /// Sorted expression from a join side (i.e. a child of the join) - origin_sorted_expr: PhysicalSortExpr, - /// Expression adjusted for filter schema. - filter_expr: Arc, - /// Interval containing expression bounds - interval: Interval, - /// Node index in the expression DAG - node_index: usize, -} - -impl SortedFilterExpr { - /// Constructor - pub fn new( - origin_sorted_expr: PhysicalSortExpr, - filter_expr: Arc, - ) -> Self { - Self { - origin_sorted_expr, - filter_expr, - interval: Interval::default(), - node_index: 0, - } - } - /// Get origin expr information - pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { - &self.origin_sorted_expr - } - /// Get filter expr information - pub fn filter_expr(&self) -> &Arc { - &self.filter_expr - } - /// Get interval information - pub fn interval(&self) -> &Interval { - &self.interval - } - /// Sets interval - pub fn set_interval(&mut self, interval: Interval) { - self.interval = interval; - } - /// Node index in ExprIntervalGraph - pub fn node_index(&self) -> usize { - self.node_index - } - /// Node index setter in ExprIntervalGraph - pub fn set_node_index(&mut self, node_index: usize) { - self.node_index = node_index; - } -} - -#[cfg(test)] -pub mod tests { - use super::*; - use crate::physical_plan::{ - expressions::Column, - expressions::PhysicalSortExpr, - joins::utils::{ColumnIndex, JoinFilter, JoinSide}, - }; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, cast, col, lit}; - use smallvec::smallvec; - use std::sync::Arc; - - /// Filter expr for a + b > c + 10 AND a + b < c + 100 - pub(crate) fn complicated_filter( - filter_schema: &Schema, - ) -> Result> { - let left_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Gt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(10))), - filter_schema, - )?, - filter_schema, - )?; - - let right_expr = binary( - cast( - binary( - col("0", filter_schema)?, - Operator::Plus, - col("1", filter_schema)?, - filter_schema, - )?, - filter_schema, - DataType::Int64, - )?, - Operator::Lt, - binary( - cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, - Operator::Plus, - lit(ScalarValue::Int64(Some(100))), - filter_schema, - )?, - filter_schema, - )?; - binary(left_expr, Operator::And, right_expr, filter_schema) - } - - #[test] - fn test_column_exchange() -> Result<()> { - let left_child_schema = - Schema::new(vec![Field::new("left_1", DataType::Int32, true)]); - // Sorting information for the left side: - let left_child_sort_expr = PhysicalSortExpr { - expr: col("left_1", &left_child_schema)?, - options: SortOptions::default(), - }; - - let right_child_schema = Schema::new(vec![ - Field::new("right_1", DataType::Int32, true), - Field::new("right_2", DataType::Int32, true), - ]); - // Sorting information for the right side: - let right_child_sort_expr = PhysicalSortExpr { - expr: binary( - col("right_1", &right_child_schema)?, - Operator::Plus, - col("right_2", &right_child_schema)?, - &right_child_schema, - )?, - options: SortOptions::default(), - }; - - let intermediate_schema = Schema::new(vec![ - Field::new("filter_1", DataType::Int32, true), - Field::new("filter_2", DataType::Int32, true), - Field::new("filter_3", DataType::Int32, true), - ]); - // Our filter expression is: left_1 > right_1 + right_2. - let filter_left = col("filter_1", &intermediate_schema)?; - let filter_right = binary( - col("filter_2", &intermediate_schema)?, - Operator::Plus, - col("filter_3", &intermediate_schema)?, - &intermediate_schema, - )?; - let filter_expr = binary( - filter_left.clone(), - Operator::Gt, - filter_right.clone(), - &intermediate_schema, - )?; - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ColumnIndex { - index: 1, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - let left_sort_filter_expr = build_filter_input_order( - JoinSide::Left, - &filter, - &Arc::new(left_child_schema), - &left_child_sort_expr, - )? - .unwrap(); - assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); - - let right_sort_filter_expr = build_filter_input_order( - JoinSide::Right, - &filter, - &Arc::new(right_child_schema), - &right_child_sort_expr, - )? - .unwrap(); - assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); - - // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`: - assert!(filter_left.eq(left_sort_filter_expr.filter_expr())); - // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`: - assert!(filter_right.eq(right_sort_filter_expr.filter_expr())); - Ok(()) - } - - #[test] - fn test_column_collector() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let filter_expr = complicated_filter(&schema)?; - let columns = collect_columns(&filter_expr); - assert_eq!(columns.len(), 3); - Ok(()) - } - - #[test] - fn find_expr_inside_expr() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let filter_expr = complicated_filter(&schema)?; - - let expr_1 = Arc::new(Column::new("gnz", 0)) as _; - assert!(!check_filter_expr_contains_sort_information( - &filter_expr, - &expr_1 - )); - - let expr_2 = col("1", &schema)? as _; - - assert!(check_filter_expr_contains_sort_information( - &filter_expr, - &expr_2 - )); - - let expr_3 = cast( - binary( - col("0", &schema)?, - Operator::Plus, - col("1", &schema)?, - &schema, - )?, - &schema, - DataType::Int64, - )?; - - assert!(check_filter_expr_contains_sort_information( - &filter_expr, - &expr_3 - )); - - let expr_4 = Arc::new(Column::new("1", 42)) as _; - - assert!(!check_filter_expr_contains_sort_information( - &filter_expr, - &expr_4, - )); - Ok(()) - } - - #[test] - fn build_sorted_expr() -> Result<()> { - let left_schema = Schema::new(vec![ - Field::new("la1", DataType::Int32, false), - Field::new("lb1", DataType::Int32, false), - Field::new("lc1", DataType::Int32, false), - Field::new("lt1", DataType::Int32, false), - Field::new("la2", DataType::Int32, false), - Field::new("la1_des", DataType::Int32, false), - ]); - - let right_schema = Schema::new(vec![ - Field::new("ra1", DataType::Int32, false), - Field::new("rb1", DataType::Int32, false), - Field::new("rc1", DataType::Int32, false), - Field::new("rt1", DataType::Int32, false), - Field::new("ra2", DataType::Int32, false), - Field::new("ra1_des", DataType::Int32, false), - ]); - - let intermediate_schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let filter_expr = complicated_filter(&intermediate_schema)?; - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 4, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - let left_schema = Arc::new(left_schema); - let right_schema = Arc::new(right_schema); - - assert!(build_filter_input_order( - JoinSide::Left, - &filter, - &left_schema, - &PhysicalSortExpr { - expr: col("la1", left_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_some()); - assert!(build_filter_input_order( - JoinSide::Left, - &filter, - &left_schema, - &PhysicalSortExpr { - expr: col("lt1", left_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_none()); - assert!(build_filter_input_order( - JoinSide::Right, - &filter, - &right_schema, - &PhysicalSortExpr { - expr: col("ra1", right_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_some()); - assert!(build_filter_input_order( - JoinSide::Right, - &filter, - &right_schema, - &PhysicalSortExpr { - expr: col("rb1", right_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_none()); - - Ok(()) - } - - // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b". - #[test] - fn sorted_filter_expr_build() -> Result<()> { - let intermediate_schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - ]); - let filter_expr = binary( - col("0", &intermediate_schema)?, - Operator::Minus, - col("1", &intermediate_schema)?, - &intermediate_schema, - )?; - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 1, - side: JoinSide::Left, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let sorted = PhysicalSortExpr { - expr: binary( - col("a", &schema)?, - Operator::Plus, - col("b", &schema)?, - &schema, - )?, - options: SortOptions::default(), - }; - - let res = convert_sort_expr_with_filter_schema( - &JoinSide::Left, - &filter, - &Arc::new(schema), - &sorted, - )?; - assert!(res.is_none()); - Ok(()) - } - - #[test] - fn test_shrink_if_necessary() { - let scale_factor = 4; - let mut join_hash_map = JoinHashMap(RawTable::with_capacity(100)); - let data_size = 2000; - let deleted_part = 3 * data_size / 4; - // Add elements to the JoinHashMap - for hash_value in 0..data_size { - join_hash_map.0.insert( - hash_value, - (hash_value, smallvec![hash_value]), - |(hash, _)| *hash, - ); - } - - assert_eq!(join_hash_map.0.len(), data_size as usize); - assert!(join_hash_map.0.capacity() >= data_size as usize); - - // Remove some elements from the JoinHashMap - for hash_value in 0..deleted_part { - join_hash_map - .0 - .remove_entry(hash_value, |(hash, _)| hash_value == *hash); - } - - assert_eq!(join_hash_map.0.len(), (data_size - deleted_part) as usize); - - // Old capacity - let old_capacity = join_hash_map.0.capacity(); - - // Test shrink_if_necessary - join_hash_map.shrink_if_necessary(scale_factor); - - // The capacity should be reduced by the scale factor - let new_expected_capacity = - join_hash_map.0.capacity() * (scale_factor - 1) / scale_factor; - assert!(join_hash_map.0.capacity() >= new_expected_capacity); - assert!(join_hash_map.0.capacity() <= old_capacity); - } -} diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs deleted file mode 100644 index 7eac619687b28..0000000000000 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ /dev/null @@ -1,3179 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This file implements the symmetric hash join algorithm with range-based -//! data pruning to join two (potentially infinite) streams. -//! -//! A [SymmetricHashJoinExec] plan takes two children plan (with appropriate -//! output ordering) and produces the join output according to the given join -//! type and other options. -//! -//! This plan uses the [OneSideHashJoiner] object to facilitate join calculations -//! for both its children. - -use std::collections::{HashMap, VecDeque}; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; -use std::task::Poll; -use std::vec; -use std::{any::Any, usize}; - -use ahash::RandomState; -use arrow::array::{ - ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, - PrimitiveBuilder, -}; -use arrow::compute::concat_batches; -use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use futures::stream::{select, BoxStream}; -use futures::{Stream, StreamExt}; -use hashbrown::{raw::RawTable, HashSet}; -use parking_lot::Mutex; - -use datafusion_common::{utils::bisect, ScalarValue}; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound}; - -use crate::physical_plan::common::SharedMemoryReservation; -use crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema; -use crate::physical_plan::joins::hash_join_utils::JoinHashMap; -use crate::physical_plan::{ - expressions::Column, - expressions::PhysicalSortExpr, - joins::{ - hash_join::{build_join_indices, update_hash}, - hash_join_utils::{build_filter_input_order, SortedFilterExpr}, - utils::{ - build_batch_from_indices, build_join_schema, check_join_is_valid, - combine_join_equivalence_properties, partitioned_join_output_partitioning, - ColumnIndex, JoinFilter, JoinOn, JoinSide, - }, - }, - metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; -use datafusion_common::JoinType; -use datafusion_common::{DataFusionError, Result}; -use datafusion_execution::TaskContext; - -const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; - -/// A symmetric hash join with range conditions is when both streams are hashed on the -/// join key and the resulting hash tables are used to join the streams. -/// The join is considered symmetric because the hash table is built on the join keys from both -/// streams, and the matching of rows is based on the values of the join keys in both streams. -/// This type of join is efficient in streaming context as it allows for fast lookups in the hash -/// table, rather than having to scan through one or both of the streams to find matching rows, also it -/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions), -/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming -/// data without any memory issues. -/// -/// For each input stream, create a hash table. -/// - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets. -/// - Test if input is equal to a predefined set of other inputs. -/// - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch]. -/// - Try to prune other side (probe) with new [RecordBatch]. -/// - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.), -/// output the [RecordBatch] when a pruning happens or at the end of the data. -/// -/// -/// ``` text -/// +-------------------------+ -/// | | -/// left stream ---------| Left OneSideHashJoiner |---+ -/// | | | -/// +-------------------------+ | -/// | -/// |--------- Joined output -/// | -/// +-------------------------+ | -/// | | | -/// right stream ---------| Right OneSideHashJoiner |---+ -/// | | -/// +-------------------------+ -/// -/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic -/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range. -/// -/// -/// PROBE SIDE BUILD SIDE -/// BUFFER BUFFER -/// +-------------+ +------------+ -/// | | | | Unjoinable -/// | | | | Range -/// | | | | -/// | | |--------------------------------- -/// | | | | | -/// | | | | | -/// | | / | | -/// | | | | | -/// | | | | | -/// | | | | | -/// | | | | | -/// | | | | | Joinable -/// | |/ | | Range -/// | || | | -/// |+-----------+|| | | -/// || Record || | | -/// || Batch || | | -/// |+-----------+|| | | -/// +-------------+\ +------------+ -/// | -/// \ -/// |--------------------------------- -/// -/// This happens when range conditions are provided on sorted columns. E.g. -/// -/// SELECT * FROM left_table, right_table -/// ON -/// left_key = right_key AND -/// left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR -/// -/// or -/// SELECT * FROM left_table, right_table -/// ON -/// left_key = right_key AND -/// left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10 -/// -/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to -/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the -/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios) -/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning -/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" , -/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) -/// than that can be dropped from the inner buffer. -/// ``` -#[derive(Debug)] -pub struct SymmetricHashJoinExec { - /// Left side stream - pub(crate) left: Arc, - /// Right side stream - pub(crate) right: Arc, - /// Set of common columns used to join on - pub(crate) on: Vec<(Column, Column)>, - /// Filters applied when finding matching rows - pub(crate) filter: Option, - /// How the join is performed - pub(crate) join_type: JoinType, - /// Expression graph and `SortedFilterExpr`s for interval calculations - filter_state: Option>>, - /// The schema once the join is applied - schema: SchemaRef, - /// Shares the `RandomState` for the hashing algorithm - random_state: RandomState, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// Information of index and left / right placement of columns - column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, -} - -struct IntervalCalculatorInnerState { - /// Expression graph for interval calculations - graph: Option, - sorted_exprs: Vec>, - calculated: bool, -} - -impl Debug for IntervalCalculatorInnerState { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Exprs({:?})", self.sorted_exprs) - } -} - -#[derive(Debug)] -struct SymmetricHashJoinSideMetrics { - /// Number of batches consumed by this operator - input_batches: metrics::Count, - /// Number of rows consumed by this operator - input_rows: metrics::Count, -} - -/// Metrics for HashJoinExec -#[derive(Debug)] -struct SymmetricHashJoinMetrics { - /// Number of left batches/rows consumed by this operator - left: SymmetricHashJoinSideMetrics, - /// Number of right batches/rows consumed by this operator - right: SymmetricHashJoinSideMetrics, - /// Memory used by sides in bytes - pub(crate) stream_memory_usage: metrics::Gauge, - /// Number of batches produced by this operator - output_batches: metrics::Count, - /// Number of rows produced by this operator - output_rows: metrics::Count, -} - -impl SymmetricHashJoinMetrics { - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let left = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let right = SymmetricHashJoinSideMetrics { - input_batches, - input_rows, - }; - - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); - - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - - Self { - left, - right, - output_batches, - stream_memory_usage, - output_rows, - } - } -} - -impl SymmetricHashJoinExec { - /// Tries to create a new [SymmetricHashJoinExec]. - /// # Error - /// This function errors when: - /// - It is not possible to join the left and right sides on keys `on`, or - /// - It fails to construct `SortedFilterExpr`s, or - /// - It fails to create the [ExprIntervalGraph]. - pub fn try_new( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - null_equals_null: bool, - ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - - // Error out if no "on" contraints are given: - if on.is_empty() { - return Err(DataFusionError::Plan( - "On constraints in SymmetricHashJoinExec should be non-empty".to_string(), - )); - } - - // Check if the join is valid with the given on constraints: - check_join_is_valid(&left_schema, &right_schema, &on)?; - - // Build the join schema from the left and right schemas: - let (schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); - - // Initialize the random state for the join operation: - let random_state = RandomState::with_seeds(0, 0, 0, 0); - - let filter_state = if filter.is_some() { - let inner_state = IntervalCalculatorInnerState { - graph: None, - sorted_exprs: vec![], - calculated: false, - }; - Some(Arc::new(Mutex::new(inner_state))) - } else { - None - }; - - Ok(SymmetricHashJoinExec { - left, - right, - on, - filter, - join_type: *join_type, - filter_state, - schema: Arc::new(schema), - random_state, - metrics: ExecutionPlanMetricsSet::new(), - column_indices, - null_equals_null, - }) - } - - /// left stream - pub fn left(&self) -> &Arc { - &self.left - } - - /// right stream - pub fn right(&self) -> &Arc { - &self.right - } - - /// Set of common columns used to join on - pub fn on(&self) -> &[(Column, Column)] { - &self.on - } - - /// Filters applied before join output - pub fn filter(&self) -> Option<&JoinFilter> { - self.filter.as_ref() - } - - /// How the join is performed - pub fn join_type(&self) -> &JoinType { - &self.join_type - } - - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null - } - - /// Check if order information covers every column in the filter expression. - pub fn check_if_order_information_available(&self) -> Result { - if let Some(filter) = self.filter() { - let left = self.left(); - if let Some(left_ordering) = left.output_ordering() { - let right = self.right(); - if let Some(right_ordering) = right.output_ordering() { - let left_convertible = convert_sort_expr_with_filter_schema( - &JoinSide::Left, - filter, - &left.schema(), - &left_ordering[0], - )? - .is_some(); - let right_convertible = convert_sort_expr_with_filter_schema( - &JoinSide::Right, - filter, - &right.schema(), - &right_ordering[0], - )? - .is_some(); - return Ok(left_convertible && right_convertible); - } - } - } - Ok(false) - } -} - -impl ExecutionPlan for SymmetricHashJoinExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn unbounded_output(&self, children: &[bool]) -> Result { - Ok(children.iter().any(|u| *u)) - } - - fn benefits_from_input_partitioning(&self) -> bool { - false - } - - fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) - .unzip(); - // TODO: This will change when we extend collected executions. - vec![ - Distribution::HashPartitioned(left_expr), - Distribution::HashPartitioned(right_expr), - ] - } - - fn output_partitioning(&self) -> Partitioning { - let left_columns_len = self.left.schema().fields.len(); - partitioned_join_output_partitioning( - self.join_type, - self.left.output_partitioning(), - self.right.output_partitioning(), - left_columns_len, - ) - } - - // TODO: Output ordering might be kept for some cases. - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, - self.left.equivalence_properties(), - self.right.equivalence_properties(), - left_columns_len, - self.on(), - self.schema(), - ) - } - - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(SymmetricHashJoinExec::try_new( - children[0].clone(), - children[1].clone(), - self.on.clone(), - self.filter.clone(), - &self.join_type, - self.null_equals_null, - )?)) - } - - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default => { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - write!( - f, - "SymmetricHashJoinExec: join_type={:?}, on={:?}{}", - self.join_type, self.on, display_filter - ) - } - } - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Statistics { - // TODO stats: it is not possible in general to know the output size of joins - Statistics::default() - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let left_partitions = self.left.output_partitioning().partition_count(); - let right_partitions = self.right.output_partitioning().partition_count(); - if left_partitions != right_partitions { - return Err(DataFusionError::Internal(format!( - "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ - consider using RepartitionExec", - ))); - } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph if one is not already built. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.filter_state, &self.filter) { - (Some(interval_state), Some(filter)) => { - // Lock the mutex of the interval state: - let mut filter_state = interval_state.lock(); - // If this is the first partition to be invoked, then we need to initialize our state - // (the expression graph for pruning, sorted filter expressions etc.) - if !filter_state.calculated { - // Interval calculations require each column to exhibit monotonicity - // independently. However, a `PhysicalSortExpr` object defines a - // lexicographical ordering, so we can only use their first elements. - // when deducing column monotonicities. - // TODO: Extend the `PhysicalSortExpr` mechanism to express independent - // (i.e. simultaneous) ordering properties of columns. - - // Build sorted filter expressions for the left and right join side: - let join_sides = [JoinSide::Left, JoinSide::Right]; - let children = [&self.left, &self.right]; - for (join_side, child) in join_sides.iter().zip(children.iter()) { - let sorted_expr = child - .output_ordering() - .and_then(|orders| { - build_filter_input_order( - *join_side, - filter, - &child.schema(), - &orders[0], - ) - .transpose() - }) - .transpose()?; - - filter_state.sorted_exprs.push(sorted_expr); - } - - // Collect available sorted filter expressions: - let sorted_exprs_size = filter_state.sorted_exprs.len(); - let mut sorted_exprs = filter_state - .sorted_exprs - .iter_mut() - .flatten() - .collect::>(); - - // Create the expression graph if we can create sorted filter expressions for both children: - filter_state.graph = if sorted_exprs.len() == sorted_exprs_size { - let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone())?; - - // Gather filter expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|sorted_expr| sorted_expr.filter_expr().clone()) - .collect::>(); - - // Gather node indices of converted filter expressions in `SortedFilterExpr`s - // using the filter columns vector: - let child_node_indices = - graph.gather_node_indices(&filter_exprs); - - // Update SortedFilterExpr instances with the corresponding node indices: - for (sorted_expr, (_, index)) in - sorted_exprs.iter_mut().zip(child_node_indices.iter()) - { - sorted_expr.set_node_index(*index); - } - - Some(graph) - } else { - None - }; - filter_state.calculated = true; - } - // Return the sorted filter expressions for both sides along with the expression graph: - ( - filter_state.sorted_exprs[0].clone(), - filter_state.sorted_exprs[1].clone(), - filter_state.graph.as_ref().cloned(), - ) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - (_, _) => (None, None, None), - }; - - let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - - let left_side_joiner = - OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema()); - let right_side_joiner = - OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - - let left_stream = self - .left - .execute(partition, context.clone())? - .map(|val| (JoinSide::Left, val)); - - let right_stream = self - .right - .execute(partition, context.clone())? - .map(|val| (JoinSide::Right, val)); - // This function will attempt to pull items from both streams. - // Each stream will be polled in a round-robin fashion, and whenever a stream is - // ready to yield an item that item is yielded. - // After one of the two input streams completes, the remaining one will be polled exclusively. - // The returned stream completes when both input streams have completed. - let input_stream = select(left_stream, right_stream).boxed(); - - let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) - .register(context.memory_pool()), - )); - if let Some(g) = graph.as_ref() { - reservation.lock().try_grow(g.size())?; - } - - Ok(Box::pin(SymmetricHashJoinStream { - input_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - final_result: false, - reservation, - })) - } -} - -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { - /// Input stream - input_stream: BoxStream<'static, (JoinSide, Result)>, - /// Input schema - schema: Arc, - /// join filter - filter: Option, - /// type of the join - join_type: JoinType, - // left hash joiner - left: OneSideHashJoiner, - /// right hash joiner - right: OneSideHashJoiner, - /// Information of index and left / right placement of columns - column_indices: Vec, - // Expression graph for range pruning. - graph: Option, - // Left globally sorted filter expr - left_sorted_filter_expr: Option, - // Right globally sorted filter expr - right_sorted_filter_expr: Option, - /// Random state used for hashing initialization - random_state: RandomState, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, - /// Metrics - metrics: SymmetricHashJoinMetrics, - /// Memory reservation - reservation: SharedMemoryReservation, - /// Flag indicating whether there is nothing to process anymore - final_result: bool, -} - -impl RecordBatchStream for SymmetricHashJoinStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for SymmetricHashJoinStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.poll_next_impl(cx) - } -} - -fn prune_hash_values( - prune_length: usize, - hashmap: &mut JoinHashMap, - row_hash_values: &mut VecDeque, - offset: u64, -) -> Result<()> { - // Create a (hash)-(row number set) map - let mut hash_value_map: HashMap> = HashMap::new(); - for index in 0..prune_length { - let hash_value = row_hash_values.pop_front().unwrap(); - if let Some(set) = hash_value_map.get_mut(&hash_value) { - set.insert(offset + index as u64); - } else { - let mut set = HashSet::new(); - set.insert(offset + index as u64); - hash_value_map.insert(hash_value, set); - } - } - for (hash_value, index_set) in hash_value_map.iter() { - if let Some((_, separation_chain)) = hashmap - .0 - .get_mut(*hash_value, |(hash, _)| hash_value == hash) - { - separation_chain.retain(|n| !index_set.contains(n)); - if separation_chain.is_empty() { - hashmap - .0 - .remove_entry(*hash_value, |(hash, _)| hash_value == hash); - } - } - } - hashmap.shrink_if_necessary(HASHMAP_SHRINK_SCALE_FACTOR); - Ok(()) -} - -/// Calculate the filter expression intervals. -/// -/// This function updates the `interval` field of each `SortedFilterExpr` based -/// on the first or the last value of the expression in `build_input_buffer` -/// and `probe_batch`. -/// -/// # Arguments -/// -/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. -/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. -/// * `probe_batch` - The `RecordBatch` on the probe side of the join. -/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. -/// -/// ### Note -/// ```text -/// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. -/// -/// As a concrete example, consider the following query: -/// -/// SELECT * FROM left_table, right_table -/// WHERE -/// left_key = right_key AND -/// a > b - 3 AND -/// a < b + 10 -/// -/// where columns "a" and "b" come from tables "left_table" and "right_table", -/// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left -/// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right -/// side (i.e. when the left side is the build side): -/// -/// Build Probe -/// +-------+ +-------+ -/// | a | z | | b | y | -/// |+--|--+| |+--|--+| -/// | 1 | 2 | | 4 | 3 | -/// |+--|--+| |+--|--+| -/// | 3 | 1 | | 4 | 3 | -/// |+--|--+| |+--|--+| -/// | 5 | 7 | | 6 | 1 | -/// |+--|--+| |+--|--+| -/// | 7 | 1 | | 6 | 3 | -/// +-------+ +-------+ -/// -/// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate -/// intervals for the whole filter expression and propagate join constraint by -/// traversing the expression graph. -/// ``` -fn calculate_filter_expr_intervals( - build_input_buffer: &RecordBatch, - build_sorted_filter_expr: &mut SortedFilterExpr, - probe_batch: &RecordBatch, - probe_sorted_filter_expr: &mut SortedFilterExpr, -) -> Result<()> { - // If either build or probe side has no data, return early: - if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { - return Ok(()); - } - // Calculate the interval for the build side filter expression (if present): - update_filter_expr_interval( - &build_input_buffer.slice(0, 1), - build_sorted_filter_expr, - )?; - // Calculate the interval for the probe side filter expression (if present): - update_filter_expr_interval( - &probe_batch.slice(probe_batch.num_rows() - 1, 1), - probe_sorted_filter_expr, - ) -} - -/// This is a subroutine of the function [`calculate_filter_expr_intervals`]. -/// It constructs the current interval using the given `batch` and updates -/// the filter expression (i.e. `sorted_expr`) with this interval. -fn update_filter_expr_interval( - batch: &RecordBatch, - sorted_expr: &mut SortedFilterExpr, -) -> Result<()> { - // Evaluate the filter expression and convert the result to an array: - let array = sorted_expr - .origin_sorted_expr() - .expr - .evaluate(batch)? - .into_array(1); - // Convert the array to a ScalarValue: - let value = ScalarValue::try_from_array(&array, 0)?; - // Create a ScalarValue representing positive or negative infinity for the same data type: - let unbounded = IntervalBound::make_unbounded(value.get_datatype())?; - // Update the interval with lower and upper bounds based on the sort option: - let interval = if sorted_expr.origin_sorted_expr().options.descending { - Interval::new(unbounded, IntervalBound::new(value, false)) - } else { - Interval::new(IntervalBound::new(value, false), unbounded) - }; - // Set the calculated interval for the sorted filter expression: - sorted_expr.set_interval(interval); - Ok(()) -} - -/// Determine the pruning length for `buffer`. -/// -/// This function evaluates the build side filter expression, converts the -/// result into an array and determines the pruning length by performing a -/// binary search on the array. -/// -/// # Arguments -/// -/// * `buffer`: The record batch to be pruned. -/// * `build_side_filter_expr`: The filter expression on the build side used -/// to determine the pruning length. -/// -/// # Returns -/// -/// A [Result] object that contains the pruning length. The function will return -/// an error if there is an issue evaluating the build side filter expression. -fn determine_prune_length( - buffer: &RecordBatch, - build_side_filter_expr: &SortedFilterExpr, -) -> Result { - let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr(); - let interval = build_side_filter_expr.interval(); - // Evaluate the build side filter expression and convert it into an array - let batch_arr = origin_sorted_expr - .expr - .evaluate(buffer)? - .into_array(buffer.num_rows()); - - // Get the lower or upper interval based on the sort direction - let target = if origin_sorted_expr.options.descending { - interval.upper.value.clone() - } else { - interval.lower.value.clone() - }; - - // Perform binary search on the array to determine the length of the record batch to be pruned - bisect::(&[batch_arr], &[target], &[origin_sorted_expr.options]) -} - -/// This method determines if the result of the join should be produced in the final step or not. -/// -/// # Arguments -/// -/// * `build_side` - Enum indicating the side of the join used as the build side. -/// * `join_type` - Enum indicating the type of join to be performed. -/// -/// # Returns -/// -/// A boolean indicating whether the result of the join should be produced in the final step or not. -/// The result will be true if the build side is JoinSide::Left and the join type is one of -/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi. -/// If the build side is JoinSide::Right, the result will be true if the join type -/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi. -fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { - if build_side == JoinSide::Left { - matches!( - join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi - ) - } else { - matches!( - join_type, - JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi - ) - } -} - -/// Get the anti join indices from the visited hash set. -/// -/// This method returns the indices from the original input that were not present in the visited hash set. -/// -/// # Arguments -/// -/// * `prune_length` - The length of the pruned record batch. -/// * `deleted_offset` - The offset to the indices. -/// * `visited_rows` - The hash set of visited indices. -/// -/// # Returns -/// -/// A `PrimitiveArray` of the anti join indices. -fn get_anti_indices( - prune_length: usize, - deleted_offset: usize, - visited_rows: &HashSet, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(prune_length); - bitmap.append_n(prune_length, false); - // mark the indices as true if they are present in the visited hash set - for v in 0..prune_length { - let row = v + deleted_offset; - bitmap.set_bit(v, visited_rows.contains(&row)); - } - // get the anti index - (0..prune_length) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect() -} - -/// This method creates a boolean buffer from the visited rows hash set -/// and the indices of the pruned record batch slice. -/// -/// It gets the indices from the original input that were present in the visited hash set. -/// -/// # Arguments -/// -/// * `prune_length` - The length of the pruned record batch. -/// * `deleted_offset` - The offset to the indices. -/// * `visited_rows` - The hash set of visited indices. -/// -/// # Returns -/// -/// A [PrimitiveArray] of the specified type T, containing the semi indices. -fn get_semi_indices( - prune_length: usize, - deleted_offset: usize, - visited_rows: &HashSet, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(prune_length); - bitmap.append_n(prune_length, false); - // mark the indices as true if they are present in the visited hash set - (0..prune_length).for_each(|v| { - let row = &(v + deleted_offset); - bitmap.set_bit(v, visited_rows.contains(row)); - }); - // get the semi index - (0..prune_length) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() -} -/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. -/// This function will insert the indices (offset by `offset`) into the `visited` hash set. -/// -/// # Arguments -/// -/// * `visited` - A hash set to store the visited indices. -/// * `offset` - An offset to the indices in the `PrimitiveArray`. -/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. -/// -fn record_visited_indices( - visited: &mut HashSet, - offset: usize, - indices: &PrimitiveArray, -) { - for i in indices.values() { - visited.insert(i.as_usize() + offset); - } -} - -/// Calculate indices by join type. -/// -/// This method returns a tuple of two arrays: build and probe indices. -/// The length of both arrays will be the same. -/// -/// # Arguments -/// -/// * `build_side`: Join side which defines the build side. -/// * `prune_length`: Length of the prune data. -/// * `visited_rows`: Hash set of visited rows of the build side. -/// * `deleted_offset`: Deleted offset of the build side. -/// * `join_type`: The type of join to be performed. -/// -/// # Returns -/// -/// A tuple of two arrays of primitive types representing the build and probe indices. -/// -fn calculate_indices_by_join_type( - build_side: JoinSide, - prune_length: usize, - visited_rows: &HashSet, - deleted_offset: usize, - join_type: JoinType, -) -> Result<(PrimitiveArray, PrimitiveArray)> -where - NativeAdapter: From<::Native>, -{ - // Store the result in a tuple - let result = match (build_side, join_type) { - // In the case of `Left` or `Right` join, or `Full` join, get the anti indices - (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) - | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) - | (_, JoinType::Full) => { - let build_unmatched_indices = - get_anti_indices(prune_length, deleted_offset, visited_rows); - let mut builder = - PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); - builder.append_nulls(build_unmatched_indices.len()); - let probe_indices = builder.finish(); - (build_unmatched_indices, probe_indices) - } - // In the case of `LeftSemi` or `RightSemi` join, get the semi indices - (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { - let build_unmatched_indices = - get_semi_indices(prune_length, deleted_offset, visited_rows); - let mut builder = - PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); - builder.append_nulls(build_unmatched_indices.len()); - let probe_indices = builder.finish(); - (build_unmatched_indices, probe_indices) - } - // The case of other join types is not considered - _ => unreachable!(), - }; - Ok(result) -} - -struct OneSideHashJoiner { - /// Build side - build_side: JoinSide, - /// Input record batch buffer - input_buffer: RecordBatch, - /// Columns from the side - on: Vec, - /// Hashmap - hashmap: JoinHashMap, - /// To optimize hash deleting in case of pruning, we hold them in memory - row_hash_values: VecDeque, - /// Reuse the hashes buffer - hashes_buffer: Vec, - /// Matched rows - visited_rows: HashSet, - /// Offset - offset: usize, - /// Deleted offset - deleted_offset: usize, -} - -impl OneSideHashJoiner { - pub fn size(&self) -> usize { - let mut size = 0; - size += std::mem::size_of_val(self); - size += std::mem::size_of_val(&self.build_side); - size += self.input_buffer.get_array_memory_size(); - size += std::mem::size_of_val(&self.on); - size += self.hashmap.size(); - size += self.row_hash_values.capacity() * std::mem::size_of::(); - size += self.hashes_buffer.capacity() * std::mem::size_of::(); - size += self.visited_rows.capacity() * std::mem::size_of::(); - size += std::mem::size_of_val(&self.offset); - size += std::mem::size_of_val(&self.deleted_offset); - size - } - pub fn new(build_side: JoinSide, on: Vec, schema: SchemaRef) -> Self { - Self { - build_side, - input_buffer: RecordBatch::new_empty(schema), - on, - hashmap: JoinHashMap(RawTable::with_capacity(0)), - row_hash_values: VecDeque::new(), - hashes_buffer: vec![], - visited_rows: HashSet::new(), - offset: 0, - deleted_offset: 0, - } - } - - /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch. - /// - /// # Arguments - /// - /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer - /// * `random_state` - The random state used to hash values - /// - /// # Returns - /// - /// Returns a [Result] encapsulating any intermediate errors. - fn update_internal_state( - &mut self, - batch: &RecordBatch, - random_state: &RandomState, - ) -> Result<()> { - // Merge the incoming batch with the existing input buffer: - self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?; - // Resize the hashes buffer to the number of rows in the incoming batch: - self.hashes_buffer.resize(batch.num_rows(), 0); - // Get allocation_info before adding the item - // Update the hashmap with the join key values and hashes of the incoming batch: - update_hash( - &self.on, - batch, - &mut self.hashmap, - self.offset, - random_state, - &mut self.hashes_buffer, - )?; - // Add the hashes buffer to the hash value deque: - self.row_hash_values.extend(self.hashes_buffer.iter()); - Ok(()) - } - - /// This method performs a join between the build side input buffer and the probe side batch. - /// - /// # Arguments - /// - /// * `schema` - A reference to the schema of the output record batch. - /// * `join_type` - The type of join to be performed. - /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. - /// * `filter` - An optional filter on the join condition. - /// * `probe_batch` - The second record batch to be joined. - /// * `probe_visited` - A hash set to store the visited indices from the probe batch. - /// * `probe_offset` - The offset of the probe side for visited indices calculations. - /// * `column_indices` - An array of columns to be selected for the result of the join. - /// * `random_state` - The random state for the join. - /// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. - /// - /// # Returns - /// - /// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. - /// If the join type is one of the above four, the function will return [None]. - #[allow(clippy::too_many_arguments)] - fn join_with_probe_batch( - &mut self, - schema: &SchemaRef, - join_type: JoinType, - on_probe: &[Column], - filter: Option<&JoinFilter>, - probe_batch: &RecordBatch, - probe_visited: &mut HashSet, - probe_offset: usize, - column_indices: &[ColumnIndex], - random_state: &RandomState, - null_equals_null: bool, - ) -> Result> { - if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { - return Ok(None); - } - let (build_indices, probe_indices) = build_join_indices( - probe_batch, - &self.hashmap, - &self.input_buffer, - &self.on, - on_probe, - filter, - random_state, - null_equals_null, - &mut self.hashes_buffer, - Some(self.deleted_offset), - self.build_side, - )?; - if need_to_produce_result_in_final(self.build_side, join_type) { - record_visited_indices( - &mut self.visited_rows, - self.deleted_offset, - &build_indices, - ); - } - if need_to_produce_result_in_final(self.build_side.negate(), join_type) { - record_visited_indices(probe_visited, probe_offset, &probe_indices); - } - if matches!( - join_type, - JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftSemi - | JoinType::RightSemi - ) { - Ok(None) - } else { - build_batch_from_indices( - schema, - &self.input_buffer, - probe_batch, - build_indices, - probe_indices, - column_indices, - self.build_side, - ) - .map(|batch| (batch.num_rows() > 0).then_some(batch)) - } - } - - /// This function produces unmatched record results based on the build side, - /// join type and other parameters. - /// - /// The method uses first `prune_length` rows from the build side input buffer - /// to produce results. - /// - /// # Arguments - /// - /// * `output_schema` - The schema of the final output record batch. - /// * `prune_length` - The length of the determined prune length. - /// * `probe_schema` - The schema of the probe [RecordBatch]. - /// * `join_type` - The type of join to be performed. - /// * `column_indices` - Indices of columns that are being joined. - /// - /// # Returns - /// - /// * `Option` - The final output record batch if required, otherwise [None]. - fn build_side_determined_results( - &self, - output_schema: &SchemaRef, - prune_length: usize, - probe_schema: SchemaRef, - join_type: JoinType, - column_indices: &[ColumnIndex], - ) -> Result> { - // Check if we need to produce a result in the final output: - if need_to_produce_result_in_final(self.build_side, join_type) { - // Calculate the indices for build and probe sides based on join type and build side: - let (build_indices, probe_indices) = calculate_indices_by_join_type( - self.build_side, - prune_length, - &self.visited_rows, - self.deleted_offset, - join_type, - )?; - - // Create an empty probe record batch: - let empty_probe_batch = RecordBatch::new_empty(probe_schema); - // Build the final result from the indices of build and probe sides: - build_batch_from_indices( - output_schema.as_ref(), - &self.input_buffer, - &empty_probe_batch, - build_indices, - probe_indices, - column_indices, - self.build_side, - ) - .map(|batch| (batch.num_rows() > 0).then_some(batch)) - } else { - // If we don't need to produce a result, return None - Ok(None) - } - } - - /// Prunes the internal buffer. - /// - /// Argument `probe_batch` is used to update the intervals of the sorted - /// filter expressions. The updated build interval determines the new length - /// of the build side. If there are rows to prune, they are removed from the - /// internal buffer. - /// - /// # Arguments - /// - /// * `schema` - The schema of the final output record batch - /// * `probe_batch` - Incoming RecordBatch of the probe side. - /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. - /// * `join_type` - The type of join (e.g. inner, left, right, etc.). - /// * `column_indices` - A vector of column indices that specifies which columns from the - /// build side should be included in the output. - /// * `graph` - A mutable reference to the physical expression graph. - /// - /// # Returns - /// - /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. - /// Otherwise, returns `Ok(None)`. - fn calculate_prune_length_with_probe_batch( - &mut self, - build_side_sorted_filter_expr: &mut SortedFilterExpr, - probe_side_sorted_filter_expr: &mut SortedFilterExpr, - graph: &mut ExprIntervalGraph, - ) -> Result { - // Return early if the input buffer is empty: - if self.input_buffer.num_rows() == 0 { - return Ok(0); - } - // Process the build and probe side sorted filter expressions if both are present: - // Collect the sorted filter expressions into a vector of (node_index, interval) tuples: - let mut filter_intervals = vec![]; - for expr in [ - &build_side_sorted_filter_expr, - &probe_side_sorted_filter_expr, - ] { - filter_intervals.push((expr.node_index(), expr.interval().clone())) - } - // Update the physical expression graph using the join filter intervals: - graph.update_ranges(&mut filter_intervals)?; - // Extract the new join filter interval for the build side: - let calculated_build_side_interval = filter_intervals.remove(0).1; - // If the intervals have not changed, return early without pruning: - if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) { - return Ok(0); - } - // Update the build side interval and determine the pruning length: - build_side_sorted_filter_expr.set_interval(calculated_build_side_interval); - - determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr) - } - - fn prune_internal_state_and_build_anti_result( - &mut self, - prune_length: usize, - schema: &SchemaRef, - probe_batch: &RecordBatch, - join_type: JoinType, - column_indices: &[ColumnIndex], - ) -> Result> { - // Compute the result and perform pruning if there are rows to prune: - let result = self.build_side_determined_results( - schema, - prune_length, - probe_batch.schema(), - join_type, - column_indices, - ); - // Prune the hash values: - prune_hash_values( - prune_length, - &mut self.hashmap, - &mut self.row_hash_values, - self.deleted_offset as u64, - )?; - // Remove pruned rows from the visited rows set: - for row in self.deleted_offset..(self.deleted_offset + prune_length) { - self.visited_rows.remove(&row); - } - // Update the input buffer after pruning: - self.input_buffer = self - .input_buffer - .slice(prune_length, self.input_buffer.num_rows() - prune_length); - // Increment the deleted offset: - self.deleted_offset += prune_length; - result - } -} - -fn combine_two_batches( - output_schema: &SchemaRef, - left_batch: Option, - right_batch: Option, -) -> Result> { - match (left_batch, right_batch) { - (Some(batch), None) | (None, Some(batch)) => { - // If only one of the batches are present, return it: - Ok(Some(batch)) - } - (Some(left_batch), Some(right_batch)) => { - // If both batches are present, concatenate them: - concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) - .map(Some) - } - (None, None) => { - // If neither is present, return an empty batch: - Ok(None) - } - } -} - -impl SymmetricHashJoinStream { - fn size(&self) -> usize { - let mut size = 0; - size += std::mem::size_of_val(&self.input_stream); - size += std::mem::size_of_val(&self.schema); - size += std::mem::size_of_val(&self.filter); - size += std::mem::size_of_val(&self.join_type); - size += self.left.size(); - size += self.right.size(); - size += std::mem::size_of_val(&self.column_indices); - size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); - size += std::mem::size_of_val(&self.left_sorted_filter_expr); - size += std::mem::size_of_val(&self.right_sorted_filter_expr); - size += std::mem::size_of_val(&self.random_state); - size += std::mem::size_of_val(&self.null_equals_null); - size += std::mem::size_of_val(&self.metrics); - size += std::mem::size_of_val(&self.final_result); - size - } - /// Polls the next result of the join operation. - /// - /// If the result of the join is ready, it returns the next record batch. - /// If the join has completed and there are no more results, it returns - /// `Poll::Ready(None)`. If the join operation is not complete, but the - /// current stream is not ready yet, it returns `Poll::Pending`. - fn poll_next_impl( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - // Poll the next batch from `input_stream`: - match self.input_stream.poll_next_unpin(cx) { - // Batch is available - Poll::Ready(Some((side, Ok(probe_batch)))) => { - // Determine which stream should be polled next. The side the - // RecordBatch comes from becomes the probe side. - let ( - probe_hash_joiner, - build_hash_joiner, - probe_side_sorted_filter_expr, - build_side_sorted_filter_expr, - probe_side_metrics, - ) = if side.eq(&JoinSide::Left) { - ( - &mut self.left, - &mut self.right, - &mut self.left_sorted_filter_expr, - &mut self.right_sorted_filter_expr, - &mut self.metrics.left, - ) - } else { - ( - &mut self.right, - &mut self.left, - &mut self.right_sorted_filter_expr, - &mut self.left_sorted_filter_expr, - &mut self.metrics.right, - ) - }; - // Update the metrics for the stream that was polled: - probe_side_metrics.input_batches.add(1); - probe_side_metrics.input_rows.add(probe_batch.num_rows()); - // Update the internal state of the hash joiner for the build side: - probe_hash_joiner - .update_internal_state(&probe_batch, &self.random_state)?; - // Join the two sides: - let equal_result = build_hash_joiner.join_with_probe_batch( - &self.schema, - self.join_type, - &probe_hash_joiner.on, - self.filter.as_ref(), - &probe_batch, - &mut probe_hash_joiner.visited_rows, - probe_hash_joiner.offset, - &self.column_indices, - &self.random_state, - self.null_equals_null, - )?; - // Increment the offset for the probe hash joiner: - probe_hash_joiner.offset += probe_batch.num_rows(); - - let anti_result = if let ( - Some(build_side_sorted_filter_expr), - Some(probe_side_sorted_filter_expr), - Some(graph), - ) = ( - build_side_sorted_filter_expr.as_mut(), - probe_side_sorted_filter_expr.as_mut(), - self.graph.as_mut(), - ) { - // Calculate filter intervals: - calculate_filter_expr_intervals( - &build_hash_joiner.input_buffer, - build_side_sorted_filter_expr, - &probe_batch, - probe_side_sorted_filter_expr, - )?; - let prune_length = build_hash_joiner - .calculate_prune_length_with_probe_batch( - build_side_sorted_filter_expr, - probe_side_sorted_filter_expr, - graph, - )?; - - if prune_length > 0 { - build_hash_joiner.prune_internal_state_and_build_anti_result( - prune_length, - &self.schema, - &probe_batch, - self.join_type, - &self.column_indices, - )? - } else { - None - } - } else { - None - }; - - // Combine results: - let result = - combine_two_batches(&self.schema, equal_result, anti_result)?; - let capacity = self.size(); - self.metrics.stream_memory_usage.set(capacity); - self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Ready(Some((_, Err(e)))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - // If the final result has already been obtained, return `Poll::Ready(None)`: - if self.final_result { - return Poll::Ready(None); - } - self.final_result = true; - // Get the left side results: - let left_result = self.left.build_side_determined_results( - &self.schema, - self.left.input_buffer.num_rows(), - self.right.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - // Get the right side results: - let right_result = self.right.build_side_determined_results( - &self.schema, - self.right.input_buffer.num_rows(), - self.left.input_buffer.schema(), - self.join_type, - &self.column_indices, - )?; - - // Combine the left and right results: - let result = - combine_two_batches(&self.schema, left_result, right_result)?; - - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Ok(result).transpose()); - } - } - Poll::Pending => return Poll::Pending, - } - } - } -} - -#[cfg(test)] -mod tests { - use std::fs::File; - - use arrow::array::{ArrayRef, Float64Array, IntervalDayTimeArray}; - use arrow::array::{Int32Array, TimestampMillisecondArray}; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; - use arrow::util::pretty::pretty_format_batches; - use rstest::*; - use tempfile::TempDir; - - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, col, Column}; - use datafusion_physical_expr::intervals::test_utils::{ - gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, - }; - use datafusion_physical_expr::PhysicalExpr; - - use crate::physical_plan::joins::{ - hash_join_utils::tests::complicated_filter, HashJoinExec, PartitionMode, - }; - use crate::physical_plan::{ - common, displayable, memory::MemoryExec, repartition::RepartitionExec, - }; - use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; - use crate::test_util::register_unbounded_file_with_ordering; - - use super::*; - - const TABLE_SIZE: i32 = 100; - - fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { - // compare - let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); - let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); - - let mut first_formatted_sorted: Vec<&str> = - first_formatted.trim().lines().collect(); - first_formatted_sorted.sort_unstable(); - - let mut second_formatted_sorted: Vec<&str> = - second_formatted.trim().lines().collect(); - second_formatted_sorted.sort_unstable(); - - for (i, (first_line, second_line)) in first_formatted_sorted - .iter() - .zip(&second_formatted_sorted) - .enumerate() - { - assert_eq!((i, first_line), (i, second_line)); - } - } - - async fn partitioned_sym_join_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - null_equals_null: bool, - context: Arc, - ) -> Result> { - let partition_count = 4; - - let left_expr = on - .iter() - .map(|(l, _)| Arc::new(l.clone()) as _) - .collect::>(); - - let right_expr = on - .iter() - .map(|(_, r)| Arc::new(r.clone()) as _) - .collect::>(); - - let join = SymmetricHashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, partition_count), - )?), - Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, partition_count), - )?), - on, - filter, - join_type, - null_equals_null, - )?; - - let mut batches = vec![]; - for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; - let more_batches = common::collect(stream).await?; - batches.extend( - more_batches - .into_iter() - .filter(|b| b.num_rows() > 0) - .collect::>(), - ); - } - - Ok(batches) - } - - async fn partitioned_hash_join_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - null_equals_null: bool, - context: Arc, - ) -> Result> { - let partition_count = 4; - - let (left_expr, right_expr) = on - .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) - .unzip(); - - let join = HashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, partition_count), - )?), - Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, partition_count), - )?), - on, - filter, - join_type, - PartitionMode::Partitioned, - null_equals_null, - )?; - - let mut batches = vec![]; - for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; - let more_batches = common::collect(stream).await?; - batches.extend( - more_batches - .into_iter() - .filter(|b| b.num_rows() > 0) - .collect::>(), - ); - } - - Ok(batches) - } - - pub fn split_record_batches( - batch: &RecordBatch, - batch_size: usize, - ) -> Result> { - let row_num = batch.num_rows(); - let number_of_batch = row_num / batch_size; - let mut sizes = vec![batch_size; number_of_batch]; - sizes.push(row_num - (batch_size * number_of_batch)); - let mut result = vec![]; - for (i, size) in sizes.iter().enumerate() { - result.push(batch.slice(i * batch_size, *size)); - } - Ok(result) - } - - // It creates join filters for different type of fields for testing. - macro_rules! join_expr_tests { - ($func_name:ident, $type:ty, $SCALAR:ident) => { - fn $func_name( - expr_id: usize, - left_col: Arc, - right_col: Arc, - ) -> Arc { - match expr_id { - // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 0 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 1 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 - 2 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 - 3 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(10 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 - 4 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ), - ScalarValue::$SCALAR(Some(10 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(30 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 - 5 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Minus, - ), - ScalarValue::$SCALAR(Some(2 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(7 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - (Operator::GtEq, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 6 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Minus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(28 as $type)), - ScalarValue::$SCALAR(Some(11 as $type)), - ScalarValue::$SCALAR(Some(21 as $type)), - ScalarValue::$SCALAR(Some(39 as $type)), - (Operator::Gt, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 7 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(28 as $type)), - ScalarValue::$SCALAR(Some(11 as $type)), - ScalarValue::$SCALAR(Some(21 as $type)), - ScalarValue::$SCALAR(Some(39 as $type)), - (Operator::GtEq, Operator::Lt), - ), - _ => panic!("No case"), - } - } - }; - } - - join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); - join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - use std::iter::Iterator; - - struct AscendingRandomFloatIterator { - prev: f64, - max: f64, - rng: StdRng, - } - - impl AscendingRandomFloatIterator { - fn new(min: f64, max: f64) -> Self { - let mut rng = StdRng::seed_from_u64(42); - let initial = rng.gen_range(min..max); - AscendingRandomFloatIterator { - prev: initial, - max, - rng, - } - } - } - - impl Iterator for AscendingRandomFloatIterator { - type Item = f64; - - fn next(&mut self) -> Option { - let value = self.rng.gen_range(self.prev..self.max); - self.prev = value; - Some(value) - } - } - - fn join_expr_tests_fixture_temporal( - expr_id: usize, - left_col: Arc, - right_col: Arc, - schema: &Schema, - ) -> Result> { - match expr_id { - // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms')) - 0 => gen_conjunctive_temporal_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ScalarValue::new_interval_dt(0, 100), // 100 ms - ScalarValue::new_interval_dt(0, 200), // 200 ms - ScalarValue::new_interval_dt(0, 450), // 450 ms - ScalarValue::new_interval_dt(0, 300), // 300 ms - schema, - ), - // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02')) - 1 => gen_conjunctive_temporal_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03 - ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01 - ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00 - ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 - schema, - ), - _ => unreachable!(), - } - } - fn build_sides_record_batches( - table_size: i32, - key_cardinality: (i32, i32), - ) -> Result<(RecordBatch, RecordBatch)> { - let null_ratio: f64 = 0.4; - let initial_range = 0..table_size; - let index = (table_size as f64 * null_ratio).round() as i32; - let rest_of = index..table_size; - let ordered: ArrayRef = Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )); - let ordered_des = Arc::new(Int32Array::from_iter( - initial_range.clone().rev().collect::>(), - )); - let cardinality = Arc::new(Int32Array::from_iter( - initial_range.clone().map(|x| x % 4).collect::>(), - )); - let cardinality_key_left = Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.0) - .collect::>(), - )); - let cardinality_key_right = Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.1) - .collect::>(), - )); - let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.clone().map(Some)) - .collect::>>() - })); - let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ - rest_of - .clone() - .map(Some) - .chain(std::iter::repeat(None).take(index as usize)) - .collect::>>() - })); - - let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.rev().map(Some)) - .collect::>>() - })); - - let time = Arc::new(TimestampMillisecondArray::from( - initial_range - .clone() - .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00 - .collect::>(), - )); - let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( - initial_range - .map(|x| x as i64 * 100) // x * 100ms - .collect::>(), - )); - - let float_asc = Arc::new(Float64Array::from_iter_values( - AscendingRandomFloatIterator::new(0., table_size as f64) - .take(table_size as usize), - )); - - let left = RecordBatch::try_from_iter(vec![ - ("la1", ordered.clone()), - ("lb1", cardinality.clone()), - ("lc1", cardinality_key_left), - ("lt1", time.clone()), - ("la2", ordered.clone()), - ("la1_des", ordered_des.clone()), - ("l_asc_null_first", ordered_asc_null_first.clone()), - ("l_asc_null_last", ordered_asc_null_last.clone()), - ("l_desc_null_first", ordered_desc_null_first.clone()), - ("li1", interval_time.clone()), - ("l_float", float_asc.clone()), - ])?; - let right = RecordBatch::try_from_iter(vec![ - ("ra1", ordered.clone()), - ("rb1", cardinality), - ("rc1", cardinality_key_right), - ("rt1", time), - ("ra2", ordered), - ("ra1_des", ordered_des), - ("r_asc_null_first", ordered_asc_null_first), - ("r_asc_null_last", ordered_asc_null_last), - ("r_desc_null_first", ordered_desc_null_first), - ("ri1", interval_time), - ("r_float", float_asc), - ])?; - Ok((left, right)) - } - - fn create_memory_table( - left_batch: RecordBatch, - right_batch: RecordBatch, - left_sorted: Option>, - right_sorted: Option>, - batch_size: usize, - ) -> Result<(Arc, Arc)> { - let mut left = MemoryExec::try_new( - &[split_record_batches(&left_batch, batch_size)?], - left_batch.schema(), - None, - )?; - if let Some(sorted) = left_sorted { - left = left.with_sort_information(sorted); - } - let mut right = MemoryExec::try_new( - &[split_record_batches(&right_batch, batch_size)?], - right_batch.schema(), - None, - )?; - if let Some(sorted) = right_sorted { - right = right.with_sort_information(sorted); - } - Ok((Arc::new(left), Arc::new(right))) - } - - async fn experiment( - left: Arc, - right: Arc, - filter: Option, - join_type: JoinType, - on: JoinOn, - task_ctx: Arc, - ) -> Result<()> { - let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), - on.clone(), - filter.clone(), - &join_type, - false, - task_ctx.clone(), - ) - .await?; - let second_batches = partitioned_hash_join_with_filter( - left, right, on, filter, &join_type, false, task_ctx, - ) - .await?; - compare_batches(&first_batches, &second_batches); - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn complex_join_all_one_ascending_numeric( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), - )] - cardinality: (i32, i32), - ) -> Result<()> { - // a + b > c + 10 AND a + b < c + 100 - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: binary( - col("la1", left_schema)?, - Operator::Plus, - col("la2", left_schema)?, - left_schema, - )?, - options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("ra1", right_schema)?, - options: SortOptions::default(), - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let filter_expr = complicated_filter(&intermediate_schema)?; - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 4, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn join_all_one_ascending_numeric( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), - )] - cardinality: (i32, i32), - #[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize, - ) -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("la1", left_schema)?, - options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("ra1", right_schema)?, - options: SortOptions::default(), - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn join_without_sort_information( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), - )] - cardinality: (i32, i32), - #[values(0, 1, 2, 3, 4, 5, 6)] case_expr: usize, - ) -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let (left, right) = create_memory_table(left_batch, right_batch, None, None, 13)?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 5, - side: JoinSide::Left, - }, - ColumnIndex { - index: 5, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn join_without_filter( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - ) -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = build_sides_record_batches(TABLE_SIZE, (11, 21))?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let (left, right) = create_memory_table(left_batch, right_batch, None, None, 13)?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - experiment(left, right, None, join_type, on, task_ctx).await?; - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn join_all_one_descending_numeric_particular( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), - )] - cardinality: (i32, i32), - #[values(0, 1, 2, 3, 4, 5, 6)] case_expr: usize, - ) -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("la1_des", left_schema)?, - options: SortOptions { - descending: true, - nulls_first: true, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("ra1_des", right_schema)?, - options: SortOptions { - descending: true, - nulls_first: true, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 5, - side: JoinSide::Left, - }, - ColumnIndex { - index: 5, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[tokio::test] - async fn join_change_in_planner() -> Result<()> { - let config = SessionConfig::new().with_target_partitions(8); - let ctx = SessionContext::with_config(config); - let tmp_dir = TempDir::new().unwrap(); - let left_file_path = tmp_dir.path().join("left.csv"); - File::create(left_file_path.clone()).unwrap(); - // Create schema - let schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::UInt32, false), - Field::new("a2", DataType::UInt32, false), - ])); - // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; - register_unbounded_file_with_ordering( - &ctx, - schema.clone(), - &left_file_path, - "left", - file_sort_order.clone(), - true, - ) - .await?; - let right_file_path = tmp_dir.path().join("right.csv"); - File::create(right_file_path.clone()).unwrap(); - register_unbounded_file_with_ordering( - &ctx, - schema, - &right_file_path, - "right", - file_sort_order, - true, - ) - .await?; - let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; - let dataframe = ctx.sql(sql).await?; - let physical_plan = dataframe.create_physical_plan().await?; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - [ - "SymmetricHashJoinExec: join_type=Full, on=[(Column { name: \"a2\", index: 1 }, Column { name: \"a2\", index: 1 })], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 1 }], 8), input_partitions=1", - // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 1 }], 8), input_partitions=1", - // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(3); - actual.remove(5); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - Ok(()) - } - - #[tokio::test] - async fn join_change_in_planner_without_sort() -> Result<()> { - let config = SessionConfig::new().with_target_partitions(8); - let ctx = SessionContext::with_config(config); - let tmp_dir = TempDir::new()?; - let left_file_path = tmp_dir.path().join("left.csv"); - File::create(left_file_path.clone())?; - let schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::UInt32, false), - Field::new("a2", DataType::UInt32, false), - ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; - let right_file_path = tmp_dir.path().join("right.csv"); - File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; - let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; - let dataframe = ctx.sql(sql).await?; - let physical_plan = dataframe.create_physical_plan().await?; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - [ - "SymmetricHashJoinExec: join_type=Full, on=[(Column { name: \"a2\", index: 1 }, Column { name: \"a2\", index: 1 })], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 1 }], 8), input_partitions=1", - // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 1 }], 8), input_partitions=1", - // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(3); - actual.remove(5); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - Ok(()) - } - - #[tokio::test] - async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { - let config = SessionConfig::new() - .with_target_partitions(8) - .with_allow_symmetric_joins_without_pruning(false); - let ctx = SessionContext::with_config(config); - let tmp_dir = TempDir::new()?; - let left_file_path = tmp_dir.path().join("left.csv"); - File::create(left_file_path.clone())?; - let schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::UInt32, false), - Field::new("a2", DataType::UInt32, false), - ])); - ctx.register_csv( - "left", - left_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; - let right_file_path = tmp_dir.path().join("right.csv"); - File::create(right_file_path.clone())?; - ctx.register_csv( - "right", - right_file_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new().schema(&schema).mark_infinite(true), - ) - .await?; - let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; - match df.create_physical_plan().await { - Ok(_) => panic!("Expecting error."), - Err(e) => { - assert_eq!(e.to_string(), "PipelineChecker\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") - } - } - Ok(()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn build_null_columns_first() -> Result<()> { - let join_type = JoinType::Full; - let cardinality = (10, 11); - let case_expr = 1; - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("l_asc_null_first", left_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("r_asc_null_first", right_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 6, - side: JoinSide::Left, - }, - ColumnIndex { - index: 6, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn build_null_columns_last() -> Result<()> { - let join_type = JoinType::Full; - let cardinality = (10, 11); - let case_expr = 1; - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("l_asc_null_last", left_schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("r_asc_null_last", right_schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 7, - side: JoinSide::Left, - }, - ColumnIndex { - index: 7, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn build_null_columns_first_descending() -> Result<()> { - let join_type = JoinType::Full; - let cardinality = (10, 11); - let case_expr = 1; - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("l_desc_null_first", left_schema)?, - options: SortOptions { - descending: true, - nulls_first: true, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("r_desc_null_first", right_schema)?, - options: SortOptions { - descending: true, - nulls_first: true, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Int32, true), - Field::new("right", DataType::Int32, true), - ]); - let filter_expr = join_expr_tests_fixture_i32( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 8, - side: JoinSide::Left, - }, - ColumnIndex { - index: 8, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> { - let cardinality = (3, 4); - let join_type = JoinType::Full; - - // a + b > c + 10 AND a + b < c + 100 - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("la1", left_schema)?, - options: SortOptions::default(), - }]; - - let right_sorted = vec![PhysicalSortExpr { - expr: col("ra1", right_schema)?, - options: SortOptions::default(), - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - Field::new("2", DataType::Int32, true), - ]); - let filter_expr = complicated_filter(&intermediate_schema)?; - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 4, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn test_one_side_hash_joiner_visited_rows( - #[values( - (JoinType::Inner, true), - (JoinType::Left,false), - (JoinType::Right, true), - (JoinType::RightSemi, true), - (JoinType::LeftSemi, false), - (JoinType::LeftAnti, false), - (JoinType::RightAnti, true), - (JoinType::Full, false), - )] - case: (JoinType, bool), - ) -> Result<()> { - // Set a random state for the join - let join_type = case.0; - let should_be_empty = case.1; - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - // Ensure there will be matching rows - let (left_batch, right_batch) = build_sides_record_batches(20, (1, 1))?; - let left_schema = left_batch.schema(); - let right_schema = right_batch.schema(); - - // Build the join schema from the left and right schemas - let (schema, join_column_indices) = - build_join_schema(&left_schema, &right_schema, &join_type); - let join_schema = Arc::new(schema); - - // Sort information for MemoryExec - let left_sorted = vec![PhysicalSortExpr { - expr: col("la1", &left_schema)?, - options: SortOptions::default(), - }]; - // Sort information for MemoryExec - let right_sorted = vec![PhysicalSortExpr { - expr: col("ra1", &right_schema)?, - options: SortOptions::default(), - }]; - // Construct MemoryExec - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 10, - )?; - - // Filter columns, ensure first batches will have matching rows. - let intermediate_schema = Schema::new(vec![ - Field::new("0", DataType::Int32, true), - Field::new("1", DataType::Int32, true), - ]); - let filter_expr = gen_conjunctive_numerical_expr( - col("0", &intermediate_schema)?, - col("1", &intermediate_schema)?, - ( - Operator::Plus, - Operator::Minus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::Int32(Some(0)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(0)), - ScalarValue::Int32(Some(3)), - (Operator::Gt, Operator::Lt), - ); - let column_indices = vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - let mut left_side_joiner = OneSideHashJoiner::new( - JoinSide::Left, - vec![Column::new_with_schema("lc1", &left_schema)?], - left_schema, - ); - - let mut right_side_joiner = OneSideHashJoiner::new( - JoinSide::Right, - vec![Column::new_with_schema("rc1", &right_schema)?], - right_schema, - ); - - let mut left_stream = left.execute(0, task_ctx.clone())?; - let mut right_stream = right.execute(0, task_ctx)?; - - let initial_left_batch = left_stream.next().await.unwrap()?; - left_side_joiner.update_internal_state(&initial_left_batch, &random_state)?; - assert_eq!( - left_side_joiner.input_buffer.num_rows(), - initial_left_batch.num_rows() - ); - - let initial_right_batch = right_stream.next().await.unwrap()?; - right_side_joiner.update_internal_state(&initial_right_batch, &random_state)?; - assert_eq!( - right_side_joiner.input_buffer.num_rows(), - initial_right_batch.num_rows() - ); - - left_side_joiner.join_with_probe_batch( - &join_schema, - join_type, - &right_side_joiner.on, - Some(&filter), - &initial_right_batch, - &mut right_side_joiner.visited_rows, - right_side_joiner.offset, - &join_column_indices, - &random_state, - false, - )?; - assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty); - Ok(()) - } - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn testing_with_temporal_columns( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (99, 12), - )] - cardinality: (i32, i32), - #[values(0, 1)] case_expr: usize, - ) -> Result<()> { - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - let left_sorted = vec![PhysicalSortExpr { - expr: col("lt1", left_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("rt1", right_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - let intermediate_schema = Schema::new(vec![ - Field::new( - "left", - DataType::Timestamp(TimeUnit::Millisecond, None), - false, - ), - Field::new( - "right", - DataType::Timestamp(TimeUnit::Millisecond, None), - false, - ), - ]); - let filter_expr = join_expr_tests_fixture_temporal( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - &intermediate_schema, - )?; - let column_indices = vec![ - ColumnIndex { - index: 3, - side: JoinSide::Left, - }, - ColumnIndex { - index: 3, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn test_with_interval_columns( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (99, 12), - )] - cardinality: (i32, i32), - ) -> Result<()> { - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - let left_sorted = vec![PhysicalSortExpr { - expr: col("li1", left_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("ri1", right_schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Interval(IntervalUnit::DayTime), false), - Field::new("right", DataType::Interval(IntervalUnit::DayTime), false), - ]); - let filter_expr = join_expr_tests_fixture_temporal( - 0, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - &intermediate_schema, - )?; - let column_indices = vec![ - ColumnIndex { - index: 9, - side: JoinSide::Left, - }, - ColumnIndex { - index: 9, - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - - Ok(()) - } - - #[rstest] - #[tokio::test(flavor = "multi_thread")] - async fn testing_ascending_float_pruning( - #[values( - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::RightSemi, - JoinType::LeftSemi, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::Full - )] - join_type: JoinType, - #[values( - (4, 5), - (99, 12), - )] - cardinality: (i32, i32), - #[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize, - ) -> Result<()> { - let config = SessionConfig::new().with_repartition_joins(false); - let session_ctx = SessionContext::with_config(config); - let task_ctx = session_ctx.task_ctx(); - let (left_batch, right_batch) = - build_sides_record_batches(TABLE_SIZE, cardinality)?; - let left_schema = &left_batch.schema(); - let right_schema = &right_batch.schema(); - let left_sorted = vec![PhysicalSortExpr { - expr: col("l_float", left_schema)?, - options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { - expr: col("r_float", right_schema)?, - options: SortOptions::default(), - }]; - let (left, right) = create_memory_table( - left_batch, - right_batch, - Some(left_sorted), - Some(right_sorted), - 13, - )?; - - let on = vec![( - Column::new_with_schema("lc1", left_schema)?, - Column::new_with_schema("rc1", right_schema)?, - )]; - - let intermediate_schema = Schema::new(vec![ - Field::new("left", DataType::Float64, true), - Field::new("right", DataType::Float64, true), - ]); - let filter_expr = join_expr_tests_fixture_f64( - case_expr, - col("left", &intermediate_schema)?, - col("right", &intermediate_schema)?, - ); - let column_indices = vec![ - ColumnIndex { - index: 10, // l_float - side: JoinSide::Left, - }, - ColumnIndex { - index: 10, // r_float - side: JoinSide::Right, - }, - ]; - let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); - - experiment(left, right, Some(filter), join_type, on, task_ctx).await?; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs deleted file mode 100644 index 13a1888508aa4..0000000000000 --- a/datafusion/core/src/physical_plan/mod.rs +++ /dev/null @@ -1,784 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Traits for physical query plan, supporting parallel execution for partitioned relations. - -pub use self::metrics::Metric; -use self::metrics::MetricsSet; -use self::{ - coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, -}; -use crate::physical_plan::expressions::PhysicalSortExpr; -use datafusion_common::Result; -pub use datafusion_common::{ColumnStatistics, Statistics}; - -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; - -use datafusion_common::utils::DataPtr; -pub use datafusion_expr::Accumulator; -pub use datafusion_expr::ColumnarValue; -pub use datafusion_physical_expr::aggregate::row_accumulator::RowAccumulator; -use datafusion_physical_expr::equivalence::OrderingEquivalenceProperties; -pub use display::DisplayFormatType; -use futures::stream::{Stream, TryStreamExt}; -use std::fmt; -use std::fmt::Debug; - -use datafusion_common::tree_node::Transformed; -use datafusion_common::DataFusionError; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{any::Any, pin::Pin}; - -/// Trait for types that stream [arrow::record_batch::RecordBatch] -pub trait RecordBatchStream: Stream> { - /// Returns the schema of this `RecordBatchStream`. - /// - /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this - /// stream should have the same schema as returned from this method. - fn schema(&self) -> SchemaRef; -} - -/// Trait for a stream of record batches. -pub type SendableRecordBatchStream = Pin>; - -/// EmptyRecordBatchStream can be used to create a RecordBatchStream -/// that will produce no results -pub struct EmptyRecordBatchStream { - /// Schema wrapped by Arc - schema: SchemaRef, -} - -impl EmptyRecordBatchStream { - /// Create an empty RecordBatchStream - pub fn new(schema: SchemaRef) -> Self { - Self { schema } - } -} - -impl RecordBatchStream for EmptyRecordBatchStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for EmptyRecordBatchStream { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(None) - } -} - -/// Physical planner interface -pub use self::planner::PhysicalPlanner; - -/// `ExecutionPlan` represent nodes in the DataFusion Physical Plan. -/// -/// Each `ExecutionPlan` is partition-aware and is responsible for -/// creating the actual `async` [`SendableRecordBatchStream`]s -/// of [`RecordBatch`] that incrementally compute the operator's -/// output from its input partition. -/// -/// [`ExecutionPlan`] can be displayed in a simplified form using the -/// return value from [`displayable`] in addition to the (normally -/// quite verbose) `Debug` output. -pub trait ExecutionPlan: Debug + Send + Sync { - /// Returns the execution plan as [`Any`](std::any::Any) so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef; - - /// Specifies the output partitioning scheme of this plan - fn output_partitioning(&self) -> Partitioning; - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns an error to indicate this. - fn unbounded_output(&self, _children: &[bool]) -> Result { - Ok(false) - } - - /// If the output of this operator within each partition is sorted, - /// returns `Some(keys)` with the description of how it was sorted. - /// - /// For example, Sort, (obviously) produces sorted output as does - /// SortPreservingMergeStream. Less obviously `Projection` - /// produces sorted output if its input was sorted as it does not - /// reorder the input rows, - /// - /// It is safe to return `None` here if your operator does not - /// have any particular output order here - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; - - /// Specifies the data distribution requirements for all the - /// children for this operator, By default it's [[Distribution::UnspecifiedDistribution]] for each child, - fn required_input_distribution(&self) -> Vec { - vec![Distribution::UnspecifiedDistribution; self.children().len()] - } - - /// Specifies the ordering requirements for all of the children - /// For each child, it's the local ordering requirement within - /// each partition rather than the global ordering - /// - /// NOTE that checking `!is_empty()` does **not** check for a - /// required input ordering. Instead, the correct check is that at - /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec>> { - vec![None; self.children().len()] - } - - /// Returns `false` if this operator's implementation may reorder - /// rows within or between partitions. - /// - /// For example, Projection, Filter, and Limit maintain the order - /// of inputs -- they may transform values (Projection) or not - /// produce the same number of rows that went in (Filter and - /// Limit), but the rows that are produced go in the same way. - /// - /// DataFusion uses this metadata to apply certain optimizations - /// such as automatically repartitioning correctly. - /// - /// The default implementation returns `false` - /// - /// WARNING: if you override this default, you *MUST* ensure that - /// the operator's maintains the ordering invariant or else - /// DataFusion may produce incorrect results. - fn maintains_input_order(&self) -> Vec { - vec![false; self.children().len()] - } - - /// Returns `true` if this operator would benefit from - /// partitioning its input (and thus from more parallelism). For - /// operators that do very little work the overhead of extra - /// parallelism may outweigh any benefits - /// - /// The default implementation returns `true` unless this operator - /// has signalled it requires a single child input partition. - fn benefits_from_input_partitioning(&self) -> bool { - // By default try to maximize parallelism with more CPUs if - // possible - !self - .required_input_distribution() - .into_iter() - .any(|dist| matches!(dist, Distribution::SinglePartition)) - } - - /// Get the EquivalenceProperties within the plan - fn equivalence_properties(&self) -> EquivalenceProperties { - EquivalenceProperties::new(self.schema()) - } - - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - OrderingEquivalenceProperties::new(self.schema()) - } - - /// Get a list of child execution plans that provide the input for this plan. The returned list - /// will be empty for leaf nodes, will contain a single value for unary nodes, or two - /// values for binary nodes (such as joins). - fn children(&self) -> Vec>; - - /// Returns a new plan where all children were replaced by new plans. - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result>; - - /// creates an iterator - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result; - - /// Return a snapshot of the set of [`Metric`]s for this - /// [`ExecutionPlan`]. - /// - /// While the values of the metrics in the returned - /// [`MetricsSet`]s may change as execution progresses, the - /// specific metrics will not. - /// - /// Once `self.execute()` has returned (technically the future is - /// resolved) for all available partitions, the set of metrics - /// should be complete. If this function is called prior to - /// `execute()` new metrics may appear in subsequent calls. - fn metrics(&self) -> Option { - None - } - - /// Format this `ExecutionPlan` to `f` in the specified type. - /// - /// Should not include a newline - /// - /// Note this function prints a placeholder by default to preserve - /// backwards compatibility. - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ExecutionPlan(PlaceHolder)") - } - - /// Returns the global output statistics for this `ExecutionPlan` node. - fn statistics(&self) -> Statistics; -} - -/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful -/// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently there are 3 kinds of execution plan which needs data exchange -/// 1. RepartitionExec for changing the partition number between two operators -/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee -/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee -pub fn need_data_exchange(plan: Arc) -> bool { - if let Some(repart) = plan.as_any().downcast_ref::() { - !matches!( - repart.output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } -} - -/// Returns a copy of this plan if we change any child according to the pointer comparison. -/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. -pub fn with_new_children_if_necessary( - plan: Arc, - children: Vec>, -) -> Result>> { - let old_children = plan.children(); - if children.len() != old_children.len() { - Err(DataFusionError::Internal( - "Wrong number of children".to_string(), - )) - } else if children.is_empty() - || children - .iter() - .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) - { - Ok(Transformed::Yes(plan.with_new_children(children)?)) - } else { - Ok(Transformed::No(plan)) - } -} - -/// Return a [wrapper](DisplayableExecutionPlan) around an -/// [`ExecutionPlan`] which can be displayed in various easier to -/// understand ways. -/// -/// ``` -/// use datafusion::prelude::*; -/// use datafusion::physical_plan::displayable; -/// use object_store::path::Path; -/// -/// #[tokio::main] -/// async fn main() { -/// // Hard code target_partitions as it appears in the RepartitionExec output -/// let config = SessionConfig::new() -/// .with_target_partitions(3); -/// let mut ctx = SessionContext::with_config(config); -/// -/// // register the a table -/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await.unwrap(); -/// -/// // create a plan to run a SQL query -/// let dataframe = ctx.sql("SELECT a FROM example WHERE a < 5").await.unwrap(); -/// let physical_plan = dataframe.create_physical_plan().await.unwrap(); -/// -/// // Format using display string -/// let displayable_plan = displayable(physical_plan.as_ref()); -/// let plan_string = format!("{}", displayable_plan.indent()); -/// -/// let working_directory = std::env::current_dir().unwrap(); -/// let normalized = Path::from_filesystem_path(working_directory).unwrap(); -/// let plan_string = plan_string.replace(normalized.as_ref(), "WORKING_DIR"); -/// -/// assert_eq!("CoalesceBatchesExec: target_batch_size=8192\ -/// \n FilterExec: a@0 < 5\ -/// \n RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1\ -/// \n CsvExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.csv]]}, projection=[a], has_header=true", -/// plan_string.trim()); -/// -/// let one_line = format!("{}", displayable_plan.one_line()); -/// assert_eq!("CoalesceBatchesExec: target_batch_size=8192", one_line.trim()); -/// } -/// ``` -/// -pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { - DisplayableExecutionPlan::new(plan) -} - -/// Visit all children of this plan, according to the order defined on `ExecutionPlanVisitor`. -// Note that this would be really nice if it were a method on -// ExecutionPlan, but it can not be because it takes a generic -// parameter and `ExecutionPlan` is a trait -pub fn accept( - plan: &dyn ExecutionPlan, - visitor: &mut V, -) -> Result<(), V::Error> { - visitor.pre_visit(plan)?; - for child in plan.children() { - visit_execution_plan(child.as_ref(), visitor)?; - } - visitor.post_visit(plan)?; - Ok(()) -} - -/// Trait that implements the [Visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for a -/// depth first walk of `ExecutionPlan` nodes. `pre_visit` is called -/// before any children are visited, and then `post_visit` is called -/// after all children have been visited. -//// -/// To use, define a struct that implements this trait and then invoke -/// ['accept']. -/// -/// For example, for an execution plan that looks like: -/// -/// ```text -/// ProjectionExec: id -/// FilterExec: state = CO -/// CsvExec: -/// ``` -/// -/// The sequence of visit operations would be: -/// ```text -/// visitor.pre_visit(ProjectionExec) -/// visitor.pre_visit(FilterExec) -/// visitor.pre_visit(CsvExec) -/// visitor.post_visit(CsvExec) -/// visitor.post_visit(FilterExec) -/// visitor.post_visit(ProjectionExec) -/// ``` -pub trait ExecutionPlanVisitor { - /// The type of error returned by this visitor - type Error; - - /// Invoked on an `ExecutionPlan` plan before any of its child - /// inputs have been visited. If Ok(true) is returned, the - /// recursion continues. If Err(..) or Ok(false) are returned, the - /// recursion stops immediately and the error, if any, is returned - /// to `accept` - fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result; - - /// Invoked on an `ExecutionPlan` plan *after* all of its child - /// inputs have been visited. The return value is handled the same - /// as the return value of `pre_visit`. The provided default - /// implementation returns `Ok(true)`. - fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { - Ok(true) - } -} - -/// Recursively calls `pre_visit` and `post_visit` for this node and -/// all of its children, as described on [`ExecutionPlanVisitor`] -pub fn visit_execution_plan( - plan: &dyn ExecutionPlan, - visitor: &mut V, -) -> Result<(), V::Error> { - visitor.pre_visit(plan)?; - for child in plan.children() { - visit_execution_plan(child.as_ref(), visitor)?; - } - visitor.post_visit(plan)?; - Ok(()) -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect( - plan: Arc, - context: Arc, -) -> Result> { - let stream = execute_stream(plan, context)?; - common::collect(stream).await -} - -/// Execute the [ExecutionPlan] and return a single stream of results -pub fn execute_stream( - plan: Arc, - context: Arc, -) -> Result { - match plan.output_partitioning().partition_count() { - 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0, context), - _ => { - // merge into a single partition - let plan = CoalescePartitionsExec::new(plan.clone()); - // CoalescePartitionsExec must produce a single partition - assert_eq!(1, plan.output_partitioning().partition_count()); - plan.execute(0, context) - } - } -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect_partitioned( - plan: Arc, - context: Arc, -) -> Result>> { - let streams = execute_stream_partitioned(plan, context)?; - - // Execute the plan and collect the results into batches. - let handles = streams - .into_iter() - .enumerate() - .map(|(idx, stream)| async move { - let handle = tokio::task::spawn(stream.try_collect()); - AbortOnDropSingle::new(handle).await.map_err(|e| { - DataFusionError::Execution(format!( - "collect_partitioned partition {idx} panicked: {e}" - )) - })? - }); - - futures::future::try_join_all(handles).await -} - -/// Execute the [ExecutionPlan] and return a vec with one stream per output partition -pub fn execute_stream_partitioned( - plan: Arc, - context: Arc, -) -> Result> { - let num_partitions = plan.output_partitioning().partition_count(); - let mut streams = Vec::with_capacity(num_partitions); - for i in 0..num_partitions { - streams.push(plan.execute(i, context.clone())?); - } - Ok(streams) -} - -/// Partitioning schemes supported by operators. -#[derive(Debug, Clone)] -pub enum Partitioning { - /// Allocate batches using a round-robin algorithm and the specified number of partitions - RoundRobinBatch(usize), - /// Allocate rows based on a hash of one of more expressions and the specified number of - /// partitions - Hash(Vec>, usize), - /// Unknown partitioning scheme with a known number of partitions - UnknownPartitioning(usize), -} - -impl Partitioning { - /// Returns the number of partitions in this partitioning scheme - pub fn partition_count(&self) -> usize { - use Partitioning::*; - match self { - RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, - } - } - - /// Returns true when the guarantees made by this [[Partitioning]] are sufficient to - /// satisfy the partitioning scheme mandated by the `required` [[Distribution]] - pub fn satisfy EquivalenceProperties>( - &self, - required: Distribution, - equal_properties: F, - ) -> bool { - match required { - Distribution::UnspecifiedDistribution => true, - Distribution::SinglePartition if self.partition_count() == 1 => true, - Distribution::HashPartitioned(required_exprs) => { - match self { - // Here we do not check the partition count for hash partitioning and assumes the partition count - // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins, - // then we need to have the partition count and hash functions validation. - Partitioning::Hash(partition_exprs, _) => { - let fast_match = - expr_list_eq_strict_order(&required_exprs, partition_exprs); - // If the required exprs do not match, need to leverage the eq_properties provided by the child - // and normalize both exprs based on the eq_properties - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required_exprs - .iter() - .map(|e| { - normalize_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_partition_exprs = partition_exprs - .iter() - .map(|e| { - normalize_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - expr_list_eq_strict_order( - &normalized_required_exprs, - &normalized_partition_exprs, - ) - } else { - fast_match - } - } else { - fast_match - } - } - _ => false, - } - } - _ => false, - } - } -} - -impl PartialEq for Partitioning { - fn eq(&self, other: &Partitioning) -> bool { - match (self, other) { - ( - Partitioning::RoundRobinBatch(count1), - Partitioning::RoundRobinBatch(count2), - ) if count1 == count2 => true, - (Partitioning::Hash(exprs1, count1), Partitioning::Hash(exprs2, count2)) - if expr_list_eq_strict_order(exprs1, exprs2) && (count1 == count2) => - { - true - } - _ => false, - } - } -} - -/// Retrieves the ordering equivalence properties for a given schema and output ordering. -pub fn ordering_equivalence_properties_helper( - schema: SchemaRef, - eq_orderings: &[LexOrdering], -) -> OrderingEquivalenceProperties { - let mut oep = OrderingEquivalenceProperties::new(schema); - let first_ordering = if let Some(first) = eq_orderings.first() { - first - } else { - // Return an empty OrderingEquivalenceProperties: - return oep; - }; - // First entry among eq_orderings is the head, skip it: - for ordering in eq_orderings.iter().skip(1) { - if !ordering.is_empty() { - oep.add_equal_conditions((first_ordering, ordering)) - } - } - oep -} - -/// Distribution schemes -#[derive(Debug, Clone)] -pub enum Distribution { - /// Unspecified distribution - UnspecifiedDistribution, - /// A single partition is required - SinglePartition, - /// Requires children to be distributed in such a way that the same - /// values of the keys end up in the same partition - HashPartitioned(Vec>), -} - -impl Distribution { - /// Creates a Partitioning for this Distribution to satisfy itself - pub fn create_partitioning(&self, partition_count: usize) -> Partitioning { - match self { - Distribution::UnspecifiedDistribution => { - Partitioning::UnknownPartitioning(partition_count) - } - Distribution::SinglePartition => Partitioning::UnknownPartitioning(1), - Distribution::HashPartitioned(expr) => { - Partitioning::Hash(expr.clone(), partition_count) - } - } - } -} - -use datafusion_physical_expr::expressions::Column; -pub use datafusion_physical_expr::window::WindowExpr; -use datafusion_physical_expr::{ - expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, LexOrdering, -}; -pub use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; - -/// Applies an optional projection to a [`SchemaRef`], returning the -/// projected schema -/// -/// Example: -/// ``` -/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; -/// use datafusion::physical_plan::project_schema; -/// -/// // Schema with columns 'a', 'b', and 'c' -/// let schema = SchemaRef::new(Schema::new(vec![ -/// Field::new("a", DataType::Int32, true), -/// Field::new("b", DataType::Int64, true), -/// Field::new("c", DataType::Utf8, true), -/// ])); -/// -/// // Pick columns 'c' and 'b' -/// let projection = Some(vec![2,1]); -/// let projected_schema = project_schema( -/// &schema, -/// projection.as_ref() -/// ).unwrap(); -/// -/// let expected_schema = SchemaRef::new(Schema::new(vec![ -/// Field::new("c", DataType::Utf8, true), -/// Field::new("b", DataType::Int64, true), -/// ])); -/// -/// assert_eq!(projected_schema, expected_schema); -/// ``` -pub fn project_schema( - schema: &SchemaRef, - projection: Option<&Vec>, -) -> Result { - let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), - None => Arc::clone(schema), - }; - Ok(schema) -} - -pub mod aggregates; -pub mod analyze; -pub mod coalesce_batches; -pub mod coalesce_partitions; -pub mod common; -pub mod display; -pub mod empty; -pub mod explain; -pub mod filter; -pub mod insert; -pub mod joins; -pub mod limit; -pub mod memory; -pub mod metrics; -pub mod planner; -pub mod projection; -pub mod repartition; -pub mod sorts; -pub mod stream; -pub mod streaming; -pub mod tree_node; -pub mod udaf; -pub mod union; -pub mod unnest; -pub mod values; -pub mod windows; - -use crate::physical_plan::common::AbortOnDropSingle; -use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_execution::TaskContext; -pub use datafusion_physical_expr::{expressions, functions, hash_utils, udf}; - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - use arrow::datatypes::Schema; - - use crate::physical_plan::Distribution; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use datafusion_physical_expr::expressions::Column; - - use std::sync::Arc; - - #[tokio::test] - async fn partitioning_satisfy_distribution() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - arrow::datatypes::Field::new("column_1", DataType::Int64, false), - arrow::datatypes::Field::new("column_2", DataType::Utf8, false), - ])); - - let partition_exprs1: Vec> = vec![ - Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), - Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), - ]; - - let partition_exprs2: Vec> = vec![ - Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), - Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), - ]; - - let distribution_types = vec![ - Distribution::UnspecifiedDistribution, - Distribution::SinglePartition, - Distribution::HashPartitioned(partition_exprs1.clone()), - ]; - - let single_partition = Partitioning::UnknownPartitioning(1); - let unspecified_partition = Partitioning::UnknownPartitioning(10); - let round_robin_partition = Partitioning::RoundRobinBatch(10); - let hash_partition1 = Partitioning::Hash(partition_exprs1, 10); - let hash_partition2 = Partitioning::Hash(partition_exprs2, 10); - - for distribution in distribution_types { - let result = ( - single_partition.satisfy(distribution.clone(), || { - EquivalenceProperties::new(schema.clone()) - }), - unspecified_partition.satisfy(distribution.clone(), || { - EquivalenceProperties::new(schema.clone()) - }), - round_robin_partition.satisfy(distribution.clone(), || { - EquivalenceProperties::new(schema.clone()) - }), - hash_partition1.satisfy(distribution.clone(), || { - EquivalenceProperties::new(schema.clone()) - }), - hash_partition2.satisfy(distribution.clone(), || { - EquivalenceProperties::new(schema.clone()) - }), - ); - - match distribution { - Distribution::UnspecifiedDistribution => { - assert_eq!(result, (true, true, true, true, true)) - } - Distribution::SinglePartition => { - assert_eq!(result, (true, false, false, false, false)) - } - Distribution::HashPartitioned(_) => { - assert_eq!(result, (false, false, false, true, false)) - } - } - } - - Ok(()) - } -} diff --git a/datafusion/core/src/physical_plan/projection.rs b/datafusion/core/src/physical_plan/projection.rs deleted file mode 100644 index 5eb578334ed2e..0000000000000 --- a/datafusion/core/src/physical_plan/projection.rs +++ /dev/null @@ -1,572 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines the projection execution plan. A projection determines which columns or expressions -//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example -//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the -//! projection expressions. `SELECT` without `FROM` will only evaluate expressions. - -use std::any::Any; -use std::collections::HashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use crate::physical_plan::{ - ColumnStatistics, DisplayFormatType, EquivalenceProperties, ExecutionPlan, - Partitioning, PhysicalExpr, -}; -use arrow::datatypes::{Field, Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion_common::Result; -use datafusion_execution::TaskContext; -use futures::stream::{Stream, StreamExt}; -use log::trace; - -use super::expressions::{Column, PhysicalSortExpr}; -use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; - -use datafusion_physical_expr::{ - normalize_out_expr_with_columns_map, project_equivalence_properties, - project_ordering_equivalence_properties, OrderingEquivalenceProperties, -}; - -/// Execution plan for a projection -#[derive(Debug)] -pub struct ProjectionExec { - /// The projection expressions stored as tuples of (expression, output column name) - pub(crate) expr: Vec<(Arc, String)>, - /// The schema once the projection has been applied to the input - schema: SchemaRef, - /// The input plan - input: Arc, - /// The output ordering - output_ordering: Option>, - /// The columns map used to normalize out expressions like Partitioning and PhysicalSortExpr - /// The key is the column from the input schema and the values are the columns from the output schema - columns_map: HashMap>, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, -} - -impl ProjectionExec { - /// Create a projection on an input - pub fn try_new( - expr: Vec<(Arc, String)>, - input: Arc, - ) -> Result { - let input_schema = input.schema(); - - let fields: Result> = expr - .iter() - .map(|(e, name)| { - let mut field = Field::new( - name, - e.data_type(&input_schema)?, - e.nullable(&input_schema)?, - ); - field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), - ); - - Ok(field) - }) - .collect(); - - let schema = Arc::new(Schema::new_with_metadata( - fields?, - input_schema.metadata().clone(), - )); - - // construct a map from the input columns to the output columns of the Projection - let mut columns_map: HashMap> = HashMap::new(); - for (expression, name) in expr.iter() { - if let Some(column) = expression.as_any().downcast_ref::() { - // For some executors, logical and physical plan schema fields - // are not the same. The information in a `Column` comes from - // the logical plan schema. Therefore, to produce correct results - // we use the field in the input schema with the same index. This - // corresponds to the physical plan `Column`. - let idx = column.index(); - let matching_input_field = input_schema.field(idx); - let matching_input_column = Column::new(matching_input_field.name(), idx); - let new_col_idx = schema.index_of(name)?; - let entry = columns_map - .entry(matching_input_column) - .or_insert_with(Vec::new); - entry.push(Column::new(name, new_col_idx)); - }; - } - - // Output Ordering need to respect the alias - let child_output_ordering = input.output_ordering(); - let output_ordering = match child_output_ordering { - Some(sort_exprs) => { - let normalized_exprs = sort_exprs - .iter() - .map(|sort_expr| { - let expr = normalize_out_expr_with_columns_map( - sort_expr.expr.clone(), - &columns_map, - ); - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }) - .collect::>(); - Some(normalized_exprs) - } - None => None, - }; - - Ok(Self { - expr, - schema, - input: input.clone(), - output_ordering, - columns_map, - metrics: ExecutionPlanMetricsSet::new(), - }) - } - - /// The projection expressions stored as tuples of (expression, output column name) - pub fn expr(&self) -> &[(Arc, String)] { - &self.expr - } - - /// The input plan - pub fn input(&self) -> &Arc { - &self.input - } -} - -impl ExecutionPlan for ProjectionExec { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns an error to indicate this. - fn unbounded_output(&self, children: &[bool]) -> Result { - Ok(children[0]) - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - /// Get the output partitioning of this plan - fn output_partitioning(&self) -> Partitioning { - // Output partition need to respect the alias - let input_partition = self.input.output_partitioning(); - match input_partition { - Partitioning::Hash(exprs, part) => { - let normalized_exprs = exprs - .into_iter() - .map(|expr| { - normalize_out_expr_with_columns_map(expr, &self.columns_map) - }) - .collect::>(); - - Partitioning::Hash(normalized_exprs, part) - } - _ => input_partition, - } - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.output_ordering.as_deref() - } - - fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input - vec![true] - } - - fn equivalence_properties(&self) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(self.schema()); - project_equivalence_properties( - self.input.equivalence_properties(), - &self.columns_map, - &mut new_properties, - ); - new_properties - } - - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - let mut new_properties = OrderingEquivalenceProperties::new(self.schema()); - project_ordering_equivalence_properties( - self.input.ordering_equivalence_properties(), - &self.columns_map, - &mut new_properties, - ); - new_properties - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(ProjectionExec::try_new( - self.expr.clone(), - children[0].clone(), - )?)) - } - - fn benefits_from_input_partitioning(&self) -> bool { - let all_column_expr = self - .expr - .iter() - .all(|(e, _)| e.as_any().downcast_ref::().is_some()); - // If expressions are all column_expr, then all computations in this projection are reorder or rename, - // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. - !all_column_expr - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - Ok(Box::pin(ProjectionStream { - schema: self.schema.clone(), - expr: self.expr.iter().map(|x| x.0.clone()).collect(), - input: self.input.execute(partition, context)?, - baseline_metrics: BaselineMetrics::new(&self.metrics, partition), - })) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let expr: Vec = self - .expr - .iter() - .map(|(e, alias)| { - let e = e.to_string(); - if &e != alias { - format!("{e} as {alias}") - } else { - e - } - }) - .collect(); - - write!(f, "ProjectionExec: expr=[{}]", expr.join(", ")) - } - } - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Statistics { - stats_projection( - self.input.statistics(), - self.expr.iter().map(|(e, _)| Arc::clone(e)), - ) - } -} - -/// If e is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - let name = if let Some(column) = e.as_any().downcast_ref::() { - column.name() - } else { - return None; - }; - - input_schema - .field_with_name(name) - .ok() - .map(|f| f.metadata().clone()) -} - -fn stats_projection( - stats: Statistics, - exprs: impl Iterator>, -) -> Statistics { - let column_statistics = stats.column_statistics.map(|input_col_stats| { - exprs - .map(|e| { - if let Some(col) = e.as_any().downcast_ref::() { - input_col_stats[col.index()].clone() - } else { - // TODO stats: estimate more statistics from expressions - // (expressions should compute their statistics themselves) - ColumnStatistics::default() - } - }) - .collect() - }); - - Statistics { - is_exact: stats.is_exact, - num_rows: stats.num_rows, - column_statistics, - // TODO stats: knowing the type of the new columns we can guess the output size - total_byte_size: None, - } -} - -impl ProjectionStream { - fn batch_project(&self, batch: &RecordBatch) -> Result { - // records time on drop - let _timer = self.baseline_metrics.elapsed_compute().timer(); - let arrays = self - .expr - .iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - - if arrays.is_empty() { - let options = - RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - RecordBatch::try_new_with_options(self.schema.clone(), arrays, &options) - .map_err(Into::into) - } else { - RecordBatch::try_new(self.schema.clone(), arrays).map_err(Into::into) - } - } -} - -/// Projection iterator -struct ProjectionStream { - schema: SchemaRef, - expr: Vec>, - input: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, -} - -impl Stream for ProjectionStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let poll = self.input.poll_next_unpin(cx).map(|x| match x { - Some(Ok(batch)) => Some(self.batch_project(&batch)), - other => other, - }); - - self.baseline_metrics.record_poll(poll) - } - - fn size_hint(&self) -> (usize, Option) { - // same number of record batches - self.input.size_hint() - } -} - -impl RecordBatchStream for ProjectionStream { - /// Get the schema - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::physical_plan::common::collect; - use crate::physical_plan::expressions::{self, col}; - use crate::prelude::SessionContext; - use crate::test::{self}; - use crate::test_util; - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::binary; - use futures::future; - - // Create a binary expression without coercion. Used here when we do not want to coerce the expressions - // to valid types. Usage can result in an execution (after plan) error. - fn binary_simple( - l: Arc, - op: Operator, - r: Arc, - input_schema: &Schema, - ) -> Arc { - binary(l, op, r, input_schema).unwrap() - } - - #[tokio::test] - async fn project_first_column() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - - let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; - - // pick column c1 and name it column c1 in the output schema - let projection = - ProjectionExec::try_new(vec![(col("c1", &schema)?, "c1".to_string())], csv)?; - - let col_field = projection.schema.field(0); - let col_metadata = col_field.metadata(); - let data: &str = &col_metadata["testing"]; - assert_eq!(data, "test"); - - let mut partition_count = 0; - let mut row_count = 0; - for partition in 0..projection.output_partitioning().partition_count() { - partition_count += 1; - let stream = projection.execute(partition, task_ctx.clone())?; - - row_count += stream - .map(|batch| { - let batch = batch.unwrap(); - assert_eq!(1, batch.num_columns()); - batch.num_rows() - }) - .fold(0, |acc, x| future::ready(acc + x)) - .await; - } - assert_eq!(partitions, partition_count); - assert_eq!(100, row_count); - - Ok(()) - } - - #[tokio::test] - async fn project_input_not_partitioning() -> Result<()> { - let schema = test_util::aggr_test_schema(); - - let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; - - // pick column c1 and name it column c1 in the output schema - let projection = - ProjectionExec::try_new(vec![(col("c1", &schema)?, "c1".to_string())], csv)?; - assert!(!projection.benefits_from_input_partitioning()); - Ok(()) - } - - #[tokio::test] - async fn project_input_partitioning() -> Result<()> { - let schema = test_util::aggr_test_schema(); - - let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; - - let c1 = col("c2", &schema).unwrap(); - let c2 = col("c9", &schema).unwrap(); - let c1_plus_c2 = binary_simple(c1, Operator::Plus, c2, &schema); - - let projection = - ProjectionExec::try_new(vec![(c1_plus_c2, "c2 + c9".to_string())], csv)?; - - assert!(projection.benefits_from_input_partitioning()); - Ok(()) - } - - #[tokio::test] - async fn project_no_column() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - - let csv = test::scan_partitioned_csv(1)?; - let expected = collect(csv.execute(0, task_ctx.clone())?).await.unwrap(); - - let projection = ProjectionExec::try_new(vec![], csv)?; - let stream = projection.execute(0, task_ctx.clone())?; - let output = collect(stream).await.unwrap(); - assert_eq!(output.len(), expected.len()); - - Ok(()) - } - - #[tokio::test] - async fn test_stats_projection_columns_only() { - let source = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(23), - column_statistics: Some(vec![ - ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), - }, - ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), - }, - ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Float32(Some(1.1))), - min_value: Some(ScalarValue::Float32(Some(0.1))), - null_count: None, - }, - ]), - }; - - let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col1", 1)), - Arc::new(expressions::Column::new("col0", 0)), - ]; - - let result = stats_projection(source, exprs.into_iter()); - - let expected = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: None, - column_statistics: Some(vec![ - ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), - }, - ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), - }, - ]), - }; - - assert_eq!(result, expected); - } -} diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs deleted file mode 100644 index a9e5122130572..0000000000000 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ /dev/null @@ -1,419 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::physical_plan::sorts::sort::SortOptions; -use arrow::buffer::ScalarBuffer; -use arrow::datatypes::ArrowNativeTypeOp; -use arrow::row::{Row, Rows}; -use arrow_array::types::ByteArrayType; -use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray}; -use std::cmp::Ordering; - -/// A [`Cursor`] for [`Rows`] -pub struct RowCursor { - cur_row: usize, - num_rows: usize, - - rows: Rows, -} - -impl std::fmt::Debug for RowCursor { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .finish() - } -} - -impl RowCursor { - /// Create a new SortKeyCursor - pub fn new(rows: Rows) -> Self { - Self { - cur_row: 0, - num_rows: rows.num_rows(), - rows, - } - } - - /// Returns the current row - fn current(&self) -> Row<'_> { - self.rows.row(self.cur_row) - } -} - -impl PartialEq for RowCursor { - fn eq(&self, other: &Self) -> bool { - self.current() == other.current() - } -} - -impl Eq for RowCursor {} - -impl PartialOrd for RowCursor { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for RowCursor { - fn cmp(&self, other: &Self) -> Ordering { - self.current().cmp(&other.current()) - } -} - -/// A cursor into a sorted batch of rows -pub trait Cursor: Ord { - /// Returns true if there are no more rows in this cursor - fn is_finished(&self) -> bool; - - /// Advance the cursor, returning the previous row index - fn advance(&mut self) -> usize; -} - -impl Cursor for RowCursor { - #[inline] - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row - } - - #[inline] - fn advance(&mut self) -> usize { - let t = self.cur_row; - self.cur_row += 1; - t - } -} - -/// An [`Array`] that can be converted into [`FieldValues`] -pub trait FieldArray: Array + 'static { - type Values: FieldValues; - - fn values(&self) -> Self::Values; -} - -/// A comparable set of non-nullable values -pub trait FieldValues { - type Value: ?Sized; - - fn len(&self) -> usize; - - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering; - - fn value(&self, idx: usize) -> &Self::Value; -} - -impl FieldArray for PrimitiveArray { - type Values = PrimitiveValues; - - fn values(&self) -> Self::Values { - PrimitiveValues(self.values().clone()) - } -} - -#[derive(Debug)] -pub struct PrimitiveValues(ScalarBuffer); - -impl FieldValues for PrimitiveValues { - type Value = T; - - fn len(&self) -> usize { - self.0.len() - } - - #[inline] - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { - T::compare(*a, *b) - } - - #[inline] - fn value(&self, idx: usize) -> &Self::Value { - &self.0[idx] - } -} - -impl FieldArray for GenericByteArray { - type Values = Self; - - fn values(&self) -> Self::Values { - // Once https://github.com/apache/arrow-rs/pull/4048 is released - // Could potentially destructure array into buffers to reduce codegen, - // in a similar vein to what is done for PrimitiveArray - self.clone() - } -} - -impl FieldValues for GenericByteArray { - type Value = T::Native; - - fn len(&self) -> usize { - Array::len(self) - } - - #[inline] - fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { - let a: &[u8] = a.as_ref(); - let b: &[u8] = b.as_ref(); - a.cmp(b) - } - - #[inline] - fn value(&self, idx: usize) -> &Self::Value { - self.value(idx) - } -} - -/// A cursor over sorted, nullable [`FieldValues`] -/// -/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering -#[derive(Debug)] -pub struct FieldCursor { - values: T, - offset: usize, - // If nulls first, the first non-null index - // Otherwise, the first null index - null_threshold: usize, - options: SortOptions, -} - -impl FieldCursor { - /// Create a new [`FieldCursor`] from the provided `values` sorted according to `options` - pub fn new>(options: SortOptions, array: &A) -> Self { - let null_threshold = match options.nulls_first { - true => array.null_count(), - false => array.len() - array.null_count(), - }; - - Self { - values: array.values(), - offset: 0, - null_threshold, - options, - } - } - - fn is_null(&self) -> bool { - (self.offset < self.null_threshold) == self.options.nulls_first - } -} - -impl PartialEq for FieldCursor { - fn eq(&self, other: &Self) -> bool { - self.cmp(other).is_eq() - } -} - -impl Eq for FieldCursor {} -impl PartialOrd for FieldCursor { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for FieldCursor { - fn cmp(&self, other: &Self) -> Ordering { - match (self.is_null(), other.is_null()) { - (true, true) => Ordering::Equal, - (true, false) => match self.options.nulls_first { - true => Ordering::Less, - false => Ordering::Greater, - }, - (false, true) => match self.options.nulls_first { - true => Ordering::Greater, - false => Ordering::Less, - }, - (false, false) => { - let s_v = self.values.value(self.offset); - let o_v = other.values.value(other.offset); - - match self.options.descending { - true => T::compare(o_v, s_v), - false => T::compare(s_v, o_v), - } - } - } - } -} - -impl Cursor for FieldCursor { - fn is_finished(&self) -> bool { - self.offset == self.values.len() - } - - fn advance(&mut self) -> usize { - let t = self.offset; - self.offset += 1; - t - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn new_primitive( - options: SortOptions, - values: ScalarBuffer, - null_count: usize, - ) -> FieldCursor> { - let null_threshold = match options.nulls_first { - true => null_count, - false => values.len() - null_count, - }; - - FieldCursor { - offset: 0, - values: PrimitiveValues(values), - null_threshold, - options, - } - } - - #[test] - fn test_primitive_nulls_first() { - let options = SortOptions { - descending: false, - nulls_first: true, - }; - - let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]); - let mut a = new_primitive(options, buffer, 1); - let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]); - let mut b = new_primitive(options, buffer, 2); - - // NULL == NULL - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - // NULL == NULL - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - // NULL < -2 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 1 > -2 - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Greater); - - // 1 > -1 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Greater); - - // 1 == 1 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - // 9 > 1 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 9 > 2 - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - let options = SortOptions { - descending: false, - nulls_first: false, - }; - - let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]); - let mut a = new_primitive(options, buffer, 2); - let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]); - let mut b = new_primitive(options, buffer, 2); - - // 0 > -1 - assert_eq!(a.cmp(&b), Ordering::Greater); - - // 0 < NULL - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 1 < NULL - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // NULL = NULL - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - let options = SortOptions { - descending: true, - nulls_first: false, - }; - - let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]); - let mut a = new_primitive(options, buffer, 3); - let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]); - let mut b = new_primitive(options, buffer, 2); - - // 6 > 67 - assert_eq!(a.cmp(&b), Ordering::Greater); - - // 6 < -3 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 6 < NULL - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 6 < NULL - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // NULL == NULL - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - let options = SortOptions { - descending: true, - nulls_first: true, - }; - - let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]); - let mut a = new_primitive(options, buffer, 2); - let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]); - let mut b = new_primitive(options, buffer, 1); - - // NULL == NULL - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - // NULL == NULL - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Equal); - assert_eq!(a, b); - - // NULL < 4546 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - - // 6 > 4546 - a.advance(); - assert_eq!(a.cmp(&b), Ordering::Greater); - - // 6 < -3 - b.advance(); - assert_eq!(a.cmp(&b), Ordering::Less); - } -} diff --git a/datafusion/core/src/physical_plan/streaming.rs b/datafusion/core/src/physical_plan/streaming.rs deleted file mode 100644 index 0555c1ce2899d..0000000000000 --- a/datafusion/core/src/physical_plan/streaming.rs +++ /dev/null @@ -1,132 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Execution plan for streaming [`PartitionStream`] - -use std::any::Any; -use std::sync::Arc; - -use arrow::datatypes::SchemaRef; -use async_trait::async_trait; -use futures::stream::StreamExt; - -use datafusion_common::{DataFusionError, Result, Statistics}; -use datafusion_physical_expr::PhysicalSortExpr; - -use crate::datasource::streaming::PartitionStream; -use crate::physical_plan::stream::RecordBatchStreamAdapter; -use crate::physical_plan::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; -use datafusion_execution::TaskContext; - -/// An [`ExecutionPlan`] for [`PartitionStream`] -pub struct StreamingTableExec { - partitions: Vec>, - projection: Option>, - projected_schema: SchemaRef, - infinite: bool, -} - -impl StreamingTableExec { - /// Try to create a new [`StreamingTableExec`] returning an error if the schema is incorrect - pub fn try_new( - schema: SchemaRef, - partitions: Vec>, - projection: Option<&Vec>, - infinite: bool, - ) -> Result { - if !partitions.iter().all(|x| schema.contains(x.schema())) { - return Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )); - } - - let projected_schema = match projection { - Some(p) => Arc::new(schema.project(p)?), - None => schema, - }; - - Ok(Self { - partitions, - projected_schema, - projection: projection.cloned().map(Into::into), - infinite, - }) - } -} - -impl std::fmt::Debug for StreamingTableExec { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("LazyMemTableExec").finish_non_exhaustive() - } -} - -#[async_trait] -impl ExecutionPlan for StreamingTableExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.projected_schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(self.partitions.len()) - } - - fn unbounded_output(&self, _children: &[bool]) -> Result { - Ok(self.infinite) - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _children: Vec>, - ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {self:?}" - ))) - } - - fn execute( - &self, - partition: usize, - ctx: Arc, - ) -> Result { - let stream = self.partitions[partition].execute(ctx); - Ok(match self.projection.clone() { - Some(projection) => Box::pin(RecordBatchStreamAdapter::new( - self.projected_schema.clone(), - stream.map(move |x| { - x.and_then(|b| b.project(projection.as_ref()).map_err(Into::into)) - }), - )), - None => stream, - }) - } - - fn statistics(&self) -> Statistics { - Default::default() - } -} diff --git a/datafusion/core/src/physical_plan/unnest.rs b/datafusion/core/src/physical_plan/unnest.rs deleted file mode 100644 index cd42c3305f2d0..0000000000000 --- a/datafusion/core/src/physical_plan/unnest.rs +++ /dev/null @@ -1,300 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines the unnest column plan for unnesting values in a column that contains a list -//! type, conceptually is like joining each row with all the values in the list column. -use arrow::array::{ - new_null_array, Array, ArrayAccessor, ArrayRef, FixedSizeListArray, LargeListArray, - ListArray, -}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use async_trait::async_trait; -use datafusion_execution::TaskContext; -use futures::Stream; -use futures::StreamExt; -use log::trace; -use std::time::Instant; -use std::{any::Any, sync::Arc}; - -use crate::physical_plan::{ - coalesce_batches::concat_batches, expressions::Column, DisplayFormatType, - Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, PhysicalExpr, - PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, -}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; - -/// Unnest the given column by joining the row with each value in the nested type. -#[derive(Debug)] -pub struct UnnestExec { - /// Input execution plan - input: Arc, - /// The schema once the unnest is applied - schema: SchemaRef, - /// The unnest column - column: Column, -} - -impl UnnestExec { - /// Create a new [UnnestExec]. - pub fn new(input: Arc, column: Column, schema: SchemaRef) -> Self { - UnnestExec { - input, - schema, - column, - } - } -} - -impl ExecutionPlan for UnnestExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns an error to indicate this. - fn unbounded_output(&self, children: &[bool]) -> Result { - Ok(children[0]) - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(UnnestExec::new( - children[0].clone(), - self.column.clone(), - self.schema.clone(), - ))) - } - - fn required_input_distribution(&self) -> Vec { - vec![Distribution::UnspecifiedDistribution] - } - - fn output_partitioning(&self) -> Partitioning { - self.input.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let input = self.input.execute(partition, context)?; - - Ok(Box::pin(UnnestStream { - input, - schema: self.schema.clone(), - column: self.column.clone(), - num_input_batches: 0, - num_input_rows: 0, - num_output_batches: 0, - num_output_rows: 0, - unnest_time: 0, - })) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "UnnestExec") - } - } - } - - fn statistics(&self) -> Statistics { - Default::default() - } -} - -/// A stream that issues [RecordBatch]es with unnested column data. -struct UnnestStream { - /// Input stream - input: SendableRecordBatchStream, - /// Unnested schema - schema: Arc, - /// The unnest column - column: Column, - /// number of input batches - num_input_batches: usize, - /// number of input rows - num_input_rows: usize, - /// number of batches produced - num_output_batches: usize, - /// number of rows produced - num_output_rows: usize, - /// total time for column unnesting, in ms - unnest_time: usize, -} - -impl RecordBatchStream for UnnestStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[async_trait] -impl Stream for UnnestStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.poll_next_impl(cx) - } -} - -impl UnnestStream { - /// Separate implementation function that unpins the [`UnnestStream`] so - /// that partial borrows work correctly - fn poll_next_impl( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { - self.input - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(batch)) => { - let start = Instant::now(); - let result = build_batch(&batch, &self.schema, &self.column); - self.num_input_batches += 1; - self.num_input_rows += batch.num_rows(); - if let Ok(ref batch) = result { - self.unnest_time += start.elapsed().as_millis() as usize; - self.num_output_batches += 1; - self.num_output_rows += batch.num_rows(); - } - - Some(result) - } - other => { - trace!( - "Processed {} probe-side input batches containing {} rows and \ - produced {} output batches containing {} rows in {} ms", - self.num_input_batches, - self.num_input_rows, - self.num_output_batches, - self.num_output_rows, - self.unnest_time, - ); - other - } - }) - } -} - -fn build_batch( - batch: &RecordBatch, - schema: &SchemaRef, - column: &Column, -) -> Result { - let list_array = column.evaluate(batch)?.into_array(batch.num_rows()); - match list_array.data_type() { - arrow::datatypes::DataType::List(_) => { - let list_array = list_array.as_any().downcast_ref::().unwrap(); - unnest_batch(batch, schema, column, &list_array) - } - arrow::datatypes::DataType::LargeList(_) => { - let list_array = list_array - .as_any() - .downcast_ref::() - .unwrap(); - unnest_batch(batch, schema, column, &list_array) - } - arrow::datatypes::DataType::FixedSizeList(_, _) => { - let list_array = list_array - .as_any() - .downcast_ref::() - .unwrap(); - unnest_batch(batch, schema, column, list_array) - } - _ => Err(DataFusionError::Execution(format!( - "Invalid unnest column {column}" - ))), - } -} - -fn unnest_batch( - batch: &RecordBatch, - schema: &SchemaRef, - column: &Column, - list_array: &T, -) -> Result -where - T: ArrayAccessor, -{ - let mut batches = Vec::new(); - let mut num_rows = 0; - - for row in 0..batch.num_rows() { - let arrays = batch - .columns() - .iter() - .enumerate() - .map(|(col_idx, arr)| { - if col_idx == column.index() { - // Unnest the value at the given row. - if list_array.value(row).is_empty() { - // If nested array is empty add an array with 1 null. - Ok(new_null_array(list_array.value(row).data_type(), 1)) - } else { - Ok(list_array.value(row)) - } - } else { - // Number of elements to duplicate, use max(1) to handle null. - let nested_len = list_array.value(row).len().max(1); - // Duplicate rows for each value in the nested array. - if arr.is_null(row) { - Ok(new_null_array(arr.data_type(), nested_len)) - } else { - let scalar = ScalarValue::try_from_array(arr, row)?; - Ok(scalar.to_array_of_size(nested_len)) - } - } - }) - .collect::>>()?; - - let rb = RecordBatch::try_new(schema.clone(), arrays.to_vec())?; - num_rows += rb.num_rows(); - batches.push(rb); - } - - concat_batches(schema, &batches, num_rows).map_err(Into::into) -} diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs deleted file mode 100644 index 73a3eb10c28fe..0000000000000 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ /dev/null @@ -1,597 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Physical expressions for window functions - -use crate::physical_plan::{ - aggregates, - expressions::{ - cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, - PhysicalSortExpr, RowNumber, - }, - udaf, ExecutionPlan, PhysicalExpr, -}; -use arrow::datatypes::Schema; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - WindowFrame, -}; -use datafusion_physical_expr::window::{ - BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr, -}; -use std::borrow::Borrow; -use std::convert::TryInto; -use std::sync::Arc; - -mod bounded_window_agg_exec; -mod window_agg_exec; - -pub use bounded_window_agg_exec::BoundedWindowAggExec; -pub use bounded_window_agg_exec::PartitionSearchMode; -use datafusion_common::utils::longest_consecutive_prefix; -use datafusion_physical_expr::equivalence::OrderingEquivalenceBuilder; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{convert_to_expr, get_indices_of_matching_exprs}; -pub use datafusion_physical_expr::window::{ - BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, -}; -use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; -pub use window_agg_exec::WindowAggExec; - -/// Create a physical expression for window function -pub fn create_window_expr( - fun: &WindowFunction, - name: String, - args: &[Arc], - partition_by: &[Arc], - order_by: &[PhysicalSortExpr], - window_frame: Arc, - input_schema: &Schema, -) -> Result> { - Ok(match fun { - WindowFunction::AggregateFunction(fun) => { - let aggregate = - aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?; - if !window_frame.start_bound.is_unbounded() { - Arc::new(SlidingAggregateWindowExpr::new( - aggregate, - partition_by, - order_by, - window_frame, - )) - } else { - Arc::new(PlainAggregateWindowExpr::new( - aggregate, - partition_by, - order_by, - window_frame, - )) - } - } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new( - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - }) -} - -fn get_scalar_value_from_args( - args: &[Arc], - index: usize, -) -> Result> { - Ok(if let Some(field) = args.get(index) { - let tmp = field - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::NotImplemented( - format!("There is only support Literal types for field at idx: {index} in Window Function"), - ))? - .value() - .clone(); - Some(tmp) - } else { - None - }) -} - -fn create_built_in_window_expr( - fun: &BuiltInWindowFunction, - args: &[Arc], - input_schema: &Schema, - name: String, -) -> Result> { - Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)), - BuiltInWindowFunction::Rank => Arc::new(rank(name)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), - BuiltInWindowFunction::Ntile => { - let n: i64 = get_scalar_value_from_args(args, 0)? - .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires at least 1 argument".to_string(), - ) - })? - .try_into()?; - let n: u64 = n as u64; - Arc::new(Ntile::new(name, n)) - } - BuiltInWindowFunction::Lag => { - let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); - let default_value = get_scalar_value_from_args(args, 2)?; - Arc::new(lag(name, data_type, arg, shift_offset, default_value)) - } - BuiltInWindowFunction::Lead => { - let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); - let default_value = get_scalar_value_from_args(args, 2)?; - Arc::new(lead(name, data_type, arg, shift_offset, default_value)) - } - BuiltInWindowFunction::NthValue => { - let arg = args[0].clone(); - let n = args[1].as_any().downcast_ref::().unwrap().value(); - let n: i64 = n - .clone() - .try_into() - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - let n: u32 = n as u32; - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::nth(name, arg, data_type, n)?) - } - BuiltInWindowFunction::FirstValue => { - let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::first(name, arg, data_type)) - } - BuiltInWindowFunction::LastValue => { - let arg = args[0].clone(); - let data_type = args[0].data_type(input_schema)?; - Arc::new(NthValue::last(name, arg, data_type)) - } - }) -} - -pub(crate) fn calc_requirements< - T: Borrow>, - S: Borrow, ->( - partition_by_exprs: impl IntoIterator, - orderby_sort_exprs: impl IntoIterator, -) -> Option> { - let mut sort_reqs = partition_by_exprs - .into_iter() - .map(|partition_by| { - PhysicalSortRequirement::new(partition_by.borrow().clone(), None) - }) - .collect::>(); - for element in orderby_sort_exprs.into_iter() { - let PhysicalSortExpr { expr, options } = element.borrow(); - if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new(expr.clone(), Some(*options))); - } - } - // Convert empty result to None. Otherwise wrap result inside Some() - (!sort_reqs.is_empty()).then_some(sort_reqs) -} - -/// This function calculates the indices such that when partition by expressions reordered with this indices -/// resulting expressions define a preset for existing ordering. -// For instance, if input is ordered by a, b, c and PARTITION BY b, a is used -// This vector will be [1, 0]. It means that when we iterate b,a columns with the order [1, 0] -// resulting vector (a, b) is a preset of the existing ordering (a, b, c). -pub(crate) fn get_ordered_partition_by_indices( - partition_by_exprs: &[Arc], - input: &Arc, -) -> Vec { - let input_ordering = input.output_ordering().unwrap_or(&[]); - let input_ordering_exprs = convert_to_expr(input_ordering); - let equal_properties = || input.equivalence_properties(); - let input_places = get_indices_of_matching_exprs( - &input_ordering_exprs, - partition_by_exprs, - equal_properties, - ); - let mut partition_places = get_indices_of_matching_exprs( - partition_by_exprs, - &input_ordering_exprs, - equal_properties, - ); - partition_places.sort(); - let first_n = longest_consecutive_prefix(partition_places); - input_places[0..first_n].to_vec() -} - -pub(crate) fn window_ordering_equivalence( - schema: &SchemaRef, - input: &Arc, - window_expr: &[Arc], -) -> OrderingEquivalenceProperties { - // We need to update the schema, so we can not directly use - // `input.ordering_equivalence_properties()`. - let mut builder = OrderingEquivalenceBuilder::new(schema.clone()) - .with_equivalences(input.equivalence_properties()) - .with_existing_ordering(input.output_ordering().map(|elem| elem.to_vec())) - .extend(input.ordering_equivalence_properties()); - for expr in window_expr { - if let Some(builtin_window_expr) = - expr.as_any().downcast_ref::() - { - // Only the built-in `RowNumber` window function introduces a new - // ordering: - if builtin_window_expr - .get_built_in_func_expr() - .as_any() - .is::() - { - if let Some((idx, field)) = - schema.column_with_name(builtin_window_expr.name()) - { - let column = Column::new(field.name(), idx); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - let rhs = PhysicalSortExpr { - expr: Arc::new(column) as _, - options, - }; - builder.add_equal_conditions(vec![rhs]); - } - } - } - } - builder.build() -} -#[cfg(test)] -mod tests { - use super::*; - use crate::datasource::physical_plan::CsvExec; - use crate::physical_plan::aggregates::AggregateFunction; - use crate::physical_plan::expressions::col; - use crate::physical_plan::{collect, ExecutionPlan}; - use crate::prelude::SessionContext; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; - use crate::test::{self, assert_is_pending, csv_exec_sorted}; - use arrow::array::*; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, SchemaRef}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_primitive_array; - use datafusion_expr::{create_udaf, Accumulator, Volatility}; - use futures::FutureExt; - - fn create_test_schema(partitions: usize) -> Result<(Arc, SchemaRef)> { - let csv = test::scan_partitioned_csv(partitions)?; - let schema = csv.schema(); - Ok((csv, schema)) - } - - fn create_test_schema2() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); - Ok(schema) - } - - /// make PhysicalSortExpr with default options - fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) - } - - /// PhysicalSortExpr with specified options - fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, - ) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } - } - - #[tokio::test] - async fn test_get_partition_by_ordering() -> Result<()> { - let test_schema = create_test_schema2()?; - // Columns a,c are nullable whereas b,d are not nullable. - // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST - // Column e is not ordered. - let sort_exprs = vec![ - sort_expr("a", &test_schema), - sort_expr("b", &test_schema), - sort_expr("c", &test_schema), - sort_expr("d", &test_schema), - ]; - // Input is ordered by a,b,c,d - let input = csv_exec_sorted(&test_schema, sort_exprs, true); - let test_data = vec![ - (vec!["a", "b"], vec![0, 1]), - (vec!["b", "a"], vec![1, 0]), - (vec!["b", "a", "c"], vec![1, 0, 2]), - (vec!["d", "b", "a"], vec![2, 1]), - (vec!["d", "e", "a"], vec![2]), - ]; - for (pb_names, expected) in test_data { - let pb_exprs = pb_names - .iter() - .map(|name| col(name, &test_schema)) - .collect::>>()?; - assert_eq!( - get_ordered_partition_by_indices(&pb_exprs, &input), - expected - ); - } - Ok(()) - } - - #[tokio::test] - async fn test_calc_requirements() -> Result<()> { - let schema = create_test_schema2()?; - let test_data = vec![ - // PARTITION BY a, ORDER BY b ASC NULLS FIRST - ( - vec!["a"], - vec![("b", true, true)], - vec![("a", None), ("b", Some((true, true)))], - ), - // PARTITION BY a, ORDER BY a ASC NULLS FIRST - (vec!["a"], vec![("a", true, true)], vec![("a", None)]), - // PARTITION BY a, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST - ( - vec!["a"], - vec![("b", true, true), ("c", false, false)], - vec![ - ("a", None), - ("b", Some((true, true))), - ("c", Some((false, false))), - ], - ), - // PARTITION BY a, c, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST - ( - vec!["a", "c"], - vec![("b", true, true), ("c", false, false)], - vec![("a", None), ("c", None), ("b", Some((true, true)))], - ), - ]; - for (pb_params, ob_params, expected_params) in test_data { - let mut partitionbys = vec![]; - for col_name in pb_params { - partitionbys.push(col(col_name, &schema)?); - } - - let mut orderbys = vec![]; - for (col_name, descending, nulls_first) in ob_params { - let expr = col(col_name, &schema)?; - let options = SortOptions { - descending, - nulls_first, - }; - orderbys.push(PhysicalSortExpr { expr, options }); - } - - let mut expected: Option> = None; - for (col_name, reqs) in expected_params { - let options = reqs.map(|(descending, nulls_first)| SortOptions { - descending, - nulls_first, - }); - let expr = col(col_name, &schema)?; - let res = PhysicalSortRequirement::new(expr, options); - if let Some(expected) = &mut expected { - expected.push(res); - } else { - expected = Some(vec![res]); - } - } - assert_eq!(calc_requirements(partitionbys, orderbys), expected); - } - Ok(()) - } - - #[tokio::test] - async fn window_function_with_udaf() -> Result<()> { - #[derive(Debug)] - struct MyCount(i64); - - impl Accumulator for MyCount { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.0))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.0 += (array.len() - array.null_count()) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts: &Int64Array = arrow::array::as_primitive_array(&states[0]); - if let Some(c) = &arrow::compute::sum(counts) { - self.0 += *c; - } - Ok(()) - } - - fn evaluate(&self) -> Result { - Ok(ScalarValue::Int64(Some(self.0))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - } - - let my_count = create_udaf( - "my_count", - DataType::Int64, - Arc::new(DataType::Int64), - Volatility::Immutable, - Arc::new(|_| Ok(Box::new(MyCount(0)))), - Arc::new(vec![DataType::Int64]), - ); - - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (input, schema) = create_test_schema(1)?; - - let window_exec = Arc::new(WindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateUDF(Arc::new(my_count)), - "my_count".to_owned(), - &[col("c3", &schema)?], - &[], - &[], - Arc::new(WindowFrame::new(false)), - schema.as_ref(), - )?], - input, - schema.clone(), - vec![], - )?); - - let result: Vec = collect(window_exec, task_ctx).await?; - assert_eq!(result.len(), 1); - - let n_schema_fields = schema.fields().len(); - let columns = result[0].columns(); - - let count: &Int64Array = as_primitive_array(&columns[n_schema_fields])?; - assert_eq!(count.value(0), 100); - assert_eq!(count.value(99), 100); - Ok(()) - } - - #[tokio::test] - async fn window_function() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let (input, schema) = create_test_schema(1)?; - - let window_exec = Arc::new(WindowAggExec::try_new( - vec![ - create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col("c3", &schema)?], - &[], - &[], - Arc::new(WindowFrame::new(false)), - schema.as_ref(), - )?, - create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Max), - "max".to_owned(), - &[col("c3", &schema)?], - &[], - &[], - Arc::new(WindowFrame::new(false)), - schema.as_ref(), - )?, - create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Min), - "min".to_owned(), - &[col("c3", &schema)?], - &[], - &[], - Arc::new(WindowFrame::new(false)), - schema.as_ref(), - )?, - ], - input, - schema.clone(), - vec![], - )?); - - let result: Vec = collect(window_exec, task_ctx).await?; - assert_eq!(result.len(), 1); - - let n_schema_fields = schema.fields().len(); - let columns = result[0].columns(); - - // c3 is small int - - let count: &Int64Array = as_primitive_array(&columns[n_schema_fields])?; - assert_eq!(count.value(0), 100); - assert_eq!(count.value(99), 100); - - let max: &Int8Array = as_primitive_array(&columns[n_schema_fields + 1])?; - assert_eq!(max.value(0), 125); - assert_eq!(max.value(99), 125); - - let min: &Int8Array = as_primitive_array(&columns[n_schema_fields + 2])?; - assert_eq!(min.value(0), -117); - assert_eq!(min.value(99), -117); - - Ok(()) - } - - #[tokio::test] - async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - - let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); - let refs = blocking_exec.refs(); - let window_agg_exec = Arc::new(WindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col("a", &schema)?], - &[], - &[], - Arc::new(WindowFrame::new(false)), - schema.as_ref(), - )?], - blocking_exec, - schema, - vec![], - )?); - - let fut = collect(window_agg_exec, task_ctx); - let mut fut = fut.boxed(); - - assert_is_pending(&mut fut); - drop(fut); - assert_strong_count_converges_to_zero(refs).await; - - Ok(()) - } -} diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_planner.rs similarity index 81% rename from datafusion/core/src/physical_plan/planner.rs rename to datafusion/core/src/physical_planner.rs index 6f45b7b5452d8..ab38b3ec6d2f3 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -15,15 +15,23 @@ // specific language governing permissions and limitations // under the License. -//! Physical query planner +//! Planner for [`LogicalPlan`] to [`ExecutionPlan`] -use super::analyze::AnalyzeExec; -use super::unnest::UnnestExec; -use super::{ - aggregates, empty::EmptyExec, joins::PartitionMode, udaf, union::UnionExec, - values::ValuesExec, windows, -}; +use std::collections::HashMap; +use std::fmt::Write; +use std::sync::Arc; + +use crate::datasource::file_format::arrow::ArrowFormat; +use crate::datasource::file_format::avro::AvroFormat; +use crate::datasource::file_format::csv::CsvFormat; +use crate::datasource::file_format::json::JsonFormat; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormat; +use crate::datasource::file_format::FileFormat; +use crate::datasource::listing::ListingTableUrl; +use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::source_as_provider; +use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ @@ -32,53 +40,65 @@ use crate::logical_expr::{ }; use crate::logical_expr::{ CrossJoin, Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, - Repartition, ToStringifiedPlan, Union, UserDefinedLogicalNode, + Repartition, Union, UserDefinedLogicalNode, }; use crate::logical_expr::{Limit, Values}; use crate::physical_expr::create_physical_expr; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::analyze::AnalyzeExec; +use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::HashJoinExec; -use crate::physical_plan::joins::SortMergeJoinExec; -use crate::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec}; +use crate::physical_plan::joins::utils as join_utils; +use crate::physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, +}; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::windows::{ - BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, -}; -use crate::physical_plan::{joins::utils as join_utils, Partitioning}; -use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; -use crate::{ - error::{DataFusionError, Result}, - physical_plan::displayable, +use crate::physical_plan::union::UnionExec; +use crate::physical_plan::unnest::UnnestExec; +use crate::physical_plan::values::ValuesExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; +use crate::physical_plan::{ + aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; + use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use async_trait::async_trait; -use datafusion_common::{DFSchema, ScalarValue}; +use arrow_array::builder::StringBuilder; +use arrow_array::RecordBatch; +use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::file_options::FileTypeWriterOptions; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, plan_err, DFSchema, FileType, ScalarValue, +}; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField, - GroupingSet, InList, Like, ScalarUDF, TryCast, WindowFunction, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, + WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp}; -use datafusion_expr::{WindowFrame, WindowFrameBound}; -use datafusion_optimizer::utils::unalias; +use datafusion_expr::{ + DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, + WindowFrameBound, WriteOp, +}; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; + +use async_trait::async_trait; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; -use std::collections::HashMap; -use std::fmt::Write; -use std::sync::Arc; fn create_function_physical_name( fun: &str, @@ -110,7 +130,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(c.flat_name()) } } - Expr::Alias(_, name) => Ok(name.clone()), + Expr::Alias(Alias { name, .. }) => Ok(name.clone()), Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), Expr::Literal(value) => Ok(format!("{value:?}")), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { @@ -121,13 +141,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Case(case) => { let mut name = "CASE ".to_string(); if let Some(e) = &case.expr { - let _ = write!(name, "{e:?} "); + let _ = write!(name, "{e} "); } for (w, t) in &case.when_then_expr { - let _ = write!(name, "WHEN {w:?} THEN {t:?} "); + let _ = write!(name, "WHEN {w} THEN {t} "); } if let Some(e) = &case.else_expr { - let _ = write!(name, "ELSE {e:?} "); + let _ = write!(name, "ELSE {e} "); } name += "END"; Ok(name) @@ -180,48 +200,66 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let expr = create_physical_name(expr, false)?; Ok(format!("{expr} IS NOT UNKNOWN")) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { let expr = create_physical_name(expr, false)?; - Ok(format!("{expr}[{key}]")) - } - Expr::ScalarFunction(func) => { - create_function_physical_name(&func.fun.to_string(), false, &func.args) + let name = match field { + GetFieldAccess::NamedStructField { name } => format!("{expr}[{name}]"), + GetFieldAccess::ListIndex { key } => { + let key = create_physical_name(key, false)?; + format!("{expr}[{key}]") + } + GetFieldAccess::ListRange { start, stop } => { + let start = create_physical_name(start, false)?; + let stop = create_physical_name(stop, false)?; + format!("{expr}[{start}:{stop}]") + } + }; + + Ok(name) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(&fun.name, false, args) + Expr::ScalarFunction(fun) => { + // function should be resolved during `AnalyzerRule`s + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + create_function_physical_name(fun.name(), false, &fun.args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, - .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, filter, order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return Err(DataFusionError::Execution( - "aggregate expression with filter is not supported".to_string(), - )); + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(..) => { + create_function_physical_name(func_def.name(), *distinct, args) } - if order_by.is_some() { - return Err(DataFusionError::Execution( - "aggregate expression with order_by is not supported".to_string(), - )); + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + Ok(format!("{}({})", fun.name(), names.join(","))) } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") } - Ok(format!("{}({})", fun.name, names.join(","))) - } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -266,15 +304,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{expr} IN ({list:?})")) } } - Expr::Exists { .. } => Err(DataFusionError::NotImplemented( - "EXISTS is not yet supported in the physical plan".to_string(), - )), - Expr::InSubquery(_) => Err(DataFusionError::NotImplemented( - "IN subquery is not yet supported in the physical plan".to_string(), - )), - Expr::ScalarSubquery(_) => Err(DataFusionError::NotImplemented( - "Scalar subqueries are not yet supported in the physical plan".to_string(), - )), + Expr::Exists { .. } => { + not_impl_err!("EXISTS is not yet supported in the physical plan") + } + Expr::InSubquery(_) => { + not_impl_err!("IN subquery is not yet supported in the physical plan") + } + Expr::ScalarSubquery(_) => { + not_impl_err!("Scalar subqueries are not yet supported in the physical plan") + } Expr::Between(Between { expr, negated, @@ -295,37 +333,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { expr, pattern, escape_char, + case_insensitive, }) => { let expr = create_physical_name(expr, false)?; let pattern = create_physical_name(pattern, false)?; + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; let escape = if let Some(char) = escape_char { format!("CHAR '{char}'") } else { "".to_string() }; if *negated { - Ok(format!("{expr} NOT LIKE {pattern}{escape}")) - } else { - Ok(format!("{expr} LIKE {pattern}{escape}")) - } - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let expr = create_physical_name(expr, false)?; - let pattern = create_physical_name(pattern, false)?; - let escape = if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - }; - if *negated { - Ok(format!("{expr} NOT ILIKE {pattern}{escape}")) + Ok(format!("{expr} NOT {op_name} {pattern}{escape}")) } else { - Ok(format!("{expr} ILIKE {pattern}{escape}")) + Ok(format!("{expr} {op_name} {pattern}{escape}")) } } Expr::SimilarTo(Like { @@ -333,6 +354,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { expr, pattern, escape_char, + case_insensitive: _, }) => { let expr = create_physical_name(expr, false)?; let pattern = create_physical_name(pattern, false)?; @@ -347,21 +369,18 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{expr} SIMILAR TO {pattern}{escape}")) } } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create physical name does not support sort expression".to_string(), - )), - Expr::Wildcard => Err(DataFusionError::Internal( - "Create physical name does not support wildcard".to_string(), - )), - Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( - "Create physical name does not support qualified wildcard".to_string(), - )), - Expr::Placeholder(_) => Err(DataFusionError::Internal( - "Create physical name does not support placeholder".to_string(), - )), - Expr::OuterReferenceColumn(_, _) => Err(DataFusionError::Internal( - "Create physical name does not support OuterReferenceColumn".to_string(), - )), + Expr::Sort { .. } => { + internal_err!("Create physical name does not support sort expression") + } + Expr::Wildcard { .. } => { + internal_err!("Create physical name does not support wildcard") + } + Expr::Placeholder(_) => { + internal_err!("Create physical name does not support placeholder") + } + Expr::OuterReferenceColumn(_, _) => { + internal_err!("Create physical name does not support OuterReferenceColumn") + } } } @@ -542,12 +561,79 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await + source.scan(session_state, projection.as_ref(), &filters, *fetch).await + } + LogicalPlan::Copy(CopyTo{ + input, + output_url, + file_format, + single_file_output, + copy_options, + }) => { + let input_exec = self.create_initial_plan(input, session_state).await?; + + // TODO: make this behavior configurable via options (should copy to create path/file as needed?) + // TODO: add additional configurable options for if existing files should be overwritten or + // appended to + let parsed_url = ListingTableUrl::parse_create_local_if_not_exists(output_url, !*single_file_output)?; + let object_store_url = parsed_url.object_store(); + + let schema: Schema = (**input.schema()).clone().into(); + + let file_type_writer_options = match copy_options{ + CopyOptions::SQLOptions(statement_options) => { + FileTypeWriterOptions::build( + file_format, + session_state.config_options(), + statement_options)? + }, + CopyOptions::WriterOptions(writer_options) => *writer_options.clone() + }; + + // Set file sink related options + let config = FileSinkConfig { + object_store_url, + table_paths: vec![parsed_url], + file_groups: vec![], + output_schema: Arc::new(schema), + table_partition_cols: vec![], + unbounded_input: false, + single_file_output: *single_file_output, + overwrite: false, + file_type_writer_options + }; + + let sink_format: Arc = match file_format { + FileType::CSV => Arc::new(CsvFormat::default()), + #[cfg(feature = "parquet")] + FileType::PARQUET => Arc::new(ParquetFormat::default()), + FileType::JSON => Arc::new(JsonFormat::default()), + FileType::AVRO => Arc::new(AvroFormat {} ), + FileType::ARROW => Arc::new(ArrowFormat {}), + }; + + sink_format.create_writer_physical_plan(input_exec, session_state, config, None).await + } + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::InsertInto, + input, + .. + }) => { + let name = table_name.table(); + let schema = session_state.schema_for_ref(table_name)?; + if let Some(provider) = schema.table(name).await { + let input_exec = self.create_initial_plan(input, session_state).await?; + provider.insert_into(session_state, input_exec, false).await + } else { + return exec_err!( + "Table '{table_name}' does not exist" + ); + } } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::Insert, + op: WriteOp::InsertOverwrite, input, .. }) => { @@ -555,11 +641,11 @@ impl DefaultPhysicalPlanner { let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { let input_exec = self.create_initial_plan(input, session_state).await?; - provider.insert_into(session_state, input_exec).await + provider.insert_into(session_state, input_exec, true).await } else { - return Err(DataFusionError::Execution(format!( + return exec_err!( "Table '{table_name}' does not exist" - ))); + ); } } LogicalPlan::Values(Values { @@ -590,9 +676,9 @@ impl DefaultPhysicalPlanner { input, window_expr, .. }) => { if window_expr.is_empty() { - return Err(DataFusionError::Internal( - "Impossibly got empty window expression".to_owned(), - )); + return internal_err!( + "Impossibly got empty window expression" + ); } let input_exec = self.create_initial_plan(input, session_state).await?; @@ -628,7 +714,7 @@ impl DefaultPhysicalPlanner { ref order_by, .. }) => generate_sort_key(partition_by, order_by), - Expr::Alias(expr, _) => { + Expr::Alias(Alias{expr,..}) => { // Convert &Box to &T match &**expr { Expr::WindowFunction(WindowFunction{ @@ -673,15 +759,13 @@ impl DefaultPhysicalPlanner { Arc::new(BoundedWindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, - PartitionSearchMode::Sorted, + InputOrderMode::Sorted, )?) } else { Arc::new(WindowAggExec::try_new( window_expr, input_exec, - physical_input_schema, physical_partition_keys, )?) }) @@ -715,14 +799,14 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter.into_iter()); + let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), aggregates.clone(), filters.clone(), - order_bys.clone(), + order_bys, input_exec, physical_input_schema.clone(), )?); @@ -734,16 +818,21 @@ impl DefaultPhysicalPlanner { && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); - let (initial_aggr, next_partition_mode): ( - Arc, - AggregateMode, - ) = if can_repartition { + // Some aggregators may be modified during initialization for + // optimization purposes. For example, a FIRST_VALUE may turn + // into a LAST_VALUE with the reverse ordering requirement. + // To reflect such changes to subsequent stages, use the updated + // `AggregateExpr`/`PhysicalSortExpr` objects. + let updated_aggregates = initial_aggr.aggr_expr().to_vec(); + let updated_order_bys = initial_aggr.order_by_expr().to_vec(); + + let next_partition_mode = if can_repartition { // construct a second aggregation with 'AggregateMode::FinalPartitioned' - (initial_aggr, AggregateMode::FinalPartitioned) + AggregateMode::FinalPartitioned } else { // construct a second aggregation, keeping the final column name equal to the // first aggregation and the expressions corresponding to the respective aggregate - (initial_aggr, AggregateMode::Final) + AggregateMode::Final }; let final_grouping_set = PhysicalGroupBy::new_single( @@ -757,9 +846,9 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(AggregateExec::try_new( next_partition_mode, final_grouping_set, - aggregates, + updated_aggregates, filters, - order_bys, + updated_order_bys, initial_aggr, physical_input_schema.clone(), )?)) @@ -827,19 +916,14 @@ impl DefaultPhysicalPlanner { &input_schema, session_state, )?; - Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?)) + let selectivity = session_state.config().options().optimizer.default_filter_selectivity; + let filter = FilterExec::try_new(runtime_expr, physical_input)?; + Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; - if schema.fields().len() < physical_plans[0].schema().fields().len() { - // `schema` could be a subset of the child schema. For example - // for query "select count(*) from (select a from t union all select a from t)" - // `schema` is empty but child schema contains one field `a`. - Ok(Arc::new(UnionExec::try_new_with_schema(physical_plans, schema.clone())?)) - } else { - Ok(Arc::new(UnionExec::new(physical_plans))) - } + Ok(Arc::new(UnionExec::new(physical_plans))) } LogicalPlan::Repartition(Repartition { input, @@ -867,7 +951,7 @@ impl DefaultPhysicalPlanner { Partitioning::Hash(runtime_expr, *n) } LogicalPartitioning::DistributeBy(_) => { - return Err(DataFusionError::NotImplemented("Physical plan does not support DistributeBy partitioning".to_string())); + return not_impl_err!("Physical plan does not support DistributeBy partitioning"); } }; Ok(Arc::new(RepartitionExec::try_new( @@ -957,10 +1041,9 @@ impl DefaultPhysicalPlanner { }) .collect::>(); let projection = - logical_plan::Projection::try_new_with_schema( + Projection::try_new( final_join_result, Arc::new(join_plan), - join_schema.clone(), )?; LogicalPlan::Projection(projection) } else { @@ -1064,7 +1147,7 @@ impl DefaultPhysicalPlanner { // Sort-Merge join support currently is experimental if join_filter.is_some() { // TODO SortMergeJoinExec need to support join filter - Err(DataFusionError::NotImplemented("SortMergeJoinExec does not support join_filter now.".to_string())) + not_impl_err!("SortMergeJoinExec does not support join_filter now.") } else { let join_on_len = join_on.len(); Ok(Arc::new(SortMergeJoinExec::try_new( @@ -1114,10 +1197,15 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, + produce_one_row: false, schema, }) => Ok(Arc::new(EmptyExec::new( - *produce_one_row, + SchemaRef::new(schema.as_ref().to_owned().into()), + ))), + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema, + }) => Ok(Arc::new(PlaceholderRowExec::new( SchemaRef::new(schema.as_ref().to_owned().into()), ))), LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { @@ -1141,12 +1229,12 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } - LogicalPlan::Unnest(Unnest { input, column, schema }) => { + LogicalPlan::Unnest(Unnest { input, column, schema, options }) => { let input = self.create_initial_plan(input, session_state).await?; let column_exec = schema.index_of_column(column) .map(|idx| Column::new(&column.name, idx))?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); - Ok(Arc::new(UnnestExec::new(input, column_exec, schema))) + Ok(Arc::new(UnnestExec::new(input, column_exec, schema, options.clone()))) } LogicalPlan::Ddl(ddl) => { // There is no default plan for DDl statements -- @@ -1154,47 +1242,46 @@ impl DefaultPhysicalPlanner { // the appropriate table can be registered with // the context) let name = ddl.name(); - Err(DataFusionError::NotImplemented( - format!("Unsupported logical plan: {name}") - )) + not_impl_err!( + "Unsupported logical plan: {name}" + ) } LogicalPlan::Prepare(_) => { // There is no default plan for "PREPARE" -- it must be // handled at a higher level (so that the appropriate // statement can be prepared) - Err(DataFusionError::NotImplemented( - "Unsupported logical plan: Prepare".to_string(), - )) + not_impl_err!( + "Unsupported logical plan: Prepare" + ) } - LogicalPlan::Dml(_) => { + LogicalPlan::Dml(dml) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this - Err(DataFusionError::NotImplemented( - "Unsupported logical plan: Dml".to_string(), - )) + not_impl_err!( + "Unsupported logical plan: Dml({0})", dml.op + ) } LogicalPlan::Statement(statement) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this let name = statement.name(); - Err(DataFusionError::NotImplemented( - format!("Unsupported logical plan: Statement({name})") - )) + not_impl_err!( + "Unsupported logical plan: Statement({name})" + ) } - LogicalPlan::DescribeTable(_) => { - Err(DataFusionError::Internal( - "Unsupported logical plan: DescribeTable must be root of the plan".to_string(), - )) + LogicalPlan::DescribeTable(DescribeTable { schema, output_schema}) => { + let output_schema: Schema = output_schema.as_ref().into(); + self.plan_describe(schema.clone(), Arc::new(output_schema)) } - LogicalPlan::Explain(_) => Err(DataFusionError::Internal( - "Unsupported logical plan: Explain must be root of the plan".to_string(), - )), + LogicalPlan::Explain(_) => internal_err!( + "Unsupported logical plan: Explain must be root of the plan" + ), LogicalPlan::Distinct(_) => { - Err(DataFusionError::Internal( - "Unsupported logical plan: Distinct should be replaced to Aggregate".to_string(), - )) + internal_err!( + "Unsupported logical plan: Distinct should be replaced to Aggregate" + ) } - LogicalPlan::Analyze(_) => Err(DataFusionError::Internal( - "Unsupported logical plan: Analyze must be root of the plan".to_string(), - )), + LogicalPlan::Analyze(_) => internal_err!( + "Unsupported logical plan: Analyze must be root of the plan" + ), LogicalPlan::Extension(e) => { let physical_inputs = self.create_initial_plan_multi(e.node.inputs(), session_state).await?; @@ -1214,19 +1301,20 @@ impl DefaultPhysicalPlanner { ).await?; } - let plan = maybe_plan.ok_or_else(|| DataFusionError::Plan(format!( - "No installed planner was able to convert the custom node to an execution plan: {:?}", e.node - )))?; + let plan = match maybe_plan { + Some(v) => Ok(v), + _ => plan_err!("No installed planner was able to convert the custom node to an execution plan: {:?}", e.node) + }?; // Ensure the ExecutionPlan's schema matches the // declared logical schema to catch and warn about // logic errors when creating user defined plans. if !e.node.schema().matches_arrow_schema(&plan.schema()) { - Err(DataFusionError::Plan(format!( + plan_err!( "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", e.node, e.node.schema(), plan.schema() - ))) + ) } else { Ok(plan) } @@ -1563,10 +1651,10 @@ pub fn create_window_expr_with_name( }) .collect::>>()?; if !is_window_valid(window_frame) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound - ))); + ); } let window_frame = Arc::new(window_frame.clone()); @@ -1580,9 +1668,7 @@ pub fn create_window_expr_with_name( physical_input_schema, ) } - other => Err(DataFusionError::Plan(format!( - "Invalid window expression '{other:?}'" - ))), + other => plan_err!("Invalid window expression '{other:?}'"), } } @@ -1595,8 +1681,8 @@ pub fn create_window_expr( ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (physical_name(e)?, e), + Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), + _ => (e.display_name()?, e), }; create_window_expr_with_name( e, @@ -1625,7 +1711,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -1651,13 +1737,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - physical_input_schema, - name, - )?; let order_by = match order_by { Some(e) => Some( e.iter() @@ -1673,58 +1752,37 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ), None => None, }; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, + } }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) + Ok((agg_expr, filter, order_by)) } - other => Err(DataFusionError::Internal(format!( - "Invalid aggregate expression '{other:?}'" - ))), + other => internal_err!("Invalid aggregate expression '{other:?}'"), } } @@ -1737,7 +1795,7 @@ pub fn create_aggregate_expr_and_maybe_filter( ) -> Result { // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (physical_name(e)?, e), }; @@ -1776,9 +1834,7 @@ pub fn create_physical_sort_expr( }, }) } else { - Err(DataFusionError::Internal( - "Expects a sort expression".to_string(), - )) + internal_err!("Expects a sort expression") } } @@ -1812,25 +1868,58 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { + // This plan will includes statistics if show_statistics is on stringified_plans.push( displayable(input.as_ref()) - .to_stringified(InitialPhysicalPlan), + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, InitialPhysicalPlan), ); + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + match self.optimize_internal( input, session_state, |plan, optimizer| { let optimizer_name = optimizer.name().to_string(); let plan_type = OptimizedPhysicalPlan { optimizer_name }; - stringified_plans - .push(displayable(plan).to_stringified(plan_type)); + stringified_plans.push( + displayable(plan) + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, plan_type), + ); }, ) { - Ok(input) => stringified_plans.push( - displayable(input.as_ref()) - .to_stringified(FinalPhysicalPlan), - ), + Ok(input) => { + // This plan will includes statistics if show_statistics is on + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(config.show_statistics) + .to_stringified(e.verbose, FinalPhysicalPlan), + ); + + // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose + if e.verbose && !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + } Err(DataFusionError::Context(optimizer_name, e)) => { let plan_type = OptimizedPhysicalPlan { optimizer_name }; stringified_plans @@ -1852,7 +1941,13 @@ impl DefaultPhysicalPlanner { } else if let LogicalPlan::Analyze(a) = logical_plan { let input = self.create_physical_plan(&a.input, session_state).await?; let schema = SchemaRef::new((*a.schema).clone().into()); - Ok(Some(Arc::new(AnalyzeExec::new(a.verbose, input, schema)))) + let show_statistics = session_state.config_options().explain.show_statistics; + Ok(Some(Arc::new(AnalyzeExec::new( + a.verbose, + show_statistics, + input, + schema, + )))) } else { Ok(None) } @@ -1872,11 +1967,11 @@ impl DefaultPhysicalPlanner { let optimizers = session_state.physical_optimizers(); debug!( "Input physical plan:\n{}\n", - displayable(plan.as_ref()).indent() + displayable(plan.as_ref()).indent(false) ); trace!( "Detailed input physical plan:\n{}", - displayable(plan.as_ref()).indent() + displayable(plan.as_ref()).indent(true) ); let mut new_plan = plan; @@ -1902,17 +1997,54 @@ impl DefaultPhysicalPlanner { trace!( "Optimized physical plan by {}:\n{}\n", optimizer.name(), - displayable(new_plan.as_ref()).indent() + displayable(new_plan.as_ref()).indent(false) ); observer(new_plan.as_ref(), optimizer.as_ref()) } debug!( "Optimized physical plan:\n{}\n", - displayable(new_plan.as_ref()).indent() + displayable(new_plan.as_ref()).indent(false) ); trace!("Detailed optimized physical plan:\n{:?}", new_plan); Ok(new_plan) } + + // return an record_batch which describes a table's schema. + fn plan_describe( + &self, + table_schema: Arc, + output_schema: Arc, + ) -> Result> { + let mut column_names = StringBuilder::new(); + let mut data_types = StringBuilder::new(); + let mut is_nullables = StringBuilder::new(); + for field in table_schema.fields() { + column_names.append_value(field.name()); + + // "System supplied type" --> Use debug format of the datatype + let data_type = field.data_type(); + data_types.append_value(format!("{data_type:?}")); + + // "YES if the column is possibly nullable, NO if it is known not nullable. " + let nullable_str = if field.is_nullable() { "YES" } else { "NO" }; + is_nullables.append_value(nullable_str); + } + + let record_batch = RecordBatch::try_new( + output_schema, + vec![ + Arc::new(column_names.finish()), + Arc::new(data_types.finish()), + Arc::new(is_nullables.finish()), + ], + )?; + + let schema = record_batch.schema(); + let partitions = vec![vec![record_batch]]; + let projection = None; + let mem_exec = MemoryExec::try_new(&partitions, schema, projection)?; + Ok(Arc::new(mem_exec)) + } } fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { @@ -1929,10 +2061,9 @@ mod tests { use super::*; use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; - use crate::physical_plan::SendableRecordBatchStream; - use crate::physical_plan::{ - expressions, DisplayFormatType, Partitioning, PhysicalPlanner, Statistics, - }; + use crate::physical_plan::{expressions, DisplayFormatType, Partitioning}; + use crate::physical_plan::{DisplayAs, SendableRecordBatchStream}; + use crate::physical_planner::PhysicalPlanner; use crate::prelude::{SessionConfig, SessionContext}; use crate::scalar::ScalarValue; use crate::test_util::{scan_empty, scan_empty_with_partitions}; @@ -1950,14 +2081,14 @@ mod tests { use fmt::Debug; use std::collections::HashMap; use std::convert::TryFrom; - use std::ops::Not; + use std::ops::{BitAnd, Not}; use std::{any::Any, fmt}; fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::with_config_rt(config, runtime) + SessionState::new_with_config_rt(config, runtime) } async fn plan(logical_plan: &LogicalPlan) -> Result> { @@ -2134,18 +2265,17 @@ mod tests { async fn errors() -> Result<()> { let bool_expr = col("c1").eq(col("c1")); let cases = vec![ - // utf8 AND utf8 - col("c1").and(col("c1")), + // utf8 = utf8 + col("c1").eq(col("c1")), // u8 AND u8 - col("c3").and(col("c3")), - // utf8 = bool - col("c1").eq(bool_expr.clone()), - // u32 AND bool - col("c2").and(bool_expr), + col("c3").bitand(col("c3")), + // utf8 = u8 + col("c1").eq(col("c3")), + // bool AND bool + bool_expr.clone().and(bool_expr), ]; for case in cases { - let logical_plan = test_csv_scan().await?.project(vec![case.clone()]); - assert!(logical_plan.is_ok()); + test_csv_scan().await?.project(vec![case.clone()]).unwrap(); } Ok(()) } @@ -2198,7 +2328,7 @@ mod tests { dict_id: 0, \ dict_is_ordered: false, \ metadata: {} } }\ - ], metadata: {} }, \ + ], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ ExecutionPlan schema: Schema { fields: [\ Field { \ name: \"b\", \ @@ -2388,6 +2518,27 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(vec![col("c1")], vec![sum(col("c2"))])? + .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); + assert_eq!( + "total_salary", + physical_plan.schema().field(1).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -2418,10 +2569,11 @@ mod tests { } else { panic!( "Plan was not an explain plan: {}", - displayable(plan.as_ref()).indent() + displayable(plan.as_ref()).indent(true) ); } } + struct ErrorExtensionPlanner {} #[async_trait] @@ -2435,7 +2587,7 @@ mod tests { _physical_inputs: &[Arc], _session_state: &SessionState, ) -> Result>> { - Err(DataFusionError::Internal("BOOM".to_string())) + internal_err!("BOOM") } } /// An example extension node that doesn't do anything @@ -2495,6 +2647,16 @@ mod tests { schema: SchemaRef, } + impl DisplayAs for NoOpExecutionPlan { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "NoOpExecutionPlan") + } + } + } + } + impl ExecutionPlan for NoOpExecutionPlan { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -2531,18 +2693,6 @@ mod tests { ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } - - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "NoOpExecutionPlan") - } - } - } - - fn statistics(&self) -> Statistics { - unimplemented!("NoOpExecutionPlan::statistics"); - } } // Produces an execution plan where the schema is mismatched from @@ -2603,4 +2753,57 @@ mod tests { ctx.read_csv(path, options).await?.into_optimized_plan()?, )) } + + #[tokio::test] + async fn test_display_plan_in_graphviz_format() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let logical_plan = scan_empty(Some("employee"), &schema, None) + .unwrap() + .project(vec![col("id") + lit(2)]) + .unwrap() + .build() + .unwrap(); + + let plan = plan(&logical_plan).await.unwrap(); + + let expected_graph = r#" +// Begin DataFusion GraphViz Plan, +// display it online here: https://dreampuf.github.io/GraphvizOnline + +digraph { + 1[shape=box label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]", tooltip=""] + 2[shape=box label="EmptyExec", tooltip=""] + 1 -> 2 [arrowhead=none, arrowtail=normal, dir=back] +} +// End DataFusion GraphViz Plan +"#; + + let generated_graph = format!("{}", displayable(&*plan).graphviz()); + + assert_eq!(expected_graph, generated_graph); + } + + #[tokio::test] + async fn test_display_graphviz_with_statistics() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let logical_plan = scan_empty(Some("employee"), &schema, None) + .unwrap() + .project(vec![col("id") + lit(2)]) + .unwrap() + .build() + .unwrap(); + + let plan = plan(&logical_plan).await.unwrap(); + + let expected_tooltip = ", tooltip=\"statistics=["; + + let generated_graph = format!( + "{}", + displayable(&*plan).set_show_statistics(true).graphviz() + ); + + assert_contains!(generated_graph, expected_tooltip); + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index ed2c81a69ff12..5cd8b3870f818 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -13,9 +13,9 @@ // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations -// under the License.pub}, +// under the License. -//! A "prelude" for users of the datafusion crate. +//! DataFusion "prelude" to simplify importing common types. //! //! Like the standard library's prelude, this module simplifies importing of //! common items. Unlike the standard prelude, the contents of this module must @@ -26,7 +26,7 @@ //! ``` pub use crate::dataframe::DataFrame; -pub use crate::execution::context::{SessionConfig, SessionContext}; +pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; @@ -38,3 +38,8 @@ pub use datafusion_expr::{ logical_plan::{JoinType, Partitioning}, Expr, }; + +pub use std::ops::Not; +pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; +pub use std::ops::{BitAnd, BitOr, BitXor}; +pub use std::ops::{Shl, Shr}; diff --git a/datafusion/core/src/scalar.rs b/datafusion/core/src/scalar.rs index 29f75096aecea..c4f0d80616ee6 100644 --- a/datafusion/core/src/scalar.rs +++ b/datafusion/core/src/scalar.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! ScalarValue reimported from datafusion-common to easy migration -//! when datafusion was split into several different crates - +//! [`ScalarValue`] single value representation. +//! +//! Note this is reimported from the datafusion-common crate for easy +//! migration when datafusion was split into several different crates pub use datafusion_common::{ScalarType, ScalarValue}; diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 838c13f96856d..aad5c19044ea9 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -17,39 +17,42 @@ //! Common unit test utility methods -use crate::arrow::array::UInt32Array; -use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; +use std::any::Any; +use std::fs::File; +use std::io::prelude::*; +use std::io::{BufReader, BufWriter}; +use std::path::Path; +use std::sync::Arc; + +use crate::datasource::file_format::file_compression_type::{ + FileCompressionType, FileTypeExt, +}; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; use crate::datasource::{MemTable, TableProvider}; use crate::error::Result; use crate::logical_expr::LogicalPlan; -use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::ExecutionPlan; use crate::test::object_store::local_unpartitioned_file; use crate::test_util::{aggr_test_schema, arrow_test_data}; -use array::ArrayRef; -use arrow::array::{self, Array, Decimal128Builder, Int32Array}; + +use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{DataFusionError, FileType, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; + #[cfg(feature = "compression")] use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; -use datafusion_common::{DataFusionError, Statistics}; -use datafusion_physical_expr::PhysicalSortExpr; #[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] use flate2::Compression as GzCompression; -use futures::{Future, FutureExt}; -use std::fs::File; -use std::io::prelude::*; -use std::io::{BufReader, BufWriter}; -use std::pin::Pin; -use std::sync::Arc; -use tempfile::TempDir; #[cfg(feature = "compression")] use xz2::write::XzEncoder; #[cfg(feature = "compression")] @@ -73,7 +76,7 @@ pub fn create_table_dual() -> Arc { } /// Returns a [`CsvExec`] that scans "aggregate_test_100.csv" with `partitions` partitions -pub fn scan_partitioned_csv(partitions: usize) -> Result> { +pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result> { let schema = aggr_test_schema(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv", arrow_test_data()); @@ -83,12 +86,15 @@ pub fn scan_partitioned_csv(partitions: usize) -> Result> { partitions, FileType::CSV, FileCompressionType::UNCOMPRESSED, + work_dir, )?; let config = partitioned_csv_config(schema, file_groups)?; Ok(Arc::new(CsvExec::new( config, true, b',', + b'"', + None, FileCompressionType::UNCOMPRESSED, ))) } @@ -100,11 +106,10 @@ pub fn partitioned_file_groups( partitions: usize, file_type: FileType, file_compression_type: FileCompressionType, + work_dir: &Path, ) -> Result>> { let path = format!("{path}/{filename}"); - let tmp_dir = TempDir::new()?.into_path(); - let mut writers = vec![]; let mut files = vec![]; for i in 0..partitions { @@ -116,7 +121,7 @@ pub fn partitioned_file_groups( .get_ext_with_compression(file_compression_type.to_owned()) .unwrap() ); - let filename = tmp_dir.join(filename); + let filename = work_dir.join(filename); let file = File::create(&filename).unwrap(); @@ -171,7 +176,10 @@ pub fn partitioned_file_groups( writers[partition].write_all(b"\n").unwrap(); } } - for w in writers.iter_mut() { + + // Must drop the stream before creating ObjectMeta below as drop + // triggers finish for ZstdEncoder which writes additional data + for mut w in writers.into_iter() { w.flush().unwrap(); } @@ -188,9 +196,9 @@ pub fn partitioned_csv_config( ) -> Result { Ok(FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: schema, + file_schema: schema.clone(), file_groups, - statistics: Default::default(), + statistics: Statistics::new_unknown(&schema), projection: None, limit: None, table_partition_cols: vec![], @@ -209,40 +217,6 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } -/// returns record batch with 3 columns of i32 in memory -pub fn build_table_i32( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), -) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Int32, false), - Field::new(b.0, DataType::Int32, false), - Field::new(c.0, DataType::Int32, false), - ]); - - RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap() -} - -/// returns memory table scan wrapped around record batch with 3 columns of i32 -pub fn build_table_scan_i32( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), -) -> Arc { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) -} - /// Returns the column names on the schema pub fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() @@ -275,14 +249,6 @@ pub fn make_partition(sz: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![arr]).unwrap() } -/// Return a RecordBatch with a single array with row_count sz -pub fn make_batch_no_column(sz: usize) -> RecordBatch { - let schema = Arc::new(Schema::empty()); - - let options = RecordBatchOptions::new().with_row_count(Option::from(sz)); - RecordBatch::try_new_with_options(schema, vec![], &options).unwrap() -} - /// Return a new table which provide this decimal column pub fn table_with_decimal() -> Arc { let batch_decimal = make_decimal(); @@ -307,25 +273,6 @@ fn make_decimal() -> RecordBatch { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() } -/// Asserts that given future is pending. -pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send + 'a>>) { - let waker = futures::task::noop_waker(); - let mut cx = futures::task::Context::from_waker(&waker); - let poll = fut.poll_unpin(&mut cx); - - assert!(poll.is_pending()); -} - -/// Create vector batches -pub fn create_vec_batches(schema: &Schema, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec -} - /// Created a sorted Csv exec pub fn csv_exec_sorted( schema: &SchemaRef, @@ -339,7 +286,7 @@ pub fn csv_exec_sorted( object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), file_schema: schema.clone(), file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), + statistics: Statistics::new_unknown(schema), projection: None, limit: None, table_partition_cols: vec![], @@ -348,19 +295,90 @@ pub fn csv_exec_sorted( }, false, 0, + 0, + None, FileCompressionType::UNCOMPRESSED, )) } -/// Create batch -fn create_batch(schema: &Schema) -> RecordBatch { - RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() +/// A mock execution plan that simply returns the provided statistics +#[derive(Debug, Clone)] +pub struct StatisticsExec { + stats: Statistics, + schema: Arc, +} +impl StatisticsExec { + pub fn new(stats: Statistics, schema: Schema) -> Self { + assert_eq!( + stats.column_statistics.len(), schema.fields().len(), + "if defined, the column statistics vector length should be the number of fields" + ); + Self { + stats, + schema: Arc::new(schema), + } + } +} + +impl DisplayAs for StatisticsExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "StatisticsExec: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } +} + +impl ExecutionPlan for StatisticsExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(2) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("This plan only serves for testing statistics") + } + + fn statistics(&self) -> Result { + Ok(self.stats.clone()) + } } -pub mod exec; pub mod object_store; pub mod variable; diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 425d0724ea4fa..d6f324a7f1f95 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. //! Object store implementation used for testing +use crate::execution::context::SessionState; use crate::prelude::SessionContext; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -25,11 +28,11 @@ use url::Url; pub fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) { let url = Url::parse("test://").unwrap(); ctx.runtime_env() - .register_object_store(&url, make_test_store(files)); + .register_object_store(&url, make_test_store_and_state(files).0); } /// Create a test object store with the provided files -pub fn make_test_store(files: &[(&str, u64)]) -> Arc { +pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, SessionState) { let memory = InMemory::new(); for (name, size) in files { @@ -40,7 +43,13 @@ pub fn make_test_store(files: &[(&str, u64)]) -> Arc { .unwrap(); } - Arc::new(memory) + ( + Arc::new(memory), + SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ), + ) } /// Helper method to fetch the file size and date at given path and create a `ObjectMeta` @@ -52,5 +61,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/src/test/variable.rs b/datafusion/core/src/test/variable.rs index a55513841561f..38207b42cb7b8 100644 --- a/datafusion/core/src/test/variable.rs +++ b/datafusion/core/src/test/variable.rs @@ -37,7 +37,7 @@ impl VarProvider for SystemVar { /// get system variable value fn get_value(&self, var_names: Vec) -> Result { let s = format!("{}-{}", "system-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } fn get_type(&self, _: &[String]) -> Option { @@ -61,7 +61,7 @@ impl VarProvider for UserDefinedVar { fn get_value(&self, var_names: Vec) -> Result { if var_names[0] != "@integer" { let s = format!("{}-{}", "user-defined-var", var_names.concat()); - Ok(ScalarValue::Utf8(Some(s))) + Ok(ScalarValue::from(s)) } else { Ok(ScalarValue::Int32(Some(41))) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 993ca9c186c5f..c6b43de0c18d5 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -17,200 +17,48 @@ //! Utility functions to make testing DataFusion based crates easier +#[cfg(feature = "parquet")] pub mod parquet; use std::any::Any; use std::collections::HashMap; +use std::fs::File; +use std::io::Write; use std::path::Path; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; -use std::{env, error::Error, path::PathBuf, sync::Arc}; -use crate::datasource::datasource::TableProviderFactory; +use tempfile::TempDir; + +use crate::dataframe::DataFrame; +use crate::datasource::provider::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; use crate::prelude::{CsvReadOptions, SessionContext}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use async_trait::async_trait; -use datafusion_common::{Statistics, TableReference}; +use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, TableType}; use datafusion_physical_expr::PhysicalSortExpr; -use futures::Stream; - -/// Compares formatted output of a record batch with an expected -/// vector of strings, with the result of pretty formatting record -/// batches. This is a macro so errors appear on the correct line -/// -/// Designed so that failure output can be directly copy/pasted -/// into the test code as expected results. -/// -/// Expects to be called about like this: -/// -/// `assert_batch_eq!(expected_lines: &[&str], batches: &[RecordBatch])` -#[macro_export] -macro_rules! assert_batches_eq { - ($EXPECTED_LINES: expr, $CHUNKS: expr) => { - let expected_lines: Vec = - $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); - - let formatted = arrow::util::pretty::pretty_format_batches($CHUNKS) - .unwrap() - .to_string(); - - let actual_lines: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; -} - -/// Compares formatted output of a record batch with an expected -/// vector of strings in a way that order does not matter. -/// This is a macro so errors appear on the correct line -/// -/// Designed so that failure output can be directly copy/pasted -/// into the test code as expected results. -/// -/// Expects to be called about like this: -/// -/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])` -#[macro_export] -macro_rules! assert_batches_sorted_eq { - ($EXPECTED_LINES: expr, $CHUNKS: expr) => { - let mut expected_lines: Vec = - $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); - - // sort except for header + footer - let num_lines = expected_lines.len(); - if num_lines > 3 { - expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() - } - - let formatted = arrow::util::pretty::pretty_format_batches($CHUNKS) - .unwrap() - .to_string(); - // fix for windows: \r\n --> - - let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); - - // sort except for header + footer - let num_lines = actual_lines.len(); - if num_lines > 3 { - actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() - } - - assert_eq!( - expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; -} -/// Returns the arrow test data directory, which is by default stored -/// in a git submodule rooted at `testing/data`. -/// -/// The default can be overridden by the optional environment -/// variable `ARROW_TEST_DATA` -/// -/// panics when the directory can not be found. -/// -/// Example: -/// ``` -/// let testdata = datafusion::test_util::arrow_test_data(); -/// let csvdata = format!("{}/csv/aggregate_test_100.csv", testdata); -/// assert!(std::path::PathBuf::from(csvdata).exists()); -/// ``` -pub fn arrow_test_data() -> String { - match get_data_dir("ARROW_TEST_DATA", "../../testing/data") { - Ok(pb) => pb.display().to_string(), - Err(err) => panic!("failed to get arrow data dir: {err}"), - } -} +use async_trait::async_trait; +use futures::Stream; -/// Returns the parquet test data directory, which is by default -/// stored in a git submodule rooted at -/// `parquet-testing/data`. -/// -/// The default can be overridden by the optional environment variable -/// `PARQUET_TEST_DATA` -/// -/// panics when the directory can not be found. -/// -/// Example: -/// ``` -/// let testdata = datafusion::test_util::parquet_test_data(); -/// let filename = format!("{}/binary.parquet", testdata); -/// assert!(std::path::PathBuf::from(filename).exists()); -/// ``` -pub fn parquet_test_data() -> String { - match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") { - Ok(pb) => pb.display().to_string(), - Err(err) => panic!("failed to get parquet data dir: {err}"), - } -} +// backwards compatibility +#[cfg(feature = "parquet")] +pub use datafusion_common::test_util::parquet_test_data; +pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; -/// Returns a directory path for finding test data. -/// -/// udf_env: name of an environment variable -/// -/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR) -/// -/// Returns either: -/// The path referred to in `udf_env` if that variable is set and refers to a directory -/// The submodule_data directory relative to CARGO_MANIFEST_PATH -pub fn get_data_dir( - udf_env: &str, - submodule_data: &str, -) -> Result> { - // Try user defined env. - if let Ok(dir) = env::var(udf_env) { - let trimmed = dir.trim().to_string(); - if !trimmed.is_empty() { - let pb = PathBuf::from(trimmed); - if pb.is_dir() { - return Ok(pb); - } else { - return Err(format!( - "the data dir `{}` defined by env {} not found", - pb.display(), - udf_env - ) - .into()); - } - } - } - - // The env is undefined or its value is trimmed to empty, let's try default dir. - - // env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package", - // set by `cargo run` or `cargo test`, see: - // https://doc.rust-lang.org/cargo/reference/environment-variables.html - let dir = env!("CARGO_MANIFEST_DIR"); - - let pb = PathBuf::from(dir).join(submodule_data); - if pb.is_dir() { - Ok(pb) - } else { - Err(format!( - "env `{}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\ - HINT: try running `git submodule update --init`", - udf_env, - pb.display(), - ).into()) - } -} +pub use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq}; /// Scan an empty data source, mainly used in tests pub fn scan_empty( @@ -240,9 +88,7 @@ pub fn scan_empty_with_partitions( /// Get the schema for the aggregate_test_* csv files pub fn aggr_test_schema() -> SchemaRef { let mut f1 = Field::new("c1", DataType::Utf8, false); - f1.set_metadata(HashMap::from_iter( - vec![("testing".into(), "test".into())].into_iter(), - )); + f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())])); let schema = Schema::new(vec![ f1, Field::new("c2", DataType::UInt32, false), @@ -262,30 +108,69 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } -/// Get the schema for the aggregate_test_* csv files with an additional filed not present in the files. -pub fn aggr_test_schema_with_missing_col() -> SchemaRef { - let mut f1 = Field::new("c1", DataType::Utf8, false); - f1.set_metadata(HashMap::from_iter( - vec![("testing".into(), "test".into())].into_iter(), - )); - let schema = Schema::new(vec![ - f1, - Field::new("c2", DataType::UInt32, false), - Field::new("c3", DataType::Int8, false), - Field::new("c4", DataType::Int16, false), - Field::new("c5", DataType::Int32, false), - Field::new("c6", DataType::Int64, false), - Field::new("c7", DataType::UInt8, false), - Field::new("c8", DataType::UInt16, false), - Field::new("c9", DataType::UInt32, false), - Field::new("c10", DataType::UInt64, false), - Field::new("c11", DataType::Float32, false), - Field::new("c12", DataType::Float64, false), - Field::new("c13", DataType::Utf8, false), - Field::new("missing_col", DataType::Int64, true), - ]); +/// Register session context for the aggregate_test_100.csv file +pub async fn register_aggregate_csv( + ctx: &mut SessionContext, + table_name: &str, +) -> Result<()> { + let schema = aggr_test_schema(); + let testdata = arrow_test_data(); + ctx.register_csv( + table_name, + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(schema.as_ref()), + ) + .await?; + Ok(()) +} - Arc::new(schema) +/// Create a table from the aggregate_test_100.csv file with the specified name +pub async fn test_table_with_name(name: &str) -> Result { + let mut ctx = SessionContext::new(); + register_aggregate_csv(&mut ctx, name).await?; + ctx.table(name).await +} + +/// Create a table from the aggregate_test_100.csv file with the name "aggregate_test_100" +pub async fn test_table() -> Result { + test_table_with_name("aggregate_test_100").await +} + +/// Execute SQL and return results +pub async fn plan_and_collect( + ctx: &SessionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Generate CSV partitions within the supplied directory +pub fn populate_csv_partitions( + tmp_dir: &TempDir, + partition_count: usize, + file_extension: &str, +) -> Result { + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // generate a partitioned file + for partition in 0..partition_count { + let filename = format!("partition-{partition}.{file_extension}"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for i in 0..=10 { + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); + file.write_all(data.as_bytes())?; + } + } + + Ok(schema) } /// TableFactory for tests @@ -363,6 +248,25 @@ impl UnboundedExec { } } } + +impl DisplayAs for UnboundedExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "UnboundableExec: unbounded={}", + self.batch_produce.is_none(), + ) + } + } + } +} + impl ExecutionPlan for UnboundedExec { fn as_any(&self) -> &dyn Any { self @@ -405,26 +309,6 @@ impl ExecutionPlan for UnboundedExec { batch: self.batch.clone(), })) } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "UnboundableExec: unbounded={}", - self.batch_produce.is_none(), - ) - } - } - } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } #[derive(Debug)] @@ -457,61 +341,6 @@ impl RecordBatchStream for UnboundedStream { } } -#[cfg(test)] -mod tests { - use super::*; - use std::env; - - #[test] - fn test_data_dir() { - let udf_env = "get_data_dir"; - let cwd = env::current_dir().unwrap(); - - let existing_pb = cwd.join(".."); - let existing = existing_pb.display().to_string(); - let existing_str = existing.as_str(); - - let non_existing = cwd.join("non-existing-dir").display().to_string(); - let non_existing_str = non_existing.as_str(); - - env::set_var(udf_env, non_existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_err()); - - env::set_var(udf_env, ""); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, " "); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::set_var(udf_env, existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - - env::remove_var(udf_env); - let res = get_data_dir(udf_env, non_existing_str); - assert!(res.is_err()); - - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); - } - - #[test] - fn test_happy() { - let res = arrow_test_data(); - assert!(PathBuf::from(res).is_dir()); - - let res = parquet_test_data(); - assert!(PathBuf::from(res).is_dir()); - } -} - /// This function creates an unbounded sorted file for testing purposes. pub async fn register_unbounded_file_with_ordering( ctx: &SessionContext, diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index d3a1f9c1ef7c0..f3c0d2987a46c 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -35,6 +35,9 @@ use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig}; + +use datafusion_common::Statistics; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; @@ -110,6 +113,7 @@ impl TestParquetFile { last_modified: Default::default(), size, e_tag: None, + version: None, }; Ok(Self { @@ -147,7 +151,7 @@ impl TestParquetFile { range: None, extensions: None, }]], - statistics: Default::default(), + statistics: Statistics::new_unknown(&self.schema), projection: None, limit: None, table_partition_cols: vec![], diff --git a/datafusion/core/src/variable/mod.rs b/datafusion/core/src/variable/mod.rs index 6efa8eb86211b..5ef165313ccf9 100644 --- a/datafusion/core/src/variable/mod.rs +++ b/datafusion/core/src/variable/mod.rs @@ -15,6 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Variable provider +//! Variable provider for `@name` and `@@name` style runtime values. pub use datafusion_physical_expr::var_provider::{VarProvider, VarType}; diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index a420f5c9f5a35..a5a5a4524e609 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -36,7 +36,7 @@ fn from_env() { // for invalid testing env::set_var(env_key, "abc"); - let err = ConfigOptions::from_env().unwrap_err().to_string(); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); assert_eq!(err, "Error parsing abc as usize\ncaused by\nExternal error: invalid digit found in string"); env::remove_var(env_key); diff --git a/datafusion/core/tests/sql_integration.rs b/datafusion/core/tests/core_integration.rs similarity index 92% rename from datafusion/core/tests/sql_integration.rs rename to datafusion/core/tests/core_integration.rs index f01298ac6d7f2..af39e1e18abc6 100644 --- a/datafusion/core/tests/sql_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -18,6 +18,9 @@ /// Run all tests that are found in the `sql` directory mod sql; +/// Run all tests that are found in the `dataframe` directory +mod dataframe; + #[cfg(test)] #[ctor::ctor] fn init() { diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index b060480d64d07..a9ea5cc2a35c8 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -15,37 +15,39 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + use arrow::array::{Int32Array, Int64Array}; use arrow::compute::kernels::aggregate; use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; -use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ - project_schema, ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use datafusion::scalar::ScalarValue; -use datafusion::{ - datasource::{TableProvider, TableType}, - physical_plan::collect, -}; -use datafusion::{error::Result, physical_plan::DisplayFormatType}; - use datafusion_common::cast::as_primitive_array; -use futures::stream::Stream; -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +use datafusion_common::project_schema; +use datafusion_common::stats::Precision; use async_trait::async_trait; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use futures::stream::Stream; -//// Custom source dataframe tests //// +/// Also run all tests that are found in the `custom_sources_cases` directory +mod custom_sources_cases; + +//--- Custom source dataframe tests ---// struct CustomTableProvider; #[derive(Debug, Clone)] @@ -98,6 +100,20 @@ impl Stream for TestCustomRecordBatchStream { } } +impl DisplayAs for CustomExecutionPlan { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CustomExecutionPlan: projection={:#?}", self.projection) + } + } + } +} + impl ExecutionPlan for CustomExecutionPlan { fn as_any(&self) -> &dyn Any { self @@ -135,42 +151,28 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "CustomExecutionPlan: projection={:#?}", self.projection) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); - Statistics { - is_exact: true, - num_rows: Some(batch.num_rows()), - total_byte_size: None, - column_statistics: Some( - self.projection - .clone() - .unwrap_or_else(|| (0..batch.columns().len()).collect()) - .iter() - .map(|i| ColumnStatistics { - null_count: Some(batch.column(*i).null_count()), - min_value: Some(ScalarValue::Int32(aggregate::min( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - max_value: Some(ScalarValue::Int32(aggregate::max( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - ..Default::default() - }) - .collect(), - ), - } + Ok(Statistics { + num_rows: Precision::Exact(batch.num_rows()), + total_byte_size: Precision::Absent, + column_statistics: self + .projection + .clone() + .unwrap_or_else(|| (0..batch.columns().len()).collect()) + .iter() + .map(|i| ColumnStatistics { + null_count: Precision::Exact(batch.column(*i).null_count()), + min_value: Precision::Exact(ScalarValue::Int32(aggregate::min( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + max_value: Precision::Exact(ScalarValue::Int32(aggregate::max( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + ..Default::default() + }) + .collect(), + }) } } @@ -254,15 +256,15 @@ async fn optimizers_catch_all_statistics() { let physical_plan = df.create_physical_plan().await.unwrap(); - // when the optimization kicks in, the source is replaced by an EmptyExec + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec assert!( - contains_empty_exec(Arc::clone(&physical_plan)), + contains_place_holder_exec(Arc::clone(&physical_plan)), "Expected aggregate_statistics optimizations missing: {physical_plan:?}" ); let expected = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("COUNT(UInt8(1))", DataType::Int64, false), + Field::new("COUNT(*)", DataType::Int64, false), Field::new("MIN(test.c1)", DataType::Int32, false), Field::new("MAX(test.c1)", DataType::Int32, false), ])), @@ -281,12 +283,12 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } -fn contains_empty_exec(plan: Arc) -> bool { - if plan.as_any().is::() { +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { true } else if plan.children().len() != 1 { false } else { - contains_empty_exec(Arc::clone(&plan.children()[0])) + contains_place_holder_exec(Arc::clone(&plan.children()[0])) } } diff --git a/datafusion/core/tests/sqllogictests/src/engines/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs similarity index 92% rename from datafusion/core/tests/sqllogictests/src/engines/mod.rs rename to datafusion/core/tests/custom_sources_cases/mod.rs index a2657bb60017b..d5367c77d2b9c 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -15,7 +15,5 @@ // specific language governing permissions and limitations // under the License. -mod conversion; -pub mod datafusion; -mod output; -pub mod postgres; +mod provider_filter_pushdown; +mod statistics; diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs similarity index 82% rename from datafusion/core/tests/provider_filter_pushdown.rs rename to datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ac1eef850dfa7..e374abd6e8915 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -15,26 +15,29 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; +use std::sync::Arc; + use arrow::array::{Int32Builder, Int64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use async_trait::async_trait; -use datafusion::datasource::datasource::{TableProvider, TableType}; +use datafusion::datasource::provider::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; -use datafusion_common::DataFusionError; +use datafusion_common::{internal_err, not_impl_err, DataFusionError}; use datafusion_expr::expr::{BinaryExpr, Cast}; -use std::ops::Deref; -use std::sync::Arc; + +use async_trait::async_trait; fn create_batch(value: i32, num_rows: usize) -> Result { let mut builder = Int32Builder::with_capacity(num_rows); @@ -58,6 +61,20 @@ struct CustomPlan { batches: Vec, } +impl DisplayAs for CustomPlan { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CustomPlan: batch_size={}", self.batches.len(),) + } + } + } +} + impl ExecutionPlan for CustomPlan { fn as_any(&self) -> &dyn std::any::Any { self @@ -81,9 +98,14 @@ impl ExecutionPlan for CustomPlan { fn with_new_children( self: Arc, - _: Vec>, + children: Vec>, ) -> Result> { - unreachable!() + // CustomPlan has no children + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } } fn execute( @@ -97,22 +119,10 @@ impl ExecutionPlan for CustomPlan { ))) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "CustomPlan: batch_size={}", self.batches.len(),) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // here we could provide more accurate statistics // but we want to test the filter pushdown not the CBOs - Statistics::default() + Ok(Statistics::new_unknown(&self.schema())) } } @@ -139,10 +149,12 @@ impl TableProvider for CustomProvider { async fn scan( &self, _state: &SessionState, - _: Option<&Vec>, + projection: Option<&Vec>, filters: &[Expr], _: Option, ) -> Result> { + let empty = Vec::new(); + let projection = projection.unwrap_or(&empty); match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { @@ -157,26 +169,25 @@ impl TableProvider for CustomProvider { ScalarValue::Int32(Some(v)) => *v as i64, ScalarValue::Int64(Some(v)) => *v, other_value => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Do not support value {other_value:?}" - ))); + ); } }, other_expr => { - return Err(DataFusionError::NotImplemented(format!( - "Do not support expr {other_expr:?}" - ))); + return not_impl_err!("Do not support expr {other_expr:?}"); } }, other_expr => { - return Err(DataFusionError::NotImplemented(format!( - "Do not support expr {other_expr:?}" - ))); + return not_impl_err!("Do not support expr {other_expr:?}"); } }; Ok(Arc::new(CustomPlan { - schema: self.zero_batch.schema(), + schema: match projection.is_empty() { + true => Arc::new(Schema::empty()), + false => self.zero_batch.schema(), + }, batches: match int_value { 0 => vec![self.zero_batch.clone()], 1 => vec![self.one_batch.clone()], @@ -185,7 +196,10 @@ impl TableProvider for CustomProvider { })) } _ => Ok(Arc::new(CustomPlan { - schema: self.zero_batch.schema(), + schema: match projection.is_empty() { + true => Arc::new(Schema::empty()), + false => self.zero_batch.schema(), + }, batches: vec![], })), } diff --git a/datafusion/core/tests/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs similarity index 76% rename from datafusion/core/tests/statistics.rs rename to datafusion/core/tests/custom_sources_cases/statistics.rs index ca83ab1cf64bf..f0985f5546543 100644 --- a/datafusion/core/tests/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -25,9 +25,8 @@ use datafusion::{ error::Result, logical_expr::Expr, physical_plan::{ - expressions::PhysicalSortExpr, project_schema, ColumnStatistics, - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, - Statistics, + expressions::PhysicalSortExpr, ColumnStatistics, DisplayAs, DisplayFormatType, + ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, prelude::SessionContext, scalar::ScalarValue, @@ -35,6 +34,7 @@ use datafusion::{ use async_trait::async_trait; use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion_common::{project_schema, stats::Precision}; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -46,13 +46,10 @@ struct StatisticsValidation { impl StatisticsValidation { fn new(stats: Statistics, schema: SchemaRef) -> Self { - assert!( - stats - .column_statistics - .as_ref() - .map(|cols| cols.len() == schema.fields().len()) - .unwrap_or(true), - "if defined, the column statistics vector length should be the number of fields" + assert_eq!( + stats.column_statistics.len(), + schema.fields().len(), + "the column statistics vector length should be the number of fields" ); Self { stats, schema } } @@ -94,23 +91,41 @@ impl TableProvider for StatisticsValidation { let current_stat = self.stats.clone(); - let proj_col_stats = current_stat - .column_statistics - .map(|col_stat| projection.iter().map(|i| col_stat[*i].clone()).collect()); - + let proj_col_stats = projection + .iter() + .map(|i| current_stat.column_statistics[*i].clone()) + .collect(); Ok(Arc::new(Self::new( Statistics { - is_exact: current_stat.is_exact, num_rows: current_stat.num_rows, column_statistics: proj_col_stats, // TODO stats: knowing the type of the new columns we can guess the output size - total_byte_size: None, + total_byte_size: Precision::Absent, }, projected_schema, ))) } } +impl DisplayAs for StatisticsValidation { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "StatisticsValidation: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } +} + impl ExecutionPlan for StatisticsValidation { fn as_any(&self) -> &dyn Any { self @@ -147,25 +162,8 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Statistics { - self.stats.clone() - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "StatisticsValidation: col_count={}, row_count={:?}", - self.schema.fields().len(), - self.stats.num_rows, - ) - } - } + fn statistics(&self) -> Result { + Ok(self.stats.clone()) } } @@ -180,23 +178,22 @@ fn init_ctx(stats: Statistics, schema: Schema) -> Result { fn fully_defined() -> (Statistics, Schema) { ( Statistics { - num_rows: Some(13), - is_exact: true, - total_byte_size: None, // ignore byte size for now - column_statistics: Some(vec![ + num_rows: Precision::Exact(13), + total_byte_size: Precision::Absent, // ignore byte size for now + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(2), - max_value: Some(ScalarValue::Int32(Some(1023))), - min_value: Some(ScalarValue::Int32(Some(-24))), - null_count: Some(0), + distinct_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(1023))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(13), - max_value: Some(ScalarValue::Int64(Some(5486))), - min_value: Some(ScalarValue::Int64(Some(-6783))), - null_count: Some(5), + distinct_count: Precision::Exact(13), + max_value: Precision::Exact(ScalarValue::Int64(Some(5486))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), + null_count: Precision::Exact(5), }, - ]), + ], }, Schema::new(vec![ Field::new("c1", DataType::Int32, false), @@ -214,7 +211,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.statistics()); + assert_eq!(stats, physical_plan.statistics()?); Ok(()) } @@ -230,10 +227,8 @@ async fn sql_filter() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - - let stats = physical_plan.statistics(); - assert!(!stats.is_exact); - assert_eq!(stats.num_rows, Some(1)); + let stats = physical_plan.statistics()?; + assert_eq!(stats.num_rows, Precision::Inexact(1)); Ok(()) } @@ -241,6 +236,7 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); + let col_stats = Statistics::unknown_column(&schema); let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); @@ -249,11 +245,11 @@ async fn sql_limit() -> Result<()> { // we loose all statistics except the for number of rows which becomes the limit assert_eq!( Statistics { - num_rows: Some(5), - is_exact: true, - ..Default::default() + num_rows: Precision::Exact(5), + column_statistics: col_stats, + total_byte_size: Precision::Absent }, - physical_plan.statistics() + physical_plan.statistics()? ); let df = ctx @@ -262,7 +258,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.statistics()); + assert_eq!(stats, physical_plan.statistics()?); Ok(()) } @@ -279,13 +275,12 @@ async fn sql_window() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); - let result = physical_plan.statistics(); + let result = physical_plan.statistics()?; assert_eq!(stats.num_rows, result.num_rows); - assert!(result.column_statistics.is_some()); - let col_stats = result.column_statistics.unwrap(); + let col_stats = result.column_statistics; assert_eq!(2, col_stats.len()); - assert_eq!(stats.column_statistics.unwrap()[1], col_stats[0]); + assert_eq!(stats.column_statistics[1], col_stats[0]); Ok(()) } diff --git a/datafusion/core/tests/data/4.json b/datafusion/core/tests/data/4.json new file mode 100644 index 0000000000000..f0c67cd7cf0e3 --- /dev/null +++ b/datafusion/core/tests/data/4.json @@ -0,0 +1,4 @@ +{"a":1, "b":[2.0, 1.3, -6.1]} +{"a":2, "b":[3.0, 4.3]} +{"c":[false, true], "d":{"c1": 23, "c2": 32}} +{"e": {"e1": 2, "e2": 12.3}} \ No newline at end of file diff --git a/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv b/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv new file mode 100644 index 0000000000000..9cdf2f845e85c --- /dev/null +++ b/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv @@ -0,0 +1,101 @@ +c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 +a,1,-85,-15154,1171968280,1919439543497968449,77,52286,774637006,12101411955859039553,0.12285209,0.6864391962767343,0keZ5G8BffGwgF2RwQD59TFzMStxCB +a,3,13,12613,1299719633,2020498574254265315,191,17835,3998790955,14881411008939145569,0.041445434,0.8813167497816289,Amn2K87Db5Es3dFQO9cw9cvpAM6h35 +a,4,-38,20744,762932956,308913475857409919,7,45465,1787652631,878137512938218976,0.7459874,0.02182578039211991,ydkwycaISlYSlEq3TlkS2m15I2pcp8 +a,4,-54,-2376,434021400,5502271306323260832,113,15777,2502326480,7966148640299601101,0.5720931,0.30585375151301186,KJFcmTVjdkCMv94wYCtfHMFhzyRsmH +a,5,36,-16974,623103518,6834444206535996609,71,29458,141047417,17448660630302620693,0.17100024,0.04429073092078406,OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh +a,1,-25,15295,383352709,4980135132406487265,231,102,3276123488,12763583666216333412,0.53796273,0.17592486905979987,XemNcT1xp61xcM1Qz3wZ1VECCnq06O +a,5,-31,-12907,586844478,-4862189775214031241,170,28086,1013876852,11005002152861474932,0.35319167,0.05573662213439634,MeSTAXq8gVxVjbEjgkvU9YLte0X9uE +a,2,45,15673,-1899175111,398282800995316041,99,2555,145294611,8554426087132697832,0.17333257,0.6405262429561641,b3b9esRhTzFEawbs6XhpKnD9ojutHB +a,3,13,32064,912707948,3826618523497875379,42,21463,2214035726,10771380284714693539,0.6133468,0.7325106678655877,i6RQVXKUh7MzuGMDaNclUYnFUAireU +a,3,17,-22796,1337043149,-1282905594104562444,167,2809,754775609,732272194388185106,0.3884129,0.658671129040488,VDhtJkYjAYPykCgOU9x3v7v3t4SO1a +a,4,65,-28462,-1813935549,7602389238442209730,18,363,1865307672,11378396836996498283,0.09130204,0.5593249815276734,WHmjWk2AY4c6m7DA4GitUx6nmb1yYS +a,4,-101,11640,1993193190,2992662416070659899,230,40566,466439833,16778113360088370541,0.3991115,0.574210838214554,NEhyk8uIx4kEULJGa8qIyFjjBcP2G6 +a,2,-48,-18025,439738328,-313657814587041987,222,13763,3717551163,9135746610908713318,0.055064857,0.9800193410444061,ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8 +a,1,-56,8692,2106705285,-7811675384226570375,231,15573,1454057357,677091006469429514,0.42794758,0.2739938529235548,JN0VclewmjwYlSl8386MlWv5rEhWCz +a,1,-5,12636,794623392,2909750622865366631,15,24022,2669374863,4776679784701509574,0.29877836,0.2537253407987472,waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs +a,3,14,28162,397430452,-452851601758273256,57,14722,431948861,8164671015278284913,0.40199697,0.07260475960924484,TtDKUZxzVxsq758G6AWPSYuZgVgbcl +a,1,83,-14704,2143473091,-4387559599038777245,37,829,4015442341,4602675983996931623,0.89542526,0.9567595541247681,ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU +a,3,-12,-9168,1489733240,-1569376002217735076,206,33821,3959216334,16060348691054629425,0.9488028,0.9293883502480845,oLZ21P2JEDooxV1pU31cIxQHEeeoLu +a,3,-72,-11122,-2141451704,-2578916903971263854,83,30296,1995343206,17452974532402389080,0.94209343,0.3231750610081745,e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG +a,2,-43,13080,370975815,5881039805148485053,2,20120,2939920218,906367167997372130,0.42733806,0.16301110515739792,m6jD0LBIQWaMfenwRCTANI9eOdyyto +a,5,-101,-12484,-842693467,-6140627905445351305,57,57885,2496054700,2243924747182709810,0.59520596,0.9491397432856566,QJYm7YRA3YetcBHI5wkMZeLXVmfuNy +b,1,29,-18218,994303988,5983957848665088916,204,9489,3275293996,14857091259186476033,0.53840446,0.17909035118828576,AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz +b,5,-82,22080,1824882165,7373730676428214987,208,34331,3342719438,3330177516592499461,0.82634634,0.40975383525297016,Ig1QcuKsjHXkproePdERo2w0mYzIqd +b,4,-111,-1967,-4229382,1892872227362838079,67,9832,1243785310,8382489916947120498,0.06563997,0.152498292971736,Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH +b,1,54,-18410,1413111008,-7145106120930085900,249,5382,1842680163,17818611040257178339,0.8881188,0.24899794314659673,6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ +b,3,17,14457,670497898,-2390782464845307388,255,24770,1538863055,12662506238151717757,0.34077626,0.7614304100703713,6x93sxYioWuq5c9Kkk8oTAAORM7cH0 +b,5,-5,24896,1955646088,2430204191283109071,118,43655,2424630722,11429640193932435507,0.87989986,0.7328050041291218,JafwVLSVk5AVoXFuzclesQ000EE2k1 +b,2,63,21456,-2138770630,-2380041687053733364,181,57594,2705709344,13144161537396946288,0.09683716,0.3051364088814128,nYVJnVicpGRqKZibHyBAmtmzBXAFfT +b,5,68,21576,1188285940,5717755781990389024,224,27600,974297360,9865419128970328044,0.80895734,0.7973920072996036,ioEncce3mPOXD2hWhpZpCPWGATG6GU +b,2,31,23127,-800561771,-8706387435232961848,153,27034,1098639440,3343692635488765507,0.35692692,0.5590205548347534,okOkcWflkNXIy4R8LzmySyY1EC3sYd +b,4,17,-28070,-673237643,1904316899655860234,188,27744,933879086,3732692885824435932,0.41860116,0.40342283197779727,JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ +b,2,-60,-21739,-1908480893,-8897292622858103761,59,50009,2525744318,1719090662556698549,0.52930677,0.560333188635217,l7uwDoTepWwnAP0ufqtHJS3CRi7RfP +b,4,-117,19316,2051224722,-5534418579506232438,133,52046,3023531799,13684453606722360110,0.62608826,0.8506721053047003,mhjME0zBHbrK6NMkytMTQzOssOa1gF +b,5,62,16337,41423756,-2274773899098124524,121,34206,2307004493,10575647935385523483,0.23794776,0.1754261586710173,qnPOOmslCJaT45buUisMRnM0rc77EK +b,2,68,15874,49866617,1179733259727844435,121,23948,3455216719,3898128009708892708,0.6306253,0.9185813970744787,802bgTGl6Bk5TlkPYYTxp5JkKyaYUA +b,1,12,7652,-1448995523,-5332734971209541785,136,49283,4076864659,15449267433866484283,0.6214579,0.05636955101974106,akiiY5N0I44CMwEnBL6RTBk7BRkxEj +b,4,-59,25286,1423957796,2646602445954944051,0,61069,3570297463,15100310750150419896,0.49619365,0.04893135681998029,fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG +b,3,-101,-13217,-346989627,5456800329302529236,26,54276,243203849,17929716297117857676,0.05422181,0.09465635123783445,MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ +b,5,-44,15788,-629486480,5822642169425315613,13,11872,3457053821,2413406423648025909,0.44318348,0.32869374687050157,ALuRhobVWbnQTTWZdSOk0iVe8oYFhW +b,4,47,20690,-1009656194,-2027442591571700798,200,7781,326151275,2881913079548128905,0.57360977,0.2145232647388039,52mKlRE3aHCBZtjECq6sY9OqVf8Dze +c,2,1,18109,2033001162,-6513304855495910254,25,43062,1491205016,5863949479783605708,0.110830784,0.9294097332465232,6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW +c,1,103,-22186,431378678,1346564663822463162,146,12393,3766999078,10901819591635583995,0.064453244,0.7784918983501654,2T3wSlHdEmASmO0xcXHnndkKEt6bz8 +c,2,-29,25305,-537142430,-7683452043175617798,150,31648,598822671,11759014161799384683,0.8315913,0.946325164889271,9UbObCsVkmYpJGcGrgfK90qOnwb2Lj +c,4,123,16620,852509237,-3087630526856906991,196,33715,3566741189,4546434653720168472,0.07606989,0.819715865079681,8LIh0b6jmDGm87BmIyjdxNIpX4ugjD +c,2,-60,-16312,-1808210365,-3368300253197863813,71,39635,2844041986,7045482583778080653,0.805363,0.6425694115212065,BJqx5WokrmrrezZA0dUbleMYkG5U2O +c,1,41,-4667,-644225469,7049620391314639084,196,48099,2125812933,15419512479294091215,0.5780736,0.9255031346434324,mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS +c,3,73,-9565,-382483011,1765659477910680019,186,1535,1088543984,2906943497598597237,0.680652,0.6009475544728957,Ow5PGpfTm4dXCfTDsXAOTatXRoAydR +c,3,-2,-18655,-2141999138,-3154042970870838072,251,34970,3862393166,13062025193350212516,0.034291923,0.7697753383420857,IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr +c,3,22,13741,-2098805236,8604102724776612452,45,2516,1362369177,196777795886465166,0.94669616,0.0494924465469434,6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE +c,1,-24,-24085,-1882293856,7385529783747709716,41,48048,520189543,2402288956117186783,0.39761502,0.3600766362333053,Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u +c,2,-106,-1114,-1927628110,1080308211931669384,177,20421,141680161,7464432081248293405,0.56749094,0.565352842229935,Vp3gmWunM5A7wOC9YW2JroFqTWjvTi +c,4,-79,5281,-237425046,373011991904079451,121,55620,2818832252,2464584078983135763,0.49774808,0.9237877978193884,t6fQUjJejPcjc04wHvHTPe55S65B4V +c,1,70,27752,1325868318,1241882478563331892,63,61637,473294098,4976799313755010034,0.13801557,0.5081765563442366,Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn +c,5,-94,-15880,2025611582,-3348824099853919681,5,40622,4268716378,12849419495718510869,0.34163946,0.4830878559436823,RilTlL1tKkPOUFuzmLydHAVZwv1OGl +c,4,-90,-2935,1579876740,6733733506744649678,254,12876,3593959807,4094315663314091142,0.5708688,0.5603062368164834,Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV +c,2,-117,-30187,-1222533990,-191957437217035800,136,47061,2293105904,12659011877190539078,0.2047385,0.9706712283358269,pLk3i59bZwd5KBZrI1FiweYTd5hteG +c,2,29,-3855,1354539333,4742062657200940467,81,53815,3398507249,562977550464243101,0.7124534,0.991517828651004,Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0 +c,4,3,-30508,659422734,-6455460736227846736,133,59663,2306130875,8622584762448622224,0.16999894,0.4273123318932347,EcCuckwsF3gV1Ecgmh5v4KM8g1ozif +c,2,-107,-2904,-1011669561,782342092880993439,18,29527,1157161427,4403623840168496677,0.31988364,0.36936304600612724,QYlaIAnJA6r8rlAb6f59wcxvcPcWFf +c,5,118,19208,-134213907,-2120241105523909127,86,57751,1229567292,16493024289408725403,0.5536642,0.9723580396501548,TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX +c,3,97,29106,-903316089,2874859437662206732,207,42171,3473924576,8188072741116415408,0.32792538,0.2667177795079635,HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g +d,5,-40,22614,706441268,-7542719935673075327,155,14337,3373581039,11720144131976083864,0.69632107,0.3114712539863804,C2GT5KVyOPZpgKVl110TyZO0NcJ434 +d,1,38,18384,-335410409,-1632237090406591229,26,57510,2712615025,1842662804748246269,0.6064476,0.6404495093354053,4HX6feIvmNXBN7XGqgO4YVBkhu8GDI +d,1,57,28781,-1143802338,2662536767954229885,202,62167,879082834,4338034436871150616,0.7618384,0.42950521730777025,VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4 +d,2,113,3917,-108973366,-7220140168410319165,197,24380,63044568,4225581724448081782,0.11867094,0.2944158618048994,90gAtmGEeIqUTbo1ZrxCvWtsseukXC +d,1,-98,13630,-1991133944,1184110014998006843,220,2986,225513085,9634106610243643486,0.89651865,0.1640882545084913,y7C453hRWd4E7ImjNDWlpexB8nUqjh +d,3,77,15091,-1302295658,8795481303066536947,154,35477,2093538928,17419098323248948387,0.11952883,0.7035635283169166,O66j6PaYuZhEUtqV6fuU7TyjM2WxC5 +d,1,-99,5613,1213926989,-8863698443222021480,19,18736,4216440507,14933742247195536130,0.6067944,0.33639590659276175,aDxBtor7Icd9C5hnTvvw5NrIre740e +d,2,93,-12642,2053379412,6468763445799074329,147,50842,1000948272,5536487915963301239,0.4279275,0.28534428578703896,lqhzgLsXZ8JhtpeeUWWNbMz8PHI705 +d,4,102,-24558,1991172974,-7823479531661596016,14,36599,1534194097,2240998421986827216,0.028003037,0.8824879447595726,0og6hSkhbX8AC1ktFS4kounvTzy8Vo +d,1,-8,27138,-1383162419,7682021027078563072,36,64517,2861376515,9904216782086286050,0.80954456,0.9463098243875633,AFGCj7OWlEB5QfniEFgonMq90Tq5uH +d,1,125,31106,-1176490478,-4306856842351827308,90,17910,3625286410,17869394731126786457,0.8882508,0.7631239070049998,dVdvo6nUD5FgCgsbOZLds28RyGTpnx +d,5,-59,2045,-2117946883,1170799768349713170,189,63353,1365198901,2501626630745849169,0.75173044,0.18628859265874176,F7NSTjWvQJyBburN7CXRUlbgp2dIrA +d,4,55,-1471,1902023838,1252101628560265705,157,3691,811650497,1524771507450695976,0.2968701,0.5437595540422571,f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX +d,3,-76,8809,141218956,-9110406195556445909,58,5494,1824517658,12046662515387914426,0.8557294,0.6668423897406515,Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK +d,2,122,10130,-168758331,-3179091803916845592,30,794,4061635107,15695681119022625322,0.69592506,0.9748360509016578,OPwBqCEK5PWTjWaiOyL45u2NLTaDWv +d,1,-72,25590,1188089983,3090286296481837049,241,832,3542840110,5885937420286765261,0.41980565,0.21535402343780985,wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +d,4,5,-7688,702611616,6239356364381313700,4,39363,3126475872,35363005357834672,0.3766935,0.061029375346466685,H5j5ZHy1FGesOAHjkQEDYCucbpKWRu +d,3,123,29533,240273900,1176001466590906949,117,30972,2592330556,12883447461717956514,0.39075065,0.38870280983958583,1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO +e,3,104,-25136,1738331255,300633854973581194,139,20807,3577318119,13079037564113702254,0.40154034,0.7764360990307122,DuJNG8tufSqW0ZstHqWj3aGvFLMg4A +e,3,112,-6823,-421042466,8535335158538929274,129,32712,3759340273,9916295859593918600,0.6424343,0.6316565296547284,BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE +e,2,49,24495,-587831330,9178511478067509438,129,12757,1289293657,10948666249269100825,0.5610077,0.5991138115095911,bgK1r6v3BCTh0aejJUhkA1Hn6idXGp +e,2,97,18167,1593800404,-9112448817105133638,163,45185,3188005828,2792105417953811674,0.38175434,0.4094218353587008,ukOiFGGFnQJDHFgZxHMpvhD3zybF0M +e,4,-56,-31500,1544188174,3096047390018154410,220,417,557517119,2774306934041974261,0.15459597,0.19113293583306745,IZTkHMLvIKuiLjhDjYMmIHxh166we4 +e,4,-53,13788,2064155045,-691093532952651300,243,35106,2778168728,9463973906560740422,0.34515214,0.27159190516490006,0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm +e,4,97,-13181,2047637360,6176835796788944083,158,53000,2042457019,9726016502640071617,0.7085086,0.12357539988406441,oHJMNvWuunsIMIWFnYG31RCfkOo2V7 +e,1,36,-21481,-928766616,-3471238138418013024,150,52569,2610290479,7788847578701297242,0.2578469,0.7670021786149205,gpo8K5qtYePve6jyPt6xgJx4YOVjms +e,2,52,23388,715235348,605432070100399212,165,56980,3314983189,7386391799827871203,0.46076488,0.980809631269599,jQimhdepw3GKmioWUlVSWeBVRKFkY3 +e,4,73,-22501,1282464673,2541794052864382235,67,21119,538589788,9575476605699527641,0.48515016,0.296036538664718,4JznSdBajNWhu4hRQwjV1FjTTxY68i +e,2,-61,-2888,-1660426473,2553892468492435401,126,35429,4144173353,939909697866979632,0.4405142,0.9231889896940375,BPtQMxnuSPpxMExYV9YkDa6cAN7GP3 +e,4,74,-12612,-1885422396,1702850374057819332,130,3583,3198969145,10767179755613315144,0.5518061,0.5614503754617461,QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv +e,3,71,194,1436496767,-5639533800082367925,158,44507,3105312559,3998472996619161534,0.930117,0.6108938307533,pTeu0WMjBRTaNRT15rLCuEh3tBJVc5 +e,1,71,-5479,-1339586153,-3920238763788954243,123,53012,4229654142,10297218950720052365,0.73473036,0.5773498217058918,cBGc0kSm32ylBDnxogG727C0uhZEYZ +e,4,96,-30336,427197269,7506304308750926996,95,48483,3521368277,5437030162957481122,0.58104324,0.42073125331890115,3BEOHQsMEFZ58VcNTOJYShTBpAPzbt +e,2,52,-12056,-1090239422,9011500141803970147,238,4168,2013662838,12565360638488684051,0.6694766,0.39144436569161134,xipQ93429ksjNcXPX5326VSg1xJZcW +e,5,64,-26526,1689098844,8950618259486183091,224,45253,662099130,16127995415060805595,0.2897315,0.5759450483859969,56MZa5O1hVtX4c5sbnCfxuX5kDChqI +e,5,-86,32514,-467659022,-8012578250188146150,254,2684,2861911482,2126626171973341689,0.12559289,0.01479305307777301,gxfHWUF8XgY2KdFxigxvNEXe2V2XMl +e,1,120,10837,-1331533190,6342019705133850847,245,3975,2830981072,16439861276703750332,0.6623719,0.9965400387585364,LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW +e,3,-95,13611,2030965207,927403809957470678,119,59134,559847112,10966649192992996919,0.5301289,0.047343434291126085,gTpyQnEODMcpsPnJMZC66gh33i3m0b +e,4,30,-16110,61035129,-3356533792537910152,159,299,28774375,13526465947516666293,0.6999775,0.03968347085780355,cq4WSAIFwx3wwTUS5bp1wCe71R6U5I diff --git a/datafusion/core/tests/data/cars.csv b/datafusion/core/tests/data/cars.csv new file mode 100644 index 0000000000000..bc40f3b01e7a5 --- /dev/null +++ b/datafusion/core/tests/data/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion/core/tests/data/clickbench_hits_10.parquet b/datafusion/core/tests/data/clickbench_hits_10.parquet new file mode 100644 index 0000000000000..c57421d5834b3 Binary files /dev/null and b/datafusion/core/tests/data/clickbench_hits_10.parquet differ diff --git a/datafusion/core/tests/data/empty_0_byte.csv b/datafusion/core/tests/data/empty_0_byte.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/all_empty/empty0.csv b/datafusion/core/tests/data/empty_files/all_empty/empty0.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/all_empty/empty1.csv b/datafusion/core/tests/data/empty_files/all_empty/empty1.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/all_empty/empty2.csv b/datafusion/core/tests/data/empty_files/all_empty/empty2.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/some_empty/a_empty.csv b/datafusion/core/tests/data/empty_files/some_empty/a_empty.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/some_empty/b.csv b/datafusion/core/tests/data/empty_files/some_empty/b.csv new file mode 100644 index 0000000000000..195c0be7c031c --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty/b.csv @@ -0,0 +1,6 @@ +1 +1 +1 +1 +1 + diff --git a/datafusion/core/tests/data/empty_files/some_empty/c_empty.csv b/datafusion/core/tests/data/empty_files/some_empty/c_empty.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/empty_files/some_empty/d.csv b/datafusion/core/tests/data/empty_files/some_empty/d.csv new file mode 100644 index 0000000000000..195c0be7c031c --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty/d.csv @@ -0,0 +1,6 @@ +1 +1 +1 +1 +1 + diff --git a/datafusion/core/tests/data/empty_files/some_empty/e_empty.csv b/datafusion/core/tests/data/empty_files/some_empty/e_empty.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/data/escape.csv b/datafusion/core/tests/data/escape.csv new file mode 100644 index 0000000000000..331a1e697329f --- /dev/null +++ b/datafusion/core/tests/data/escape.csv @@ -0,0 +1,11 @@ +c1,c2 +"id0","value\"0" +"id1","value\"1" +"id2","value\"2" +"id3","value\"3" +"id4","value\"4" +"id5","value\"5" +"id6","value\"6" +"id7","value\"7" +"id8","value\"8" +"id9","value\"9" diff --git a/datafusion/core/tests/data/fixed_size_list_array.parquet b/datafusion/core/tests/data/fixed_size_list_array.parquet new file mode 100644 index 0000000000000..aafc5ce62f52a Binary files /dev/null and b/datafusion/core/tests/data/fixed_size_list_array.parquet differ diff --git a/datafusion/core/tests/data/one_col.csv b/datafusion/core/tests/data/one_col.csv new file mode 100644 index 0000000000000..00a21774480fc --- /dev/null +++ b/datafusion/core/tests/data/one_col.csv @@ -0,0 +1,10 @@ +5 +5 +5 +5 +5 +5 +5 +5 +5 +5 diff --git a/datafusion/core/tests/data/parquet_map.parquet b/datafusion/core/tests/data/parquet_map.parquet new file mode 100644 index 0000000000000..e7ffb5115c44f Binary files /dev/null and b/datafusion/core/tests/data/parquet_map.parquet differ diff --git a/datafusion/core/tests/data/quote.csv b/datafusion/core/tests/data/quote.csv new file mode 100644 index 0000000000000..d814884364095 --- /dev/null +++ b/datafusion/core/tests/data/quote.csv @@ -0,0 +1,11 @@ +c1,c2 +~id0~,~value0~ +~id1~,~value1~ +~id2~,~value2~ +~id3~,~value3~ +~id4~,~value4~ +~id5~,~value5~ +~id6~,~value6~ +~id7~,~value7~ +~id8~,~value8~ +~id9~,~value9~ diff --git a/datafusion/core/tests/data/wide_rows.csv b/datafusion/core/tests/data/wide_rows.csv new file mode 100644 index 0000000000000..22bfb4a0ec9bb --- /dev/null +++ b/datafusion/core/tests/data/wide_rows.csv @@ -0,0 +1,3 @@ +1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2 + diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs similarity index 95% rename from datafusion/core/tests/dataframe_functions.rs rename to datafusion/core/tests/dataframe/dataframe_functions.rs index fb10caf1b07d6..9677003ec226f 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -80,7 +80,7 @@ macro_rules! assert_fn_batches { async fn test_fn_ascii() -> Result<()> { let expr = ascii(col("a")); - let expected = vec![ + let expected = [ "+---------------+", "| ascii(test.a) |", "+---------------+", @@ -97,7 +97,7 @@ async fn test_fn_ascii() -> Result<()> { async fn test_fn_bit_length() -> Result<()> { let expr = bit_length(col("a")); - let expected = vec![ + let expected = [ "+--------------------+", "| bit_length(test.a) |", "+--------------------+", @@ -116,7 +116,7 @@ async fn test_fn_bit_length() -> Result<()> { async fn test_fn_btrim() -> Result<()> { let expr = btrim(vec![lit(" a b c ")]); - let expected = vec![ + let expected = [ "+-----------------------------------------+", "| btrim(Utf8(\" a b c \")) |", "+-----------------------------------------+", @@ -133,7 +133,7 @@ async fn test_fn_btrim() -> Result<()> { async fn test_fn_btrim_with_chars() -> Result<()> { let expr = btrim(vec![col("a"), lit("ab")]); - let expected = vec![ + let expected = [ "+--------------------------+", "| btrim(test.a,Utf8(\"ab\")) |", "+--------------------------+", @@ -153,7 +153,7 @@ async fn test_fn_btrim_with_chars() -> Result<()> { async fn test_fn_approx_median() -> Result<()> { let expr = approx_median(col("b")); - let expected = vec![ + let expected = [ "+-----------------------+", "| APPROX_MEDIAN(test.b) |", "+-----------------------+", @@ -173,7 +173,7 @@ async fn test_fn_approx_median() -> Result<()> { async fn test_fn_approx_percentile_cont() -> Result<()> { let expr = approx_percentile_cont(col("b"), lit(0.5)); - let expected = vec![ + let expected = [ "+---------------------------------------------+", "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", "+---------------------------------------------+", @@ -194,7 +194,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); - let expected = vec![ + let expected = [ "+--------------------------+", "| character_length(test.a) |", "+--------------------------+", @@ -214,7 +214,7 @@ async fn test_fn_character_length() -> Result<()> { async fn test_fn_chr() -> Result<()> { let expr = chr(lit(128175)); - let expected = vec![ + let expected = [ "+--------------------+", "| chr(Int32(128175)) |", "+--------------------+", @@ -231,7 +231,7 @@ async fn test_fn_chr() -> Result<()> { async fn test_fn_initcap() -> Result<()> { let expr = initcap(col("a")); - let expected = vec![ + let expected = [ "+-----------------+", "| initcap(test.a) |", "+-----------------+", @@ -252,7 +252,7 @@ async fn test_fn_initcap() -> Result<()> { async fn test_fn_left() -> Result<()> { let expr = left(col("a"), lit(3)); - let expected = vec![ + let expected = [ "+-----------------------+", "| left(test.a,Int32(3)) |", "+-----------------------+", @@ -272,7 +272,7 @@ async fn test_fn_left() -> Result<()> { async fn test_fn_lower() -> Result<()> { let expr = lower(col("a")); - let expected = vec![ + let expected = [ "+---------------+", "| lower(test.a) |", "+---------------+", @@ -293,7 +293,7 @@ async fn test_fn_lower() -> Result<()> { async fn test_fn_lpad() -> Result<()> { let expr = lpad(vec![col("a"), lit(10)]); - let expected = vec![ + let expected = [ "+------------------------+", "| lpad(test.a,Int32(10)) |", "+------------------------+", @@ -314,7 +314,7 @@ async fn test_fn_lpad() -> Result<()> { async fn test_fn_lpad_with_string() -> Result<()> { let expr = lpad(vec![col("a"), lit(10), lit("*")]); - let expected = vec![ + let expected = [ "+----------------------------------+", "| lpad(test.a,Int32(10),Utf8(\"*\")) |", "+----------------------------------+", @@ -334,7 +334,7 @@ async fn test_fn_lpad_with_string() -> Result<()> { async fn test_fn_ltrim() -> Result<()> { let expr = ltrim(lit(" a b c ")); - let expected = vec![ + let expected = [ "+-----------------------------------------+", "| ltrim(Utf8(\" a b c \")) |", "+-----------------------------------------+", @@ -351,7 +351,7 @@ async fn test_fn_ltrim() -> Result<()> { async fn test_fn_ltrim_with_columns() -> Result<()> { let expr = ltrim(col("a")); - let expected = vec![ + let expected = [ "+---------------+", "| ltrim(test.a) |", "+---------------+", @@ -372,7 +372,7 @@ async fn test_fn_ltrim_with_columns() -> Result<()> { async fn test_fn_md5() -> Result<()> { let expr = md5(col("a")); - let expected = vec![ + let expected = [ "+----------------------------------+", "| md5(test.a) |", "+----------------------------------+", @@ -393,7 +393,7 @@ async fn test_fn_md5() -> Result<()> { async fn test_fn_regexp_match() -> Result<()> { let expr = regexp_match(vec![col("a"), lit("[a-z]")]); - let expected = vec![ + let expected = [ "+------------------------------------+", "| regexp_match(test.a,Utf8(\"[a-z]\")) |", "+------------------------------------+", @@ -414,7 +414,7 @@ async fn test_fn_regexp_match() -> Result<()> { async fn test_fn_regexp_replace() -> Result<()> { let expr = regexp_replace(vec![col("a"), lit("[a-z]"), lit("x"), lit("g")]); - let expected = vec![ + let expected = [ "+----------------------------------------------------------+", "| regexp_replace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\"),Utf8(\"g\")) |", "+----------------------------------------------------------+", @@ -434,7 +434,7 @@ async fn test_fn_regexp_replace() -> Result<()> { async fn test_fn_replace() -> Result<()> { let expr = replace(col("a"), lit("abc"), lit("x")); - let expected = vec![ + let expected = [ "+---------------------------------------+", "| replace(test.a,Utf8(\"abc\"),Utf8(\"x\")) |", "+---------------------------------------+", @@ -454,7 +454,7 @@ async fn test_fn_replace() -> Result<()> { async fn test_fn_repeat() -> Result<()> { let expr = repeat(col("a"), lit(2)); - let expected = vec![ + let expected = [ "+-------------------------+", "| repeat(test.a,Int32(2)) |", "+-------------------------+", @@ -475,7 +475,7 @@ async fn test_fn_repeat() -> Result<()> { async fn test_fn_reverse() -> Result<()> { let expr = reverse(col("a")); - let expected = vec![ + let expected = [ "+-----------------+", "| reverse(test.a) |", "+-----------------+", @@ -496,7 +496,7 @@ async fn test_fn_reverse() -> Result<()> { async fn test_fn_right() -> Result<()> { let expr = right(col("a"), lit(3)); - let expected = vec![ + let expected = [ "+------------------------+", "| right(test.a,Int32(3)) |", "+------------------------+", @@ -517,7 +517,7 @@ async fn test_fn_right() -> Result<()> { async fn test_fn_rpad() -> Result<()> { let expr = rpad(vec![col("a"), lit(11)]); - let expected = vec![ + let expected = [ "+------------------------+", "| rpad(test.a,Int32(11)) |", "+------------------------+", @@ -538,7 +538,7 @@ async fn test_fn_rpad() -> Result<()> { async fn test_fn_rpad_with_characters() -> Result<()> { let expr = rpad(vec![col("a"), lit(11), lit("x")]); - let expected = vec![ + let expected = [ "+----------------------------------+", "| rpad(test.a,Int32(11),Utf8(\"x\")) |", "+----------------------------------+", @@ -559,7 +559,7 @@ async fn test_fn_rpad_with_characters() -> Result<()> { async fn test_fn_sha224() -> Result<()> { let expr = sha224(col("a")); - let expected = vec![ + let expected = [ "+----------------------------------------------------------+", "| sha224(test.a) |", "+----------------------------------------------------------+", @@ -579,7 +579,7 @@ async fn test_fn_sha224() -> Result<()> { async fn test_fn_split_part() -> Result<()> { let expr = split_part(col("a"), lit("b"), lit(1)); - let expected = vec![ + let expected = [ "+---------------------------------------+", "| split_part(test.a,Utf8(\"b\"),Int32(1)) |", "+---------------------------------------+", @@ -598,7 +598,7 @@ async fn test_fn_split_part() -> Result<()> { async fn test_fn_starts_with() -> Result<()> { let expr = starts_with(col("a"), lit("abc")); - let expected = vec![ + let expected = [ "+---------------------------------+", "| starts_with(test.a,Utf8(\"abc\")) |", "+---------------------------------+", @@ -619,7 +619,7 @@ async fn test_fn_starts_with() -> Result<()> { async fn test_fn_strpos() -> Result<()> { let expr = strpos(col("a"), lit("f")); - let expected = vec![ + let expected = [ "+--------------------------+", "| strpos(test.a,Utf8(\"f\")) |", "+--------------------------+", @@ -639,7 +639,7 @@ async fn test_fn_strpos() -> Result<()> { async fn test_fn_substr() -> Result<()> { let expr = substr(col("a"), lit(2)); - let expected = vec![ + let expected = [ "+-------------------------+", "| substr(test.a,Int32(2)) |", "+-------------------------+", @@ -657,7 +657,7 @@ async fn test_fn_substr() -> Result<()> { #[tokio::test] async fn test_cast() -> Result<()> { let expr = cast(col("b"), DataType::Float64); - let expected = vec![ + let expected = [ "+--------+", "| test.b |", "+--------+", @@ -677,7 +677,7 @@ async fn test_cast() -> Result<()> { async fn test_fn_to_hex() -> Result<()> { let expr = to_hex(col("b")); - let expected = vec![ + let expected = [ "+----------------+", "| to_hex(test.b) |", "+----------------+", @@ -697,7 +697,7 @@ async fn test_fn_to_hex() -> Result<()> { async fn test_fn_translate() -> Result<()> { let expr = translate(col("a"), lit("bc"), lit("xx")); - let expected = vec![ + let expected = [ "+-----------------------------------------+", "| translate(test.a,Utf8(\"bc\"),Utf8(\"xx\")) |", "+-----------------------------------------+", @@ -716,7 +716,7 @@ async fn test_fn_translate() -> Result<()> { async fn test_fn_upper() -> Result<()> { let expr = upper(col("a")); - let expected = vec![ + let expected = [ "+---------------+", "| upper(test.a) |", "+---------------+", diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs new file mode 100644 index 0000000000000..da7589072bed4 --- /dev/null +++ b/datafusion/core/tests/dataframe/describe.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{ + assert_batches_eq, + prelude::{ParquetReadOptions, SessionContext}, +}; +use datafusion_common::{test_util::parquet_test_data, Result}; + +#[tokio::test] +async fn describe() -> Result<()> { + let ctx = parquet_context().await; + + let describe_record_batch = ctx + .table("alltypes_tiny_pages") + .await? + .describe() + .await? + .collect() + .await?; + + #[rustfmt::skip] + let expected = [ + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", + "| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |", + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", + "| count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", + "| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", + "| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45000000000001 | null | null | null | 2009.5 | 6.526027397260274 |", + "| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |", + "| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |", + "| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |", + "| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |", + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+" + ]; + assert_batches_eq!(expected, &describe_record_batch); + Ok(()) +} + +#[tokio::test] +async fn describe_boolean_binary() -> Result<()> { + let ctx = parquet_context().await; + + //add test case for only boolean boolean/binary column + let result = ctx + .sql("select 'a' as a,true as b") + .await? + .describe() + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = [ + "+------------+------+------+", + "| describe | a | b |", + "+------------+------+------+", + "| count | 1 | 1 |", + "| null_count | 1 | 1 |", + "| mean | null | null |", + "| std | null | null |", + "| min | a | null |", + "| max | a | null |", + "| median | null | null |", + "+------------+------+------+" + ]; + assert_batches_eq!(expected, &result); + Ok(()) +} + +/// Return a SessionContext with parquet file registered +async fn parquet_context() -> SessionContext { + let ctx = SessionContext::new(); + let testdata = parquet_test_data(); + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await + .unwrap(); + ctx +} diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe/mod.rs similarity index 69% rename from datafusion/core/tests/dataframe.rs rename to datafusion/core/tests/dataframe/mod.rs index f19f6d3c59083..c6b8e0e01b4f2 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -15,15 +15,20 @@ // specific language governing permissions and limitations // under the License. +// Include tests in dataframe_functions +mod dataframe_functions; +mod describe; + use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::pretty::pretty_format_batches; use arrow::{ array::{ - ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder, - StructBuilder, UInt32Array, UInt32Builder, + ArrayRef, FixedSizeListBuilder, Int32Array, Int32Builder, ListBuilder, + StringArray, StringBuilder, StructBuilder, UInt32Array, UInt32Builder, }, record_batch::RecordBatch, }; +use arrow_schema::ArrowError; use std::sync::Arc; use datafusion::dataframe::DataFrame; @@ -34,15 +39,13 @@ use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_common::{DataFusionError, ScalarValue, UnnestOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ - avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, - sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, + scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunction, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -60,8 +63,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let df_results = ctx .table("t1") .await? - .aggregate(vec![col("b")], vec![count(Wildcard)])? - .sort(vec![count(Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .explain(false, false)? .collect() .await?; @@ -95,8 +98,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -132,8 +135,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .filter(exists(Arc::new( ctx.table("t2") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .into_unoptimized_plan(), // Usually, into_optimized_plan() should be used here, but due to // https://github.com/apache/arrow-datafusion/issues/5771, @@ -168,7 +171,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -198,17 +201,17 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![count(Expr::Wildcard)])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; - // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. + // add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? - .select(vec![count(Expr::Wildcard)])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![count(wildcard())])? .explain(false, false)? .collect() .await?; @@ -234,7 +237,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 - // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here + // for compare difference between sql and df logical plan, we need to create a new SessionContext here let ctx = create_join_context()?; let df_results = ctx .table("t1") @@ -244,8 +247,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])? - .select(vec![count(lit(COUNT_STAR_EXPANSION))])? + .aggregate(vec![], vec![count(wildcard())])? + .select(vec![col(count(wildcard()).to_string())])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), @@ -264,68 +267,6 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { Ok(()) } -#[tokio::test] -async fn describe() -> Result<()> { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await?; - - let describe_record_batch = ctx - .table("alltypes_tiny_pages") - .await? - .describe() - .await? - .collect() - .await?; - - #[rustfmt::skip] - let expected = vec![ - "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", - "| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |", - "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", - "| count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", - "| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", - "| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45000000000001 | null | null | null | 2009.5 | 6.526027397260274 |", - "| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |", - "| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |", - "| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |", - "| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |", - "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", - ]; - assert_batches_eq!(expected, &describe_record_batch); - - //add test case for only boolean boolean/binary column - let result = ctx - .sql("select 'a' as a,true as b") - .await? - .describe() - .await? - .collect() - .await?; - #[rustfmt::skip] - let expected = vec![ - "+------------+------+------+", - "| describe | a | b |", - "+------------+------+------+", - "| count | 1 | 1 |", - "| null_count | 1 | 1 |", - "| mean | null | null |", - "| std | null | null |", - "| min | a | null |", - "| max | a | null |", - "| median | null | null |", - "+------------+------+------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) -} - #[tokio::test] async fn join() -> Result<()> { let schema1 = Arc::new(Schema::new(vec![ @@ -403,16 +344,14 @@ async fn sort_on_unprojected_columns() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ - "+-----+", + let expected = ["+-----+", "| a |", "+-----+", "| 100 |", "| 10 |", "| 10 |", "| 1 |", - "+-----+", - ]; + "+-----+"]; assert_batches_eq!(expected, &results); Ok(()) @@ -449,15 +388,13 @@ async fn sort_on_distinct_columns() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ - "+-----+", + let expected = ["+-----+", "| a |", "+-----+", "| 100 |", "| 10 |", "| 1 |", - "+-----+", - ]; + "+-----+"]; assert_batches_eq!(expected, &results); Ok(()) } @@ -487,7 +424,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { .distinct()? .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) .unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); + assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); Ok(()) } @@ -506,7 +443,7 @@ async fn sort_on_ambiguous_column() -> Result<()> { .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -525,7 +462,7 @@ async fn group_by_ambiguous_column() -> Result<()> { .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -544,7 +481,7 @@ async fn filter_on_ambiguous_column() -> Result<()> { .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -563,7 +500,7 @@ async fn select_ambiguous_column() -> Result<()> { .unwrap_err(); let expected = "Schema error: Ambiguous reference to unqualified field b"; - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); Ok(()) } @@ -591,14 +528,12 @@ async fn filter_with_alias_overwrite() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ - "+------+", + let expected = ["+------+", "| a |", "+------+", "| true |", "| true |", - "+------+", - ]; + "+------+"]; assert_batches_eq!(expected, &results); Ok(()) @@ -626,16 +561,14 @@ async fn select_with_alias_overwrite() -> Result<()> { let results = df.collect().await?; #[rustfmt::skip] - let expected = vec![ - "+-------+", + let expected = ["+-------+", "| a |", "+-------+", "| false |", "| true |", "| true |", "| false |", - "+-------+", - ]; + "+-------+"]; assert_batches_eq!(expected, &results); Ok(()) @@ -943,22 +876,20 @@ async fn unnest_columns() -> Result<()> { const NUM_ROWS: usize = 4; let df = table_with_nested_types(NUM_ROWS).await?; let results = df.collect().await?; - let expected = vec![ - "+----------+------------------------------------------------+--------------------+", + let expected = ["+----------+------------------------------------------------+--------------------+", "| shape_id | points | tags |", "+----------+------------------------------------------------+--------------------+", "| 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | [tag1] |", "| 2 | | [tag1, tag2] |", "| 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | |", "| 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | [tag1, tag2, tag3] |", - "+----------+------------------------------------------------+--------------------+", - ]; + "+----------+------------------------------------------------+--------------------+"]; assert_batches_sorted_eq!(expected, &results); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; let results = df.unnest_column("tags")?.collect().await?; - let expected = vec![ + let expected = [ "+----------+------------------------------------------------+------+", "| shape_id | points | tags |", "+----------+------------------------------------------------+------+", @@ -981,7 +912,7 @@ async fn unnest_columns() -> Result<()> { // Unnest points let df = table_with_nested_types(NUM_ROWS).await?; let results = df.unnest_column("points")?.collect().await?; - let expected = vec![ + let expected = [ "+----------+-----------------+--------------------+", "| shape_id | points | tags |", "+----------+-----------------+--------------------+", @@ -1042,13 +973,236 @@ async fn unnest_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_column_nulls() -> Result<()> { + let df = table_with_lists_and_nulls().await?; + let results = df.clone().collect().await?; + let expected = [ + "+--------+----+", + "| list | id |", + "+--------+----+", + "| [1, 2] | A |", + "| | B |", + "| [] | C |", + "| [3] | D |", + "+--------+----+", + ]; + assert_batches_eq!(expected, &results); + + // Unnest, preserving nulls (row with B is preserved) + let options = UnnestOptions::new().with_preserve_nulls(true); + + let results = df + .clone() + .unnest_column_with_options("list", options)? + .collect() + .await?; + let expected = [ + "+------+----+", + "| list | id |", + "+------+----+", + "| 1 | A |", + "| 2 | A |", + "| | B |", + "| 3 | D |", + "+------+----+", + ]; + assert_batches_eq!(expected, &results); + + let options = UnnestOptions::new().with_preserve_nulls(false); + let results = df + .unnest_column_with_options("list", options)? + .collect() + .await?; + let expected = [ + "+------+----+", + "| list | id |", + "+------+----+", + "| 1 | A |", + "| 2 | A |", + "| 3 | D |", + "+------+----+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_fixed_list() -> Result<()> { + let batch = get_fixed_list_batch()?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = [ + "+----------+----------------+", + "| shape_id | tags |", + "+----------+----------------+", + "| 1 | |", + "| 2 | [tag21, tag22] |", + "| 3 | [tag31, tag32] |", + "| 4 | |", + "| 5 | [tag51, tag52] |", + "| 6 | [tag61, tag62] |", + "+----------+----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + let options = UnnestOptions::new().with_preserve_nulls(true); + + let results = df + .unnest_column_with_options("tags", options)? + .collect() + .await?; + let expected = vec![ + "+----------+-------+", + "| shape_id | tags |", + "+----------+-------+", + "| 1 | |", + "| 2 | tag21 |", + "| 2 | tag22 |", + "| 3 | tag31 |", + "| 3 | tag32 |", + "| 4 | |", + "| 5 | tag51 |", + "| 5 | tag52 |", + "| 6 | tag61 |", + "| 6 | tag62 |", + "+----------+-------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_fixed_list_drop_nulls() -> Result<()> { + let batch = get_fixed_list_batch()?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = [ + "+----------+----------------+", + "| shape_id | tags |", + "+----------+----------------+", + "| 1 | |", + "| 2 | [tag21, tag22] |", + "| 3 | [tag31, tag32] |", + "| 4 | |", + "| 5 | [tag51, tag52] |", + "| 6 | [tag61, tag62] |", + "+----------+----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + let options = UnnestOptions::new().with_preserve_nulls(false); + + let results = df + .unnest_column_with_options("tags", options)? + .collect() + .await?; + let expected = [ + "+----------+-------+", + "| shape_id | tags |", + "+----------+-------+", + "| 2 | tag21 |", + "| 2 | tag22 |", + "| 3 | tag31 |", + "| 3 | tag32 |", + "| 5 | tag51 |", + "| 5 | tag52 |", + "| 6 | tag61 |", + "| 6 | tag62 |", + "+----------+-------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_fixed_list_nonull() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tags_builder = FixedSizeListBuilder::new(StringBuilder::new(), 2); + + for idx in 0..6 { + // Append shape id. + shape_id_builder.append_value(idx as u32 + 1); + + tags_builder + .values() + .append_value(format!("tag{}1", idx + 1)); + tags_builder + .values() + .append_value(format!("tag{}2", idx + 1)); + tags_builder.append(true); + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tags", Arc::new(tags_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = [ + "+----------+----------------+", + "| shape_id | tags |", + "+----------+----------------+", + "| 1 | [tag11, tag12] |", + "| 2 | [tag21, tag22] |", + "| 3 | [tag31, tag32] |", + "| 4 | [tag41, tag42] |", + "| 5 | [tag51, tag52] |", + "| 6 | [tag61, tag62] |", + "+----------+----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + let options = UnnestOptions::new().with_preserve_nulls(true); + let results = df + .unnest_column_with_options("tags", options)? + .collect() + .await?; + let expected = vec![ + "+----------+-------+", + "| shape_id | tags |", + "+----------+-------+", + "| 1 | tag11 |", + "| 1 | tag12 |", + "| 2 | tag21 |", + "| 2 | tag22 |", + "| 3 | tag31 |", + "| 3 | tag32 |", + "| 4 | tag41 |", + "| 4 | tag42 |", + "| 5 | tag51 |", + "| 5 | tag52 |", + "| 6 | tag61 |", + "| 6 | tag62 |", + "+----------+-------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn unnest_aggregate_columns() -> Result<()> { const NUM_ROWS: usize = 5; let df = table_with_nested_types(NUM_ROWS).await?; let results = df.select_columns(&["tags"])?.collect().await?; - let expected = vec![ + let expected = [ r#"+--------------------+"#, r#"| tags |"#, r#"+--------------------+"#, @@ -1067,7 +1221,7 @@ async fn unnest_aggregate_columns() -> Result<()> { .aggregate(vec![], vec![count(col("tags"))])? .collect() .await?; - let expected = vec![ + let expected = [ r#"+--------------------+"#, r#"| COUNT(shapes.tags) |"#, r#"+--------------------+"#, @@ -1079,6 +1233,181 @@ async fn unnest_aggregate_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn unnest_array_agg() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let results = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("tag_id")).alias("tag_id")], + )? + .collect() + .await?; + let expected = [ + "+----------+--------------+", + "| shape_id | tag_id |", + "+----------+--------------+", + "| 1 | [11, 12, 13] |", + "| 2 | [21, 22, 23] |", + "| 3 | [31, 32, 33] |", + "+----------+--------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Unnesting again should produce the original batch. + let results = ctx + .table("shapes") + .await? + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("tag_id")).alias("tag_id")], + )? + .unnest_column("tag_id")? + .collect() + .await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn unnest_with_redundant_columns() -> Result<()> { + let mut shape_id_builder = UInt32Builder::new(); + let mut tag_id_builder = UInt32Builder::new(); + + for shape_id in 1..=3 { + for tag_id in 1..=3 { + shape_id_builder.append_value(shape_id as u32); + tag_id_builder.append_value((shape_id * 10 + tag_id) as u32); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + let df = ctx.table("shapes").await?; + + let results = df.clone().collect().await?; + let expected = vec![ + "+----------+--------+", + "| shape_id | tag_id |", + "+----------+--------+", + "| 1 | 11 |", + "| 1 | 12 |", + "| 1 | 13 |", + "| 2 | 21 |", + "| 2 | 22 |", + "| 2 | 23 |", + "| 3 | 31 |", + "| 3 | 32 |", + "| 3 | 33 |", + "+----------+--------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + // Doing an `array_agg` by `shape_id` produces: + let df = df + .clone() + .aggregate( + vec![col("shape_id")], + vec![array_agg(col("shape_id")).alias("shape_id2")], + )? + .unnest_column("shape_id2")? + .select(vec![col("shape_id")])?; + + let optimized_plan = df.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: shapes.shape_id [shape_id:UInt32]", + " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected = [ + "+----------+", + "| shape_id |", + "+----------+", + "| 1 |", + "| 1 |", + "| 1 |", + "| 2 |", + "| 2 |", + "| 2 |", + "| 3 |", + "| 3 |", + "| 3 |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + async fn create_test_table(name: &str) -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), @@ -1222,6 +1551,71 @@ async fn table_with_nested_types(n: usize) -> Result { ctx.table("shapes").await } +fn get_fixed_list_batch() -> Result { + let mut shape_id_builder = UInt32Builder::new(); + let mut tags_builder = FixedSizeListBuilder::new(StringBuilder::new(), 2); + + for idx in 0..6 { + // Append shape id. + shape_id_builder.append_value(idx as u32 + 1); + + if idx % 3 != 0 { + tags_builder + .values() + .append_value(format!("tag{}1", idx + 1)); + tags_builder + .values() + .append_value(format!("tag{}2", idx + 1)); + tags_builder.append(true); + } else { + tags_builder.values().append_null(); + tags_builder.values().append_null(); + tags_builder.append(false); + } + } + + let batch = RecordBatch::try_from_iter(vec![ + ("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef), + ("tags", Arc::new(tags_builder.finish()) as ArrayRef), + ])?; + + Ok(batch) +} + +/// A a data frame that a list of integers and string IDs +async fn table_with_lists_and_nulls() -> Result { + let mut list_builder = ListBuilder::new(UInt32Builder::new()); + let mut id_builder = StringBuilder::new(); + + // [1, 2], A + list_builder.values().append_value(1); + list_builder.values().append_value(2); + list_builder.append(true); + id_builder.append_value("A"); + + // NULL, B + list_builder.append(false); + id_builder.append_value("B"); + + // [], C + list_builder.append(true); + id_builder.append_value("C"); + + // [3], D + list_builder.values().append_value(3); + list_builder.append(true); + id_builder.append_value("D"); + + let batch = RecordBatch::try_from_iter(vec![ + ("list", Arc::new(list_builder.finish()) as ArrayRef), + ("id", Arc::new(id_builder.finish()) as ArrayRef), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("shapes", batch)?; + ctx.table("shapes").await +} + pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> { let testdata = parquet_test_data(); ctx.register_parquet( @@ -1257,7 +1651,7 @@ async fn use_var_provider() -> Result<()> { let config = SessionConfig::new() .with_target_partitions(4) .set_bool("datafusion.optimizer.skip_failed_rules", false); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_table("csv_table", mem_table)?; ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {})); @@ -1268,3 +1662,23 @@ async fn use_var_provider() -> Result<()> { dataframe.collect().await?; Ok(()) } + +#[tokio::test] +async fn test_array_agg() -> Result<()> { + let df = create_test_table("test") + .await? + .aggregate(vec![], vec![array_agg(col("a"))])?; + + let results = df.collect().await?; + + let expected = [ + "+-------------------------------------+", + "| ARRAY_AGG(test.a) |", + "+-------------------------------------+", + "| [abcDEF, abc123, CBAdef, 123AbcDef] |", + "+-------------------------------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 76bc487eabb9c..93c7f7368065c 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -17,48 +17,54 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! -#[cfg(not(target_os = "windows"))] +#[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use arrow::array::Array; - use arrow::csv::ReaderBuilder; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::test_util::register_unbounded_file_with_ordering; - use datafusion::{ - prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, - }; - use datafusion_common::{DataFusionError, Result}; - use futures::StreamExt; - use itertools::enumerate; - use nix::sys::stat; - use nix::unistd; - use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; - use std::thread::JoinHandle; use std::time::{Duration, Instant}; + + use arrow::array::Array; + use arrow::csv::ReaderBuilder; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SchemaRef; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; use tempfile::TempDir; + use tokio::task::{spawn_blocking, JoinHandle}; - // ! For the sake of the test, do not alter the numbers. ! - // Session batch size - const TEST_BATCH_SIZE: usize = 20; - // Number of lines written to FIFO - const TEST_DATA_SIZE: usize = 20_000; - // Number of lines what can be joined. Each joinable key produced 20 lines with - // aggregate_test_100 dataset. We will use these joinable keys for understanding - // incremental execution. - const TEST_JOIN_RATIO: f64 = 0.01; + use datafusion::datasource::stream::{StreamConfig, StreamTable}; + use datafusion::datasource::TableProvider; + use datafusion::{ + prelude::{CsvReadOptions, SessionConfig, SessionContext}, + test_util::{aggr_test_schema, arrow_test_data}, + }; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::Expr; + + /// Makes a TableProvider for a fifo file + fn fifo_table( + schema: SchemaRef, + path: impl Into, + sort: Vec>, + ) -> Arc { + let config = StreamConfig::new_file(schema, path.into()) + .with_order(sort) + .with_batch_size(TEST_BATCH_SIZE) + .with_header(true); + Arc::new(StreamTable::new(Arc::new(config))) + } fn create_fifo_file(tmp_dir: &TempDir, file_name: &str) -> Result { let file_path = tmp_dir.path().join(file_name); // Simulate an infinite environment via a FIFO file if let Err(e) = unistd::mkfifo(&file_path, stat::Mode::S_IRWXU) { - Err(DataFusionError::Execution(e.to_string())) + exec_err!("{}", e) } else { Ok(file_path) } @@ -81,31 +87,62 @@ mod unix_test { continue; } } - return Err(DataFusionError::Execution(e.to_string())); + return exec_err!("{}", e); } Ok(()) } + fn create_writing_thread( + file_path: PathBuf, + header: String, + lines: Vec, + waiting_lock: Arc, + wait_until: usize, + ) -> JoinHandle<()> { + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(10); + let sa = file_path.clone(); + // Spawn a new thread to write to the FIFO file + spawn_blocking(move || { + let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + for (cnt, line) in lines.iter().enumerate() { + while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { + thread::sleep(Duration::from_millis(50)); + } + write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + } + drop(file); + }) + } + + // ! For the sake of the test, do not alter the numbers. ! + // Session batch size + const TEST_BATCH_SIZE: usize = 20; + // Number of lines written to FIFO + const TEST_DATA_SIZE: usize = 20_000; + // Number of lines what can be joined. Each joinable key produced 20 lines with + // aggregate_test_100 dataset. We will use these joinable keys for understanding + // incremental execution. + const TEST_JOIN_RATIO: f64 = 0.01; + // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn unbounded_file_with_swapped_join( - #[values(true, false)] unbounded_file: bool, - ) -> Result<()> { + async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) .with_collect_statistics(false) .with_target_partitions(1); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // To make unbounded deterministic - let waiting = Arc::new(AtomicBool::new(unbounded_file)); + let waiting = Arc::new(AtomicBool::new(true)); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = - create_fifo_file(&tmp_dir, &format!("fifo_{unbounded_file:?}.csv"))?; + let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; // Execution can calculated at least one RecordBatch after the number of // "joinable_lines_length" lines are read. let joinable_lines_length = @@ -129,7 +166,7 @@ mod unix_test { "a1,a2\n".to_owned(), lines, waiting.clone(), - joinable_lines_length, + joinable_lines_length * 2, ); // Data Schema @@ -137,15 +174,10 @@ mod unix_test { Field::new("a1", DataType::Utf8, false), Field::new("a2", DataType::UInt32, false), ])); - // Create a file with bounded or unbounded flag. - ctx.register_csv( - "left", - fifo_path.as_os_str().to_str().unwrap(), - CsvReadOptions::new() - .schema(schema.as_ref()) - .mark_infinite(unbounded_file), - ) - .await?; + + let provider = fifo_table(schema, fifo_path, vec![]); + ctx.register_table("left", provider).unwrap(); + // Register right table let schema = aggr_test_schema(); let test_data = arrow_test_data(); @@ -161,7 +193,7 @@ mod unix_test { while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); } - task.join().unwrap(); + task.await.unwrap(); Ok(()) } @@ -172,46 +204,17 @@ mod unix_test { Equal, } - fn create_writing_thread( - file_path: PathBuf, - header: String, - lines: Vec, - waiting_lock: Arc, - wait_until: usize, - ) -> JoinHandle<()> { - // Timeout for a long period of BrokenPipe error - let broken_pipe_timeout = Duration::from_secs(10); - // Spawn a new thread to write to the FIFO file - thread::spawn(move || { - let file = OpenOptions::new().write(true).open(file_path).unwrap(); - // Reference time to use when deciding to fail the test - let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); - for (cnt, line) in enumerate(lines) { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); - } - write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) - .unwrap(); - } - drop(file); - }) - } - // This test provides a relatively realistic end-to-end scenario where // we change the join into a [SymmetricHashJoin] to accommodate two // unbounded (FIFO) sources. - #[rstest] - #[timeout(std::time::Duration::from_secs(30))] - #[tokio::test(flavor = "multi_thread")] - #[ignore] + #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context let config = SessionConfig::new() .with_batch_size(TEST_BATCH_SIZE) .set_bool("datafusion.execution.coalesce_batches", false) .with_target_partitions(1); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // Tasks let mut tasks: Vec> = vec![]; @@ -254,47 +257,30 @@ mod unix_test { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); + // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + // Set unbounded sorted files read configuration - register_unbounded_file_with_ordering( - &ctx, - schema.clone(), - &left_fifo, - "left", - file_sort_order.clone(), - true, - ) - .await?; - register_unbounded_file_with_ordering( - &ctx, - schema, - &right_fifo, - "right", - file_sort_order, - true, - ) - .await?; + let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + ctx.register_table("left", provider)?; + + let provider = fifo_table(schema.clone(), right_fifo, order); + ctx.register_table("right", provider)?; + // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 - FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + t1.a1, + t1.a2, + t2.a1, + t2.a2 + FROM + left as t1 FULL + JOIN right as t2 ON t1.a2 = t2.a2 + AND t1.a1 > t2.a1 + 4 + AND t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; @@ -313,7 +299,8 @@ mod unix_test { }; operations.push(op); } - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); + // The SymmetricHashJoin executor produces FULL join results at every // pruning, which happens before it reaches the end of input and more // than once. In this test, we feed partially joinable data to both @@ -342,7 +329,7 @@ mod unix_test { let waiting_thread = waiting.clone(); // create local execution context let config = SessionConfig::new().with_batch_size(TEST_BATCH_SIZE); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; let source_fifo_path = create_fifo_file(&tmp_dir, "source.csv")?; @@ -368,8 +355,9 @@ mod unix_test { // Prevent move let (sink_fifo_path_thread, sink_display_fifo_path) = (sink_fifo_path.clone(), sink_fifo_path.display()); + // Spawn a new thread to read sink EXTERNAL TABLE. - tasks.push(thread::spawn(move || { + tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::Utf8, false), @@ -377,7 +365,6 @@ mod unix_test { ])); let mut reader = ReaderBuilder::new(schema) - .has_header(true) .with_batch_size(TEST_BATCH_SIZE) .build(file) .map_err(|e| DataFusionError::Internal(e.to_string())) @@ -389,38 +376,35 @@ mod unix_test { })); // register second csv file with the SQL (create an empty file if not found) ctx.sql(&format!( - "CREATE EXTERNAL TABLE source_table ( + "CREATE UNBOUNDED EXTERNAL TABLE source_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{source_display_fifo_path}'" )) .await?; // register csv file with the SQL ctx.sql(&format!( - "CREATE EXTERNAL TABLE sink_table ( + "CREATE UNBOUNDED EXTERNAL TABLE sink_table ( a1 VARCHAR NOT NULL, a2 INT NOT NULL ) STORED AS CSV WITH HEADER ROW - OPTIONS ('UNBOUNDED' 'TRUE') LOCATION '{sink_display_fifo_path}'" )) .await?; let df = ctx - .sql( - "INSERT INTO sink_table - SELECT a1, a2 FROM source_table", - ) + .sql("INSERT INTO sink_table SELECT a1, a2 FROM source_table") .await?; + + // Start execution df.collect().await?; - tasks.into_iter().for_each(|jh| jh.join().unwrap()); + futures::future::try_join_all(tasks).await.unwrap(); Ok(()) } } diff --git a/datafusion/core/tests/fuzz.rs b/datafusion/core/tests/fuzz.rs new file mode 100644 index 0000000000000..92646e8b37636 --- /dev/null +++ b/datafusion/core/tests/fuzz.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Run all tests that are found in the `fuzz_cases` directory +mod fuzz_cases; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); +} diff --git a/datafusion/core/tests/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs similarity index 76% rename from datafusion/core/tests/aggregate_fuzz.rs rename to datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 74370049e81fb..821f236af87b5 100644 --- a/datafusion/core/tests/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -28,45 +28,40 @@ use datafusion::physical_plan::aggregates::{ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test(flavor = "multi_thread", worker_threads = 8)] - async fn aggregate_test() { - let test_cases = vec![ - vec!["a"], - vec!["b", "a"], - vec!["c", "a"], - vec!["c", "b", "a"], - vec!["d", "a"], - vec!["d", "b", "a"], - vec!["d", "c", "a"], - vec!["d", "c", "b", "a"], - ]; - let n = 300; - let distincts = vec![10, 20]; - for distinct in distincts { - let mut handles = Vec::new(); - for i in 0..n { - let test_idx = i % test_cases.len(); - let group_by_columns = test_cases[test_idx].clone(); - let job = tokio::spawn(run_aggregate_test( - make_staggered_batches::(1000, distinct, i as u64), - group_by_columns, - )); - handles.push(job); - } - for job in handles { - job.await.unwrap(); - } +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +async fn aggregate_test() { + let test_cases = vec![ + vec!["a"], + vec!["b", "a"], + vec!["c", "a"], + vec!["c", "b", "a"], + vec!["d", "a"], + vec!["d", "b", "a"], + vec!["d", "c", "a"], + vec!["d", "c", "b", "a"], + ]; + let n = 300; + let distincts = vec![10, 20]; + for distinct in distincts { + let mut handles = Vec::new(); + for i in 0..n { + let test_idx = i % test_cases.len(); + let group_by_columns = test_cases[test_idx].clone(); + let job = tokio::spawn(run_aggregate_test( + make_staggered_batches::(1000, distinct, i as u64), + group_by_columns, + )); + handles.push(job); + } + for job in handles { + job.await.unwrap(); } } } @@ -77,7 +72,7 @@ mod tests { async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str>) { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let mut sort_keys = vec![]; for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { @@ -94,7 +89,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(sort_keys), + .with_sort_information(vec![sort_keys]), ); let aggregate_expr = vec![Arc::new(Sum::new( @@ -107,6 +102,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) .collect::>(); let group_by = PhysicalGroupBy::new_single(expr); + let aggregate_exec_running = Arc::new( AggregateExec::try_new( AggregateMode::Partial, @@ -118,7 +114,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as _; + ) as Arc; let aggregate_exec_usual = Arc::new( AggregateExec::try_new( @@ -131,14 +127,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str schema.clone(), ) .unwrap(), - ) as _; + ) as Arc; let task_ctx = ctx.task_ctx(); - let collected_usual = collect(aggregate_exec_usual, task_ctx.clone()) + let collected_usual = collect(aggregate_exec_usual.clone(), task_ctx.clone()) .await .unwrap(); - let collected_running = collect(aggregate_exec_running, task_ctx.clone()) + let collected_running = collect(aggregate_exec_running.clone(), task_ctx.clone()) .await .unwrap(); assert!(collected_running.len() > 2); @@ -162,7 +158,25 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .zip(&running_formatted_sorted) .enumerate() { - assert_eq!((i, usual_line), (i, running_line), "Inconsistent result"); + assert_eq!( + (i, usual_line), + (i, running_line), + "Inconsistent result\n\n\ + Aggregate_expr: {aggregate_expr:?}\n\ + group_by: {group_by:?}\n\ + Left Plan:\n{}\n\ + Right Plan:\n{}\n\ + schema:\n{schema}\n\ + Left Ouptut:\n{}\n\ + Right Output:\n{}\n\ + input:\n{}\n\ + ", + displayable(aggregate_exec_usual.as_ref()).indent(false), + displayable(aggregate_exec_running.as_ref()).indent(false), + usual_formatted, + running_formatted, + pretty_format_batches(&input1).unwrap(), + ); } } @@ -192,7 +206,7 @@ pub(crate) fn make_staggered_batches( let input1 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0)); let input2 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.1)); let input3 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.2)); - let input4 = Int64Array::from_iter_values(input4.into_iter()); + let input4 = Int64Array::from_iter_values(input4); // split into several record batches let mut remainder = RecordBatch::try_from_iter(vec![ diff --git a/datafusion/core/tests/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs similarity index 97% rename from datafusion/core/tests/join_fuzz.rs rename to datafusion/core/tests/fuzz_cases/join_fuzz.rs index 48e3da1886782..ac86364f42551 100644 --- a/datafusion/core/tests/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -102,7 +102,7 @@ async fn run_join_test( let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; for batch_size in batch_sizes { let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); let schema1 = input1[0].schema(); @@ -195,8 +195,8 @@ fn make_staggered_batches(len: usize) -> Vec { input12.sort_unstable(); let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); - let input3 = Int32Array::from_iter_values(input3.into_iter()); - let input4 = Int32Array::from_iter_values(input4.into_iter()); + let input3 = Int32Array::from_iter_values(input3); + let input4 = Int32Array::from_iter_values(input4); // split into several record batches let batch = RecordBatch::try_from_iter(vec![ diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs new file mode 100644 index 0000000000000..9889ce2ae562a --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for Sort + Fetch/Limit (TopK!) + +use arrow::compute::concat_batches; +use arrow::util::pretty::pretty_format_batches; +use arrow::{array::Int32Array, record_batch::RecordBatch}; +use arrow_array::{Float64Array, Int64Array, StringArray}; +use arrow_schema::SchemaRef; +use datafusion::datasource::MemTable; +use datafusion::prelude::SessionContext; +use datafusion_common::assert_contains; +use rand::{thread_rng, Rng}; +use std::sync::Arc; +use test_utils::stagger_batch; + +#[tokio::test] +async fn test_sort_topk_i32() { + run_limit_fuzz_test(SortedData::new_i32).await +} + +#[tokio::test] +async fn test_sort_topk_f64() { + run_limit_fuzz_test(SortedData::new_f64).await +} + +#[tokio::test] +async fn test_sort_topk_str() { + run_limit_fuzz_test(SortedData::new_str).await +} + +#[tokio::test] +async fn test_sort_topk_i64str() { + run_limit_fuzz_test(SortedData::new_i64str).await +} + +/// Run TopK fuzz tests the specified input data with different +/// different test functions so they can run in parallel) +async fn run_limit_fuzz_test(make_data: F) +where + F: Fn(usize) -> SortedData, +{ + let mut rng = thread_rng(); + for size in [10, 1_0000, 10_000, 100_000] { + let data = make_data(size); + // test various limits including some random ones + for limit in [1, 3, 7, 17, 10000, rng.gen_range(1..size * 2)] { + // limit can be larger than the number of rows in the input + run_limit_test(limit, &data).await; + } + } +} + +/// The data column(s) to use for the TopK test +/// +/// Each variants stores the input batches and the expected sorted values +/// compute the expected output for a given fetch (limit) value. +#[derive(Debug)] +enum SortedData { + // single Int32 column + I32 { + batches: Vec, + sorted: Vec>, + }, + /// Single Float64 column + F64 { + batches: Vec, + sorted: Vec>, + }, + /// Single sorted String column + Str { + batches: Vec, + sorted: Vec>, + }, + /// (i64, string) columns + I64Str { + batches: Vec, + sorted: Vec<(Option, Option)>, + }, +} + +impl SortedData { + /// Create an i32 column of random values, with the specified number of + /// rows, sorted the default + fn new_i32(size: usize) -> Self { + let mut rng = thread_rng(); + // have some repeats (approximately 1/3 of the values are the same) + let max = size as i32 / 3; + let data: Vec> = (0..size) + .map(|_| { + // no nulls for now + Some(rng.gen_range(0..max)) + }) + .collect(); + + let batches = stagger_batch(int32_batch(data.iter().cloned())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::I32 { batches, sorted } + } + + /// Create an f64 column of random values, with the specified number of + /// rows, sorted the default + fn new_f64(size: usize) -> Self { + let mut rng = thread_rng(); + let mut data: Vec> = (0..size / 3) + .map(|_| { + // no nulls for now + Some(rng.gen_range(0.0..1.0f64)) + }) + .collect(); + + // have some repeats (approximately 1/3 of the values are the same) + while data.len() < size { + data.push(data[rng.gen_range(0..data.len())]); + } + + let batches = stagger_batch(f64_batch(data.iter().cloned())); + + let mut sorted = data; + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + Self::F64 { batches, sorted } + } + + /// Create an string column of random values, with the specified number of + /// rows, sorted the default + fn new_str(size: usize) -> Self { + let mut rng = thread_rng(); + let mut data: Vec> = (0..size / 3) + .map(|_| { + // no nulls for now + Some(get_random_string(16)) + }) + .collect(); + + // have some repeats (approximately 1/3 of the values are the same) + while data.len() < size { + data.push(data[rng.gen_range(0..data.len())].clone()); + } + + let batches = stagger_batch(string_batch(data.iter())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::Str { batches, sorted } + } + + /// Create two columns of random values (int64, string), with the specified number of + /// rows, sorted the default + fn new_i64str(size: usize) -> Self { + let mut rng = thread_rng(); + + // 100 distinct values + let strings: Vec> = (0..100) + .map(|_| { + // no nulls for now + Some(get_random_string(16)) + }) + .collect(); + + // form inputs, with only 10 distinct integer values , to force collision checks + let data = (0..size) + .map(|_| { + ( + Some(rng.gen_range(0..10)), + strings[rng.gen_range(0..strings.len())].clone(), + ) + }) + .collect::>(); + + let batches = stagger_batch(i64string_batch(data.iter())); + + let mut sorted = data; + sorted.sort_unstable(); + + Self::I64Str { batches, sorted } + } + + /// Return top top `limit` values as a RecordBatch + fn topk_values(&self, limit: usize) -> RecordBatch { + match self { + Self::I32 { sorted, .. } => int32_batch(sorted.iter().take(limit).cloned()), + Self::F64 { sorted, .. } => f64_batch(sorted.iter().take(limit).cloned()), + Self::Str { sorted, .. } => string_batch(sorted.iter().take(limit)), + Self::I64Str { sorted, .. } => i64string_batch(sorted.iter().take(limit)), + } + } + + /// Return the input data to sort + fn batches(&self) -> Vec { + match self { + Self::I32 { batches, .. } => batches.clone(), + Self::F64 { batches, .. } => batches.clone(), + Self::Str { batches, .. } => batches.clone(), + Self::I64Str { batches, .. } => batches.clone(), + } + } + + /// Return the schema of the input data + fn schema(&self) -> SchemaRef { + match self { + Self::I32 { batches, .. } => batches[0].schema(), + Self::F64 { batches, .. } => batches[0].schema(), + Self::Str { batches, .. } => batches[0].schema(), + Self::I64Str { batches, .. } => batches[0].schema(), + } + } + + /// Return the sort expression to use for this data, depending on the type + fn sort_expr(&self) -> Vec { + match self { + Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { + vec![datafusion_expr::col("x").sort(true, true)] + } + Self::I64Str { .. } => { + vec![ + datafusion_expr::col("x").sort(true, true), + datafusion_expr::col("y").sort(true, true), + ] + } + } + } +} + +/// Create a record batch with a single column of type `Int32` named "x" +fn int32_batch(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Int32Array::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with a single column of type `Float64` named "x" +fn f64_batch(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Float64Array::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with a single column of type `StringArray` named "x" +fn string_batch<'a>(values: impl IntoIterator>) -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(StringArray::from_iter(values.into_iter())) as _, + )]) + .unwrap() +} + +/// Create a record batch with i64 column "x" and utf8 column "y" +fn i64string_batch<'a>( + values: impl IntoIterator, Option)> + Clone, +) -> RecordBatch { + let ints = values.clone().into_iter().map(|(i, _)| *i); + let strings = values.into_iter().map(|(_, s)| s); + RecordBatch::try_from_iter(vec![ + ("x", Arc::new(Int64Array::from_iter(ints)) as _), + ("y", Arc::new(StringArray::from_iter(strings)) as _), + ]) + .unwrap() +} + +/// Run the TopK test, sorting the input batches with the specified ftch +/// (limit) and compares the results to the expected values. +async fn run_limit_test(fetch: usize, data: &SortedData) { + let input = data.batches(); + let schema = data.schema(); + + let table = MemTable::try_new(schema, vec![input]).unwrap(); + + let ctx = SessionContext::new(); + let df = ctx + .read_table(Arc::new(table)) + .unwrap() + .sort(data.sort_expr()) + .unwrap() + .limit(0, Some(fetch)) + .unwrap(); + + // Verify the plan contains a TopK node + { + let explain = df + .clone() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + let plan_text = pretty_format_batches(&explain).unwrap().to_string(); + let expected = format!("TopK(fetch={fetch})"); + assert_contains!(plan_text, expected); + } + + let results = df.collect().await.unwrap(); + let expected = data.topk_values(fetch); + + // Verify that all output batches conform to the specified batch size + let max_batch_size = ctx.copied_config().batch_size(); + for batch in &results { + assert!(batch.num_rows() <= max_batch_size); + } + + let results = concat_batches(&results[0].schema(), &results).unwrap(); + + let results = [results]; + let expected = [expected]; + + assert_eq!( + &expected, + &results, + "TopK mismatch fetch {fetch} \n\ + expected rows {}, actual rows {}.\ + \n\nExpected:\n{}\n\nActual:\n{}", + expected[0].num_rows(), + results[0].num_rows(), + pretty_format_batches(&expected).unwrap(), + pretty_format_batches(&results).unwrap(), + ); +} + +/// Return random ASCII String with len +fn get_random_string(len: usize) -> String { + rand::thread_rng() + .sample_iter(rand::distributions::Alphanumeric) + .take(len) + .map(char::from) + .collect() +} diff --git a/datafusion/core/tests/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs similarity index 98% rename from datafusion/core/tests/merge_fuzz.rs rename to datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 6411f31be0cee..c38ff41f5783a 100644 --- a/datafusion/core/tests/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -118,7 +118,7 @@ async fn run_merge_test(input: Vec>) { let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let task_ctx = ctx.task_ctx(); let collected = collect(merge, task_ctx).await.unwrap(); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs new file mode 100644 index 0000000000000..83ec928ae229c --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_fuzz; +mod join_fuzz; +mod merge_fuzz; +mod sort_fuzz; + +mod limit_fuzz; +mod sort_preserving_repartition_fuzz; +mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs new file mode 100644 index 0000000000000..f4b4f16aa1601 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill + +use arrow::{ + array::{ArrayRef, Int32Array}, + compute::SortOptions, + record_batch::RecordBatch, +}; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_execution::memory_pool::GreedyMemoryPool; +use datafusion_physical_expr::expressions::col; +use rand::Rng; +use std::sync::Arc; +use test_utils::{batches_to_vec, partitions_to_sorted_vec}; + +const KB: usize = 1 << 10; +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn test_sort_1k_mem() { + for (batch_size, should_spill) in [(5, false), (20000, true), (1000000, true)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(10 * KB) + .with_should_spill(should_spill) + .run() + .await; + } +} + +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn test_sort_100k_mem() { + for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, true)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(100 * KB) + .with_should_spill(should_spill) + .run() + .await; + } +} + +#[tokio::test] +async fn test_sort_unlimited_mem() { + for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, false)] { + SortTest::new() + .with_int32_batches(batch_size) + .with_pool_size(usize::MAX) + .with_should_spill(should_spill) + .run() + .await; + } +} +#[derive(Debug, Default)] +struct SortTest { + input: Vec>, + /// GreedyMemoryPool size, if specified + pool_size: Option, + /// If true, expect the sort to spill + should_spill: bool, +} + +impl SortTest { + fn new() -> Self { + Default::default() + } + + /// Create batches of int32 values of rows + fn with_int32_batches(mut self, rows: usize) -> Self { + self.input = vec![make_staggered_i32_batches(rows)]; + self + } + + /// specify that this test should use a memory pool of the specifeid size + fn with_pool_size(mut self, pool_size: usize) -> Self { + self.pool_size = Some(pool_size); + self + } + + fn with_should_spill(mut self, should_spill: bool) -> Self { + self.should_spill = should_spill; + self + } + + /// Sort the input using SortExec and ensure the results are + /// correct according to `Vec::sort` both with and without spilling + async fn run(&self) { + let input = self.input.clone(); + let first_batch = input + .iter() + .flat_map(|p| p.iter()) + .next() + .expect("at least one batch"); + let schema = first_batch.schema(); + + let sort = vec![PhysicalSortExpr { + expr: col("x", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let exec = MemoryExec::try_new(&input, schema, None).unwrap(); + let sort = Arc::new(SortExec::new(sort, Arc::new(exec))); + + let session_config = SessionConfig::new(); + let session_ctx = if let Some(pool_size) = self.pool_size { + // Make sure there is enough space for the initial spill + // reservation + let pool_size = pool_size.saturating_add( + session_config + .options() + .execution + .sort_spill_reservation_bytes, + ); + + let runtime_config = RuntimeConfig::new() + .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))); + let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + SessionContext::new_with_config_rt(session_config, runtime) + } else { + SessionContext::new_with_config(session_config) + }; + + let task_ctx = session_ctx.task_ctx(); + let collected = collect(sort.clone(), task_ctx).await.unwrap(); + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + + if self.should_spill { + assert_ne!( + sort.metrics().unwrap().spill_count().unwrap(), + 0, + "Expected spill, but did not: {self:?}" + ); + } else { + assert_eq!( + sort.metrics().unwrap().spill_count().unwrap(), + 0, + "Expected no spill, but did: {self:?}" + ); + } + + assert_eq!( + session_ctx.runtime_env().memory_pool.reserved(), + 0, + "The sort should have returned all memory used back to the memory pool" + ); + assert_eq!(expected, actual, "failure in @ pool_size {self:?}"); + } +} + +/// Return randomly sized record batches in a field named 'x' of type `Int32` +/// with randomized i32 content +fn make_staggered_i32_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let max_batch = 1024; + + let mut batches = vec![]; + let mut remaining = len; + while remaining != 0 { + let to_read = rng.gen_range(0..=remaining.min(max_batch)); + remaining -= to_read; + + batches.push( + RecordBatch::try_from_iter(vec![( + "x", + Arc::new(Int32Array::from_iter_values( + (0..to_read).map(|_| rng.gen()), + )) as ArrayRef, + )]) + .unwrap(), + ) + } + batches +} diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs new file mode 100644 index 0000000000000..df6499e9b1e47 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -0,0 +1,488 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod sp_repartition_fuzz_tests { + use std::sync::Arc; + + use arrow::compute::{concat_batches, lexsort, SortColumn}; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; + use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; + + use datafusion::physical_plan::{ + collect, + memory::MemoryExec, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + repartition::RepartitionExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, + sorts::streaming_merge::streaming_merge, + stream::RecordBatchStreamAdapter, + ExecutionPlan, Partitioning, + }; + use datafusion::prelude::SessionContext; + use datafusion_common::Result; + use datafusion_execution::{ + config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, + }; + use datafusion_physical_expr::{ + expressions::{col, Column}, + EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + }; + use test_utils::add_empty_batches; + + use datafusion_physical_expr::equivalence::EquivalenceClass; + use itertools::izip; + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as u64) + .collect(); + Arc::new(UInt64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in eq_properties.constants() { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class().iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group().iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } + + // This test checks for whether during sort preserving merge we can preserve all of the valid orderings + // successfully. If at the input we have orderings [a ASC, b ASC], [c ASC, d ASC] + // After sort preserving merge orderings [a ASC, b ASC], [c ASC, d ASC] should still be valid. + #[tokio::test] + async fn stream_merge_multi_order_preserve() -> Result<()> { + const N_PARTITION: usize = 8; + const N_ELEM: usize = 25; + const N_DISTINCT: usize = 5; + const N_DIFF_SCHEMA: usize = 20; + + use datafusion::physical_plan::common::collect; + for seed in 0..N_DIFF_SCHEMA { + // Create a schema with random equivalence properties + let (_test_schema, eq_properties) = create_random_schema(seed as u64)?; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; + let schema = table_data_with_properties.schema(); + let streams: Vec = (0..N_PARTITION) + .map(|_idx| { + let batch = table_data_with_properties.clone(); + Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(async { Ok(batch) }), + )) as SendableRecordBatchStream + }) + .collect::>(); + + // Returns concatenated version of the all available orderings + let exprs = eq_properties + .oeq_class() + .output_ordering() + .unwrap_or_default(); + + let context = SessionContext::new().task_ctx(); + let mem_reservation = + MemoryConsumer::new("test".to_string()).register(context.memory_pool()); + + // Internally SortPreservingMergeExec uses this function for merging. + let res = streaming_merge( + streams, + schema, + &exprs, + BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), + 1, + None, + mem_reservation, + )?; + let res = collect(res).await?; + // Contains the merged result. + let res = concat_batches(&res[0].schema(), &res)?; + + for ordering in eq_properties.oeq_class().iter() { + let err_msg = format!("error in eq properties: {:?}", eq_properties); + let sort_solumns = ordering + .iter() + .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) + .collect::>>()?; + let orig_columns = sort_solumns + .iter() + .map(|sort_column| sort_column.values.clone()) + .collect::>(); + let sorted_columns = lexsort(&sort_solumns, None)?; + + // Make sure after merging ordering is still valid. + assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert!( + izip!(orig_columns.into_iter(), sorted_columns.into_iter()) + .all(|(lhs, rhs)| { lhs == rhs }), + "{}", + err_msg + ) + } + } + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn sort_preserving_repartition_test() { + let seed_start = 0; + let seed_end = 100; + let n_row = 1000; + // Since ordering in the test (ORDER BY a,b,c) + // covers all the table (table consists of a,b,c columns). + // Result doesn't depend on the stable/unstable sort + // behaviour. We can choose, n_distinct as we like. However, + // we chose it a large number to decrease probability of having same rows in the table. + let n_distinct = 1_000_000; + for (is_first_roundrobin, is_first_sort_preserving) in + [(false, false), (false, true), (true, false), (true, true)] + { + for is_second_roundrobin in [false, true] { + let mut handles = Vec::new(); + + for seed in seed_start..seed_end { + let job = tokio::spawn(run_sort_preserving_repartition_test( + make_staggered_batches::(n_row, n_distinct, seed as u64), + is_first_roundrobin, + is_first_sort_preserving, + is_second_roundrobin, + )); + handles.push(job); + } + + for job in handles { + job.await.unwrap(); + } + } + } + } + + /// Check whether physical plan below + /// "SortPreservingMergeExec: [a@0 ASC,b@1 ASC,c@2 ASC]", + /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=2", (Partitioning can be roundrobin also) + /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=1", (Partitioning can be roundrobin also) + /// " MemoryExec: partitions=1, partition_sizes=[75]", + /// and / or + /// "SortPreservingMergeExec: [a@0 ASC,b@1 ASC,c@2 ASC]", + /// " SortPreservingRepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=2", (Partitioning can be roundrobin also) + /// " RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 2), input_partitions=1", (Partitioning can be roundrobin also) + /// " MemoryExec: partitions=1, partition_sizes=[75]", + /// preserves ordering. Input fed to the plan above should be same with the output of the plan. + async fn run_sort_preserving_repartition_test( + input1: Vec, + // If `true`, first repartition executor after `MemoryExec` will be in `RoundRobin` mode + // else it will be in `Hash` mode + is_first_roundrobin: bool, + // If `true`, first repartition executor after `MemoryExec` will be `SortPreservingRepartitionExec` + // If `false`, first repartition executor after `MemoryExec` will be `RepartitionExec` (Since its input + // partition number is 1, `RepartitionExec` also preserves ordering.). + is_first_sort_preserving: bool, + // If `true`, second repartition executor after `MemoryExec` will be in `RoundRobin` mode + // else it will be in `Hash` mode + is_second_roundrobin: bool, + ) { + let schema = input1[0].schema(); + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::new_with_config(session_config); + let mut sort_keys = vec![]; + for ordering_col in ["a", "b", "c"] { + sort_keys.push(PhysicalSortExpr { + expr: col(ordering_col, &schema).unwrap(), + options: SortOptions::default(), + }) + } + + let concat_input_record = concat_batches(&schema, &input1).unwrap(); + + let running_source = Arc::new( + MemoryExec::try_new(&[input1.clone()], schema.clone(), None) + .unwrap() + .with_sort_information(vec![sort_keys.clone()]), + ); + let hash_exprs = vec![col("c", &schema).unwrap()]; + + let intermediate = match (is_first_roundrobin, is_first_sort_preserving) { + (true, true) => sort_preserving_repartition_exec_round_robin(running_source), + (true, false) => repartition_exec_round_robin(running_source), + (false, true) => { + sort_preserving_repartition_exec_hash(running_source, hash_exprs.clone()) + } + (false, false) => repartition_exec_hash(running_source, hash_exprs.clone()), + }; + + let intermediate = if is_second_roundrobin { + sort_preserving_repartition_exec_round_robin(intermediate) + } else { + sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone()) + }; + + let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate); + let task_ctx = ctx.task_ctx(); + + let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap(); + let concat_res = concat_batches(&schema, &collected_running).unwrap(); + assert_eq!(concat_res, concat_input_record); + } + + fn sort_preserving_repartition_exec_round_robin( + input: Arc, + ) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2)) + .unwrap() + .with_preserve_order(), + ) + } + + fn repartition_exec_round_robin( + input: Arc, + ) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2)).unwrap(), + ) + } + + fn sort_preserving_repartition_exec_hash( + input: Arc, + hash_expr: Vec>, + ) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2)) + .unwrap() + .with_preserve_order(), + ) + } + + fn repartition_exec_hash( + input: Arc, + hash_expr: Vec>, + ) -> Arc { + Arc::new( + RepartitionExec::try_new(input, Partitioning::Hash(hash_expr, 2)).unwrap(), + ) + } + + fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + } + + /// Return randomly sized record batches with: + /// three sorted int64 columns 'a', 'b', 'c' ranged from 0..'n_distinct' as columns + pub(crate) fn make_staggered_batches( + len: usize, + n_distinct: usize, + random_seed: u64, + ) -> Vec { + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(random_seed); + let mut input123: Vec<(i64, i64, i64)> = vec![(0, 0, 0); len]; + input123.iter_mut().for_each(|v| { + *v = ( + rng.gen_range(0..n_distinct) as i64, + rng.gen_range(0..n_distinct) as i64, + rng.gen_range(0..n_distinct) as i64, + ) + }); + input123.sort(); + let input1 = + Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0)); + let input2 = + Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.1)); + let input3 = + Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.2)); + + // split into several record batches + let mut remainder = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(input1) as ArrayRef), + ("b", Arc::new(input2) as ArrayRef), + ("c", Arc::new(input3) as ArrayRef), + ]) + .unwrap(); + + let mut batches = vec![]; + if STREAM { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..50); + if remainder.num_rows() < batch_size { + break; + } + batches.push(remainder.slice(0, batch_size)); + remainder = + remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } else { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + batches.push(remainder.slice(0, batch_size)); + remainder = + remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } + add_empty_batches(batches, &mut rng) + } +} diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs similarity index 77% rename from datafusion/core/tests/window_fuzz.rs rename to datafusion/core/tests/fuzz_cases/window_fuzz.rs index 77b6e0a5d11bb..44ff71d023928 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,127 +22,121 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use hashbrown::HashMap; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; - use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, PartitionSearchMode, WindowAggExec, + create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{collect, ExecutionPlan, InputOrderMode}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; - -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_physical_expr::expressions::{col, lit}; +use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; -#[cfg(test)] -mod tests { - use super::*; - use datafusion::physical_plan::windows::PartitionSearchMode::{ - Linear, PartiallySorted, Sorted, - }; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; - #[tokio::test(flavor = "multi_thread", worker_threads = 16)] - async fn window_bounded_window_random_comparison() -> Result<()> { - // make_staggered_batches gives result sorted according to a, b, c - // In the test cases first entry represents partition by columns - // Second entry represents order by columns. - // Third entry represents search mode. - // In sorted mode physical plans are in the form for WindowAggExec - //``` - // WindowAggExec - // MemoryExec] - // ``` - // and in the form for BoundedWindowAggExec - // ``` - // BoundedWindowAggExec - // MemoryExec - // ``` - // In Linear and PartiallySorted mode physical plans are in the form for WindowAggExec - //``` - // WindowAggExec - // SortExec(required by window function) - // MemoryExec] - // ``` - // and in the form for BoundedWindowAggExec - // ``` - // BoundedWindowAggExec - // MemoryExec - // ``` - let test_cases = vec![ - (vec!["a"], vec!["a"], Sorted), - (vec!["a"], vec!["b"], Sorted), - (vec!["a"], vec!["a", "b"], Sorted), - (vec!["a"], vec!["b", "c"], Sorted), - (vec!["a"], vec!["a", "b", "c"], Sorted), - (vec!["b"], vec!["a"], Linear), - (vec!["b"], vec!["a", "b"], Linear), - (vec!["b"], vec!["a", "c"], Linear), - (vec!["b"], vec!["a", "b", "c"], Linear), - (vec!["c"], vec!["a"], Linear), - (vec!["c"], vec!["a", "b"], Linear), - (vec!["c"], vec!["a", "c"], Linear), - (vec!["c"], vec!["a", "b", "c"], Linear), - (vec!["b", "a"], vec!["a"], Sorted), - (vec!["b", "a"], vec!["b"], Sorted), - (vec!["b", "a"], vec!["c"], Sorted), - (vec!["b", "a"], vec!["a", "b"], Sorted), - (vec!["b", "a"], vec!["b", "c"], Sorted), - (vec!["b", "a"], vec!["a", "c"], Sorted), - (vec!["b", "a"], vec!["a", "b", "c"], Sorted), - (vec!["c", "b"], vec!["a"], Linear), - (vec!["c", "b"], vec!["a", "b"], Linear), - (vec!["c", "b"], vec!["a", "c"], Linear), - (vec!["c", "b"], vec!["a", "b", "c"], Linear), - (vec!["c", "a"], vec!["a"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["b"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["c"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["a", "b"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["b", "c"], PartiallySorted(vec![1])), - (vec!["c", "a"], vec!["a", "c"], PartiallySorted(vec![1])), - ( - vec!["c", "a"], - vec!["a", "b", "c"], - PartiallySorted(vec![1]), - ), - (vec!["c", "b", "a"], vec!["a"], Sorted), - (vec!["c", "b", "a"], vec!["b"], Sorted), - (vec!["c", "b", "a"], vec!["c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "b"], Sorted), - (vec!["c", "b", "a"], vec!["b", "c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "c"], Sorted), - (vec!["c", "b", "a"], vec!["a", "b", "c"], Sorted), - ]; - let n = 300; - let n_distincts = vec![10, 20]; - for n_distinct in n_distincts { - let mut handles = Vec::new(); - for i in 0..n { - let idx = i % test_cases.len(); - let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); - let job = tokio::spawn(run_window_test( - make_staggered_batches::(1000, n_distinct, i as u64), - i as u64, - pb_cols, - ob_cols, - search_mode, - )); - handles.push(job); - } - for job in handles { - job.await.unwrap()?; - } +use datafusion_physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 16)] +async fn window_bounded_window_random_comparison() -> Result<()> { + // make_staggered_batches gives result sorted according to a, b, c + // In the test cases first entry represents partition by columns + // Second entry represents order by columns. + // Third entry represents search mode. + // In sorted mode physical plans are in the form for WindowAggExec + //``` + // WindowAggExec + // MemoryExec] + // ``` + // and in the form for BoundedWindowAggExec + // ``` + // BoundedWindowAggExec + // MemoryExec + // ``` + // In Linear and PartiallySorted mode physical plans are in the form for WindowAggExec + //``` + // WindowAggExec + // SortExec(required by window function) + // MemoryExec] + // ``` + // and in the form for BoundedWindowAggExec + // ``` + // BoundedWindowAggExec + // MemoryExec + // ``` + let test_cases = vec![ + (vec!["a"], vec!["a"], Sorted), + (vec!["a"], vec!["b"], Sorted), + (vec!["a"], vec!["a", "b"], Sorted), + (vec!["a"], vec!["b", "c"], Sorted), + (vec!["a"], vec!["a", "b", "c"], Sorted), + (vec!["b"], vec!["a"], Linear), + (vec!["b"], vec!["a", "b"], Linear), + (vec!["b"], vec!["a", "c"], Linear), + (vec!["b"], vec!["a", "b", "c"], Linear), + (vec!["c"], vec!["a"], Linear), + (vec!["c"], vec!["a", "b"], Linear), + (vec!["c"], vec!["a", "c"], Linear), + (vec!["c"], vec!["a", "b", "c"], Linear), + (vec!["b", "a"], vec!["a"], Sorted), + (vec!["b", "a"], vec!["b"], Sorted), + (vec!["b", "a"], vec!["c"], Sorted), + (vec!["b", "a"], vec!["a", "b"], Sorted), + (vec!["b", "a"], vec!["b", "c"], Sorted), + (vec!["b", "a"], vec!["a", "c"], Sorted), + (vec!["b", "a"], vec!["a", "b", "c"], Sorted), + (vec!["c", "b"], vec!["a"], Linear), + (vec!["c", "b"], vec!["a", "b"], Linear), + (vec!["c", "b"], vec!["a", "c"], Linear), + (vec!["c", "b"], vec!["a", "b", "c"], Linear), + (vec!["c", "a"], vec!["a"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["b"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["c"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["a", "b"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["b", "c"], PartiallySorted(vec![1])), + (vec!["c", "a"], vec!["a", "c"], PartiallySorted(vec![1])), + ( + vec!["c", "a"], + vec!["a", "b", "c"], + PartiallySorted(vec![1]), + ), + (vec!["c", "b", "a"], vec!["a"], Sorted), + (vec!["c", "b", "a"], vec!["b"], Sorted), + (vec!["c", "b", "a"], vec!["c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "b"], Sorted), + (vec!["c", "b", "a"], vec!["b", "c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "c"], Sorted), + (vec!["c", "b", "a"], vec!["a", "b", "c"], Sorted), + ]; + let n = 300; + let n_distincts = vec![10, 20]; + for n_distinct in n_distincts { + let mut handles = Vec::new(); + for i in 0..n { + let idx = i % test_cases.len(); + let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone(); + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, n_distinct, i as u64), + i as u64, + pb_cols, + ob_cols, + search_mode, + )); + handles.push(job); + } + for job in handles { + job.await.unwrap()?; } - Ok(()) } + Ok(()) } fn get_random_function( @@ -208,6 +202,13 @@ fn get_random_function( vec![], ), ); + window_fn_map.insert( + "dense_rank", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + vec![], + ), + ); window_fn_map.insert( "lead", ( @@ -254,6 +255,14 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; + if let WindowFunction::AggregateFunction(f) = window_fn { + let a = args[0].clone(); + let dt = a.data_type(schema.as_ref()).unwrap(); + let sig = f.signature(); + let coerced = coerce_types(f, &[dt], &sig).unwrap(); + args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + } + for new_arg in new_args { args.push(new_arg.clone()); } @@ -374,13 +383,13 @@ async fn run_window_test( random_seed: u64, partition_by_columns: Vec<&str>, orderby_columns: Vec<&str>, - search_mode: PartitionSearchMode, + search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, PartitionSearchMode::Sorted); + let is_linear = !matches!(search_mode, InputOrderMode::Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); - let ctx = SessionContext::with_config(session_config); + let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); @@ -425,7 +434,7 @@ async fn run_window_test( ]; let memory_exec = MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None).unwrap(); - let memory_exec = memory_exec.with_sort_information(source_sort_keys.clone()); + let memory_exec = memory_exec.with_sort_information(vec![source_sort_keys.clone()]); let mut exec1 = Arc::new(memory_exec) as Arc; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. @@ -445,7 +454,6 @@ async fn run_window_test( ) .unwrap()], exec1, - schema.clone(), vec![], ) .unwrap(), @@ -453,7 +461,7 @@ async fn run_window_test( let exec2 = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(source_sort_keys.clone()), + .with_sort_information(vec![source_sort_keys.clone()]), ); let running_window_exec = Arc::new( BoundedWindowAggExec::try_new( @@ -468,7 +476,6 @@ async fn run_window_test( ) .unwrap()], exec2, - schema.clone(), vec![], search_mode, ) @@ -534,7 +541,7 @@ fn make_staggered_batches( let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1)); let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2)); - let input4 = Int32Array::from_iter_values(input4.into_iter()); + let input4 = Int32Array::from_iter_values(input4); // split into several record batches let mut remainder = RecordBatch::try_from_iter(vec![ diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index f2e1223dc6ec8..a98d097856fb3 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -17,20 +17,30 @@ //! This module contains tests for limiting memory at runtime in DataFusion -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, DictionaryArray}; +use arrow_schema::SortOptions; +use async_trait::async_trait; +use datafusion::assert_batches_eq; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::streaming::PartitionStream; +use datafusion_expr::{Expr, TableType}; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; -use std::sync::Arc; +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use datafusion::datasource::streaming::{PartitionStream, StreamingTable}; -use datafusion::datasource::MemTable; +use datafusion::datasource::streaming::StreamingTable; +use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::physical_optimizer::pipeline_fixer::PipelineFixer; +use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion_common::assert_contains; +use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use datafusion_common::{assert_contains, Result}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_execution::TaskContext; @@ -45,110 +55,110 @@ fn init() { #[tokio::test] async fn oom_sort() { - run_limit_test( - "select * from t order by host DESC", - vec![ + TestCase::new() + .with_query("select * from t order by host DESC") + .with_expected_errors(vec![ "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", - ], - 200_000, - ) - .await + ]) + .with_memory_limit(200_000) + .run() + .await } #[tokio::test] async fn group_by_none() { - run_limit_test( - "select median(image) from t", - vec![ + TestCase::new() + .with_query("select median(request_bytes) from t") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "AggregateStream", - ], - 20_000, - ) - .await + ]) + .with_memory_limit(2_000) + .run() + .await } #[tokio::test] async fn group_by_row_hash() { - run_limit_test( - "select count(*) from t GROUP BY response_bytes", - vec![ + TestCase::new() + .with_query("select count(*) from t GROUP BY response_bytes") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "GroupedHashAggregateStream", - ], - 2_000, - ) - .await + ]) + .with_memory_limit(2_000) + .run() + .await } #[tokio::test] async fn group_by_hash() { - run_limit_test( + TestCase::new() // group by dict column - "select count(*) from t GROUP BY service, host, pod, container", - vec![ + .with_query("select count(*) from t GROUP BY service, host, pod, container") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "GroupedHashAggregateStream", - ], - 1_000, - ) - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] async fn join_by_key_multiple_partitions() { let config = SessionConfig::new().with_target_partitions(2); - run_limit_test_with_config( - "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput[0]", - ], - 1_000, - config, - ) - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] async fn join_by_key_single_partition() { let config = SessionConfig::new().with_target_partitions(1); - run_limit_test_with_config( - "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput", - ], - 1_000, - config, - ) - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] async fn join_by_expression() { - run_limit_test( - "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "NestedLoopJoinLoad[0]", - ], - 1_000, - ) - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] async fn cross_join() { - run_limit_test( - "select t1.* from t t1 CROSS JOIN t t2", - vec![ + TestCase::new() + .with_query("select t1.* from t t1 CROSS JOIN t t2") + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "CrossJoinExec", - ], - 1_000, - ) - .await + ]) + .with_memory_limit(1_000) + .run() + .await } #[tokio::test] @@ -158,94 +168,504 @@ async fn merge_join() { .with_target_partitions(2) .set_bool("datafusion.optimizer.prefer_hash_join", false); - run_limit_test_with_config( - "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", - vec![ + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "SMJStream", - ], - 1_000, - config, - ) - .await + ]) + .with_memory_limit(1_000) + .with_config(config) + .run() + .await } #[tokio::test] -async fn test_limit_symmetric_hash_join() { - let config = SessionConfig::new(); - - run_streaming_test_with_config( - "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", - vec![ +async fn symmetric_hash_join() { + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "SymmetricHashJoinStream", - ], - 1_000, - config, - ) - .await + ]) + .with_memory_limit(1_000) + .with_scenario(Scenario::AccessLogStreaming) + .run() + .await } -/// 50 byte memory limit -const MEMORY_FRACTION: f64 = 0.95; +#[tokio::test] +async fn sort_preserving_merge() { + let scenario = Scenario::new_dictionary_strings(2); + let partition_size = scenario.partition_size(); + + TestCase::new() + // This query uses the exact same ordering as the input table + // so only a merge is needed + .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") + .with_expected_errors(vec![ + "Resources exhausted: Failed to allocate additional", + "SortPreservingMergeExec", + ]) + // provide insufficient memory to merge + .with_memory_limit(partition_size / 2) + // two partitions of data, so a merge is required + .with_scenario(scenario) + .with_expected_plan( + // It is important that this plan only has + // SortPreservingMergeExec (not a Sort which would compete + // with the SortPreservingMergeExec for memory) + &[ + "+---------------+-------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+-------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Limit: skip=0, fetch=10 |", + "| | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | GlobalLimitExec: skip=0, fetch=10 |", + "| | SortPreservingMergeExec: [a@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 |", + "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", + "| | |", + "+---------------+-------------------------------------------------------------------------------------------------------------+", + ] + ) + .run() + .await +} -/// runs the specified query against 1000 rows with specified -/// memory limit and no disk manager enabled with default SessionConfig. -async fn run_limit_test( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, -) { - let config = SessionConfig::new(); - run_limit_test_with_config(query, expected_error_contains, memory_limit, config).await +#[tokio::test] +async fn sort_spill_reservation() { + let scenario = Scenario::new_dictionary_strings(1); + let partition_size = scenario.partition_size(); + + let base_config = SessionConfig::new() + // do not allow the sort to use the 'concat in place' path + .with_sort_in_place_threshold_bytes(10); + + // This test case shows how sort_spill_reservation works by + // purposely sorting data that requires non trivial memory to + // sort/merge. + let test = TestCase::new() + // This query uses a different order than the input table to + // force a sort. It also needs to have multiple columns to + // force RowFormat / interner that makes merge require + // substantial memory + .with_query("select * from t ORDER BY a , b DESC") + // enough memory to sort if we don't try to merge it all at once + .with_memory_limit(partition_size) + // use a single partiton so only a sort is needed + .with_scenario(scenario) + .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_expected_plan( + // It is important that this plan only has a SortExec, not + // also merge, so we can ensure the sort could finish + // given enough merging memory + &[ + "+---------------+--------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+--------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b DESC NULLS FIRST |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] |", + "| | MemoryExec: partitions=1, partition_sizes=[5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", + "| | |", + "+---------------+--------------------------------------------------------------------------------------------------------+", + ] + ); + + let config = base_config + .clone() + // provide insufficient reserved space for merging, + // the sort will fail while trying to merge + .with_sort_spill_reservation_bytes(1024); + + test.clone() + .with_expected_errors(vec![ + "Resources exhausted: Failed to allocate additional", + "ExternalSorterMerge", // merging in sort fails + ]) + .with_config(config) + .run() + .await; + + let config = base_config + // reserve sufficient space up front for merge and this time, + // which will force the spills to happen with less buffered + // input and thus with enough to merge. + .with_sort_spill_reservation_bytes(partition_size / 2); + + test.with_config(config).with_expected_success().run().await; } -/// runs the specified query against 1000 rows with a 50 -/// byte memory limit and no disk manager enabled -/// with specified SessionConfig instance -async fn run_limit_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, +/// Run the query with the specified memory limit, +/// and verifies the expected errors are returned +#[derive(Clone, Debug)] +struct TestCase { + query: Option, + expected_errors: Vec, memory_limit: usize, config: SessionConfig, -) { - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); + scenario: Scenario, + /// How should the disk manager (that allows spilling) be + /// configured? Defaults to `Disabled` + disk_manager_config: DiskManagerConfig, + /// Expected explain plan, if non emptry + expected_plan: Vec, + /// Is the plan expected to pass? Defaults to false + expected_success: bool, +} - let table = MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); +impl TestCase { + fn new() -> Self { + Self { + query: None, + expected_errors: vec![], + memory_limit: 0, + config: SessionConfig::new(), + scenario: Scenario::AccessLog, + disk_manager_config: DiskManagerConfig::Disabled, + expected_plan: vec![], + expected_success: false, + } + } + + /// Set the query to run + fn with_query(mut self, query: impl Into) -> Self { + self.query = Some(query.into()); + self + } + + /// Set a list of expected strings that must appear in any errors + fn with_expected_errors<'a>( + mut self, + expected_errors: impl IntoIterator, + ) -> Self { + self.expected_errors = + expected_errors.into_iter().map(|s| s.to_string()).collect(); + self + } - let rt_config = RuntimeConfig::new() - // do not allow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - .with_memory_limit(memory_limit, MEMORY_FRACTION); + /// Set the amount of memory that can be used + fn with_memory_limit(mut self, memory_limit: usize) -> Self { + self.memory_limit = memory_limit; + self + } - let runtime = RuntimeEnv::new(rt_config).unwrap(); + /// Specify the configuration to use + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = config; + self + } - // Disabling physical optimizer rules to avoid sorts / repartitions - // (since RepartitionExec / SortExec also has a memory budget which we'll likely hit first) - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![]); + /// Mark that the test expects the query to run successfully + pub fn with_expected_success(mut self) -> Self { + self.expected_success = true; + self + } - let ctx = SessionContext::with_state(state); - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); + /// Specify the scenario to run + pub fn with_scenario(mut self, scenario: Scenario) -> Self { + self.scenario = scenario; + self + } - let df = ctx.sql(query).await.expect("Planning query"); + /// Specify if the disk manager should be enabled. If true, + /// operators that support it can spill + pub fn with_disk_manager_config( + mut self, + disk_manager_config: DiskManagerConfig, + ) -> Self { + self.disk_manager_config = disk_manager_config; + self + } - match df.collect().await { - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") + /// Specify an expected plan to review + pub fn with_expected_plan(mut self, expected_plan: &[&str]) -> Self { + self.expected_plan = expected_plan.iter().map(|s| s.to_string()).collect(); + self + } + + /// Run the test, panic'ing on error + async fn run(self) { + let Self { + query, + expected_errors, + memory_limit, + config, + scenario, + disk_manager_config, + expected_plan, + expected_success, + } = self; + + let table = scenario.table(); + + let rt_config = RuntimeConfig::new() + // do not allow spilling + .with_disk_manager(disk_manager_config) + .with_memory_limit(memory_limit, MEMORY_FRACTION); + + let runtime = RuntimeEnv::new(rt_config).unwrap(); + + // Configure execution + let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + let state = match scenario.rules() { + Some(rules) => state.with_physical_optimizer_rules(rules), + None => state, + }; + + let ctx = SessionContext::new_with_state(state); + ctx.register_table("t", table).expect("registering table"); + + let query = query.expect("Test error: query not specified"); + let df = ctx.sql(&query).await.expect("Planning query"); + + if !expected_plan.is_empty() { + let expected_plan: Vec<_> = + expected_plan.iter().map(|s| s.as_str()).collect(); + let actual_plan = df + .clone() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + assert_batches_eq!(expected_plan, &actual_plan); } - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); + + match df.collect().await { + Ok(_batches) => { + if !expected_success { + panic!( + "Unexpected success when running, expected memory limit failure" + ) + } + } + Err(e) => { + if expected_success { + panic!( + "Unexpected failure when running, expected success but got: {e}" + ) + } else { + for error_substring in expected_errors { + assert_contains!(e.to_string(), error_substring); + } + } } } } } +/// 50 byte memory limit +const MEMORY_FRACTION: f64 = 0.95; + +/// Different data scenarios +#[derive(Clone, Debug)] +enum Scenario { + /// 1000 rows of access log data with batches of 50 rows + AccessLog, + + /// 1000 rows of access log data with batches of 50 rows in a + /// [`StreamingTable`] + AccessLogStreaming, + + /// N partitions of of sorted, dictionary encoded strings. + DictionaryStrings { + partitions: usize, + /// If true, splits all input batches into 1 row each + single_row_batches: bool, + }, +} + +impl Scenario { + /// Create a new DictionaryStrings scenario with the number of partitions + fn new_dictionary_strings(partitions: usize) -> Self { + Self::DictionaryStrings { + partitions, + single_row_batches: false, + } + } + + /// return the size, in bytes, of each partition + fn partition_size(&self) -> usize { + if let Self::DictionaryStrings { + single_row_batches, .. + } = self + { + batches_byte_size(&maybe_split_batches(dict_batches(), *single_row_batches)) + } else { + panic!("Scenario does not support partition size"); + } + } + + /// return a TableProvider with data for the test + fn table(&self) -> Arc { + match self { + Self::AccessLog => { + let batches = access_log_batches(); + let table = + MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); + Arc::new(table) + } + Self::AccessLogStreaming => { + let batches = access_log_batches(); + + // Create a new streaming table with the generated schema and batches + let table = StreamingTable::try_new( + batches[0].schema(), + vec![Arc::new(DummyStreamPartition { + schema: batches[0].schema(), + batches: batches.clone(), + })], + ) + .unwrap() + .with_infinite_table(true); + Arc::new(table) + } + Self::DictionaryStrings { + partitions, + single_row_batches, + } => { + use datafusion::physical_expr::expressions::col; + let batches: Vec> = std::iter::repeat(maybe_split_batches( + dict_batches(), + *single_row_batches, + )) + .take(*partitions) + .collect(); + + let schema = batches[0][0].schema(); + let options = SortOptions { + descending: false, + nulls_first: false, + }; + let sort_information = vec![vec![ + PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options, + }, + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options, + }, + ]]; + + let table = SortedTableProvider::new(batches, sort_information); + Arc::new(table) + } + } + } + + /// return specific physical optimizer rules to use + fn rules(&self) -> Option>> { + match self { + Self::AccessLog => { + // Disabling physical optimizer rules to avoid sorts / + // repartitions (since RepartitionExec / SortExec also + // has a memory budget which we'll likely hit first) + Some(vec![]) + } + Self::AccessLogStreaming => { + // Disable all physical optimizer rules except the + // JoinSelection rule to avoid sorts or repartition, + // as they also have memory budgets that may be hit + // first + Some(vec![Arc::new(JoinSelection::new())]) + } + Self::DictionaryStrings { .. } => { + // Use default rules + None + } + } + } +} + +fn access_log_batches() -> Vec { + AccessLogGenerator::new() + .with_row_limit(1000) + .with_max_batch_size(50) + .collect() +} + +/// If `one_row_batches` is true, then returns new record batches that +/// are one row in size +fn maybe_split_batches( + batches: Vec, + one_row_batches: bool, +) -> Vec { + if !one_row_batches { + return batches; + } + + batches + .into_iter() + .flat_map(|mut batch| { + let mut batches = vec![]; + while batch.num_rows() > 1 { + batches.push(batch.slice(0, 1)); + batch = batch.slice(1, batch.num_rows() - 1); + } + batches + }) + .collect() +} + +static DICT_BATCHES: OnceLock> = OnceLock::new(); + +/// Returns 5 sorted string dictionary batches each with 50 rows with +/// this schema. +/// +/// a: Dictionary, +/// b: Dictionary, +fn dict_batches() -> Vec { + DICT_BATCHES.get_or_init(make_dict_batches).clone() +} + +fn make_dict_batches() -> Vec { + let batch_size = 50; + + let mut i = 0; + let gen = std::iter::from_fn(move || { + // create values like + // 0000000001 + // 0000000002 + // ... + // 0000000002 + + let values: Vec<_> = (i..i + batch_size) + .map(|x| format!("{:010}", x / 16)) + .collect(); + //println!("values: \n{values:?}"); + let array: DictionaryArray = + values.iter().map(|s| s.as_str()).collect(); + let array = Arc::new(array) as ArrayRef; + let batch = + RecordBatch::try_from_iter(vec![("a", array.clone()), ("b", array)]).unwrap(); + + i += batch_size; + Some(batch) + }); + + let num_batches = 5; + + let batches: Vec<_> = gen.take(num_batches).collect(); + + batches.iter().enumerate().for_each(|(i, batch)| { + println!("Dict batch[{i}] size is: {}", batch.get_array_memory_size()); + }); + + batches +} + +// How many bytes does the memory from dict_batches consume? +fn batches_byte_size(batches: &[RecordBatch]) -> usize { + batches.iter().map(|b| b.get_array_memory_size()).sum() +} + struct DummyStreamPartition { schema: SchemaRef, batches: Vec, @@ -266,65 +686,49 @@ impl PartitionStream for DummyStreamPartition { } } -async fn run_streaming_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, - config: SessionConfig, -) { - // Generate a set of access logs with a row limit of 1000 and a max batch size of 50 - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); - - // Create a new streaming table with the generated schema and batches - let table = StreamingTable::try_new( - batches[0].schema(), - vec![Arc::new(DummyStreamPartition { - schema: batches[0].schema(), - batches: batches.clone(), - })], - ) - .unwrap() - .with_infinite_table(true); - - // Configure the runtime environment with custom settings - let rt_config = RuntimeConfig::new() - // Disable disk manager to disallow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - // Set memory limit to 50 bytes - .with_memory_limit(memory_limit, MEMORY_FRACTION); - - // Create a new runtime environment with the configured settings - let runtime = RuntimeEnv::new(rt_config).unwrap(); - - // Create a new session state with the given configuration and runtime environment - // Disable all physical optimizer rules except the PipelineFixer rule to avoid sorts or - // repartition, as they also have memory budgets that may be hit first - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![Arc::new(PipelineFixer::new())]); - - // Create a new session context with the session state - let ctx = SessionContext::with_state(state); - // Register the streaming table with the session context - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); - - // Execute the SQL query and get a DataFrame - let df = ctx.sql(query).await.expect("Planning query"); - - // Collect the results of the DataFrame execution - match df.collect().await { - // If the execution succeeds, panic as we expect memory limit failure - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") - } - // If the execution fails, verify if the error contains the expected substrings - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); - } +/// Wrapper over a TableProvider that can provide ordering information +struct SortedTableProvider { + schema: SchemaRef, + batches: Vec>, + sort_information: Vec, +} + +impl SortedTableProvider { + fn new(batches: Vec>, sort_information: Vec) -> Self { + let schema = batches[0][0].schema(); + Self { + schema, + batches, + sort_information, } } } + +#[async_trait] +impl TableProvider for SortedTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let mem_exec = + MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())? + .with_sort_information(self.sort_information.clone()); + + Ok(Arc::new(mem_exec)) + } +} diff --git a/datafusion/core/tests/order_spill_fuzz.rs b/datafusion/core/tests/order_spill_fuzz.rs deleted file mode 100644 index 1f72e0fcb45bf..0000000000000 --- a/datafusion/core/tests/order_spill_fuzz.rs +++ /dev/null @@ -1,128 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill - -use arrow::{ - array::{ArrayRef, Int32Array}, - compute::SortOptions, - record_batch::RecordBatch, -}; -use datafusion::execution::memory_pool::GreedyMemoryPool; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; -use datafusion::physical_plan::memory::MemoryExec; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; -use datafusion::prelude::{SessionConfig, SessionContext}; -use rand::Rng; -use std::sync::Arc; -use test_utils::{batches_to_vec, partitions_to_sorted_vec}; - -#[tokio::test] -#[cfg_attr(tarpaulin, ignore)] -async fn test_sort_1k_mem() { - run_sort(10240, vec![(5, false), (20000, true), (1000000, true)]).await -} - -#[tokio::test] -#[cfg_attr(tarpaulin, ignore)] -async fn test_sort_100k_mem() { - run_sort(102400, vec![(5, false), (20000, false), (1000000, true)]).await -} - -#[tokio::test] -async fn test_sort_unlimited_mem() { - run_sort( - usize::MAX, - vec![(5, false), (2000, false), (1000000, false)], - ) - .await -} - -/// Sort the input using SortExec and ensure the results are correct according to `Vec::sort` -async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { - for (size, spill) in size_spill { - let input = vec![make_staggered_batches(size)]; - let first_batch = input - .iter() - .flat_map(|p| p.iter()) - .next() - .expect("at least one batch"); - let schema = first_batch.schema(); - - let sort = vec![PhysicalSortExpr { - expr: col("x", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }]; - - let exec = MemoryExec::try_new(&input, schema, None).unwrap(); - let sort = Arc::new(SortExec::new(sort, Arc::new(exec))); - - let runtime_config = RuntimeConfig::new() - .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))); - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); - - let task_ctx = session_ctx.task_ctx(); - let collected = collect(sort.clone(), task_ctx).await.unwrap(); - - let expected = partitions_to_sorted_vec(&input); - let actual = batches_to_vec(&collected); - - if spill { - assert_ne!(sort.metrics().unwrap().spill_count().unwrap(), 0); - } else { - assert_eq!(sort.metrics().unwrap().spill_count().unwrap(), 0); - } - - assert_eq!( - session_ctx.runtime_env().memory_pool.reserved(), - 0, - "The sort should have returned all memory used back to the memory pool" - ); - assert_eq!(expected, actual, "failure in @ pool_size {pool_size}"); - } -} - -/// Return randomly sized record batches in a field named 'x' of type `Int32` -/// with randomized i32 content -fn make_staggered_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); - let max_batch = 1024; - - let mut batches = vec![]; - let mut remaining = len; - while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); - remaining -= to_read; - - batches.push( - RecordBatch::try_from_iter(vec![( - "x", - Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), - )) as ArrayRef, - )]) - .unwrap(), - ) - } - batches -} diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 7d73b4a618818..3752d42dbf43a 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::io::Cursor; +use std::ops::Range; +use std::sync::Arc; +use std::time::SystemTime; + use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; @@ -30,6 +35,7 @@ use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::{collect, Statistics}; use datafusion::prelude::SessionContext; use datafusion_common::Result; + use futures::future::BoxFuture; use futures::{FutureExt, TryFutureExt}; use object_store::memory::InMemory; @@ -39,10 +45,6 @@ use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; -use std::io::Cursor; -use std::ops::Range; -use std::sync::Arc; -use std::time::SystemTime; const EXPECTED_USER_DEFINED_METADATA: &str = "some-user-defined-metadata"; @@ -77,8 +79,8 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { // just any url that doesn't point to in memory object store object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], @@ -96,7 +98,7 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { let task_ctx = session_ctx.task_ctx(); let read = collect(Arc::new(parquet_exec), task_ctx).await.unwrap(); - let expected = vec![ + let expected = [ "+-----+----+----+", "| c1 | c2 | c3 |", "+-----+----+----+", @@ -186,6 +188,7 @@ async fn store_parquet_in_memory( last_modified: chrono::DateTime::from(SystemTime::now()), size: buf.len(), e_tag: None, + version: None, }; (meta, Bytes::from(buf)) diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs new file mode 100644 index 0000000000000..9f94a59a3e598 --- /dev/null +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::sync::Arc; + +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::datasource::TableProvider; +use datafusion::execution::context::SessionState; +use datafusion::prelude::SessionContext; +use datafusion_common::stats::Precision; +use datafusion_execution::cache::cache_manager::CacheManagerConfig; +use datafusion_execution::cache::cache_unit; +use datafusion_execution::cache::cache_unit::{ + DefaultFileStatisticsCache, DefaultListFilesCache, +}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + +use tempfile::tempdir; + +#[tokio::test] +async fn load_table_stats_with_session_level_cache() { + let testdata = datafusion::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let (cache1, _, state1) = get_cache_runtime_state(); + + // Create a separate DefaultFileStatisticsCache + let (cache2, _, state2) = get_cache_runtime_state(); + + let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + + let table1 = get_listing_table(&table_path, Some(cache1), &opt).await; + let table2 = get_listing_table(&table_path, Some(cache2), &opt).await; + + //Session 1 first time list files + assert_eq!(get_static_cache_size(&state1), 0); + let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); + + assert_eq!(exec1.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec1.statistics().unwrap().total_byte_size, + Precision::Exact(671) + ); + assert_eq!(get_static_cache_size(&state1), 1); + + //Session 2 first time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_static_cache_size(&state2), 0); + let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); + assert_eq!(exec2.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec2.statistics().unwrap().total_byte_size, + Precision::Exact(671) + ); + assert_eq!(get_static_cache_size(&state2), 1); + + //Session 1 second time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_static_cache_size(&state1), 1); + let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); + assert_eq!(exec3.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec3.statistics().unwrap().total_byte_size, + Precision::Exact(671) + ); + // List same file no increase + assert_eq!(get_static_cache_size(&state1), 1); +} + +#[tokio::test] +async fn list_files_with_session_level_cache() { + let p_name = "alltypes_plain.parquet"; + let testdata = datafusion::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, p_name); + + let temp_path1 = tempdir() + .unwrap() + .into_path() + .into_os_string() + .into_string() + .unwrap(); + let temp_filename1 = format!("{}/{}", temp_path1, p_name); + + let temp_path2 = tempdir() + .unwrap() + .into_path() + .into_os_string() + .into_string() + .unwrap(); + let temp_filename2 = format!("{}/{}", temp_path2, p_name); + + fs::copy(filename.clone(), temp_filename1).expect("panic"); + fs::copy(filename, temp_filename2).expect("panic"); + + let table_path = ListingTableUrl::parse(temp_path1).unwrap(); + + let (_, _, state1) = get_cache_runtime_state(); + + // Create a separate DefaultFileStatisticsCache + let (_, _, state2) = get_cache_runtime_state(); + + let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + + let table1 = get_listing_table(&table_path, None, &opt).await; + let table2 = get_listing_table(&table_path, None, &opt).await; + + //Session 1 first time list files + assert_eq!(get_list_file_cache_size(&state1), 0); + let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); + let parquet1 = exec1.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state1), 1); + let fg = &parquet1.base_config().file_groups; + assert_eq!(fg.len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); + + //Session 2 first time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_list_file_cache_size(&state2), 0); + let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); + let parquet2 = exec2.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state2), 1); + let fg2 = &parquet2.base_config().file_groups; + assert_eq!(fg2.len(), 1); + assert_eq!(fg2.first().unwrap().len(), 1); + + //Session 1 second time list files + //check session 1 cache result not show in session 2 + assert_eq!(get_list_file_cache_size(&state1), 1); + let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); + let parquet3 = exec3.as_any().downcast_ref::().unwrap(); + + assert_eq!(get_list_file_cache_size(&state1), 1); + let fg = &parquet3.base_config().file_groups; + assert_eq!(fg.len(), 1); + assert_eq!(fg.first().unwrap().len(), 1); + // List same file no increase + assert_eq!(get_list_file_cache_size(&state1), 1); +} + +async fn get_listing_table( + table_path: &ListingTableUrl, + static_cache: Option>, + opt: &ListingOptions, +) -> ListingTable { + let schema = opt + .infer_schema( + &SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ), + table_path, + ) + .await + .unwrap(); + let config1 = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt.clone()) + .with_schema(schema); + let table = ListingTable::try_new(config1).unwrap(); + if let Some(c) = static_cache { + table.with_cache(Some(c)) + } else { + table + } +} + +fn get_cache_runtime_state() -> ( + Arc, + Arc, + SessionState, +) { + let cache_config = CacheManagerConfig::default(); + let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + + let cache_config = cache_config + .with_files_statistics_cache(Some(file_static_cache.clone())) + .with_list_files_cache(Some(list_file_cache.clone())); + + let rt = Arc::new( + RuntimeEnv::new(RuntimeConfig::new().with_cache_manager(cache_config)).unwrap(), + ); + let state = SessionContext::new_with_config_rt(SessionConfig::default(), rt).state(); + + (file_static_cache, list_file_cache, state) +} + +fn get_static_cache_size(state1: &SessionState) -> usize { + state1 + .runtime_env() + .cache_manager + .get_file_statistic_cache() + .unwrap() + .len() +} + +fn get_list_file_cache_size(state1: &SessionState) -> usize { + state1 + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .len() +} diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 885834f939791..f214e8903a4f8 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -34,7 +34,7 @@ use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_optimizer::utils::{conjunction, disjunction, split_conjunction}; +use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; @@ -507,7 +507,7 @@ impl<'a> TestCase<'a> { ) -> RecordBatch { println!(" scan options: {scan_options:?}"); println!(" reading with filter {filter:?}"); - let ctx = SessionContext::with_config(scan_options.config()); + let ctx = SessionContext::new_with_config(scan_options.config()); let exec = self .test_parquet_file .create_scan(Some(filter.clone())) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 6f289e0c064bd..3f003c077d6a0 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -40,6 +40,7 @@ use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; +mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; @@ -153,7 +154,7 @@ impl ContextWithParquet { let parquet_path = file.path().to_string_lossy(); // now, setup a the file as a data source and run a query against it - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); ctx.register_parquet("t", &parquet_path, ParquetReadOptions::default()) .await @@ -290,7 +291,8 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { offset_nanos + t.parse::() .unwrap() - .timestamp_nanos() + .timestamp_nanos_opt() + .unwrap() }) }) .collect::>(); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 4337259c1e622..e1e8b8e66edd0 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -17,6 +17,7 @@ use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; + use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::PartitionedFile; @@ -30,6 +31,7 @@ use datafusion_common::{ScalarValue, Statistics, ToDFSchema}; use datafusion_expr::{col, lit, Expr}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; + use futures::StreamExt; use object_store::path::Path; use object_store::ObjectMeta; @@ -48,6 +50,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, }; let schema = ParquetFormat::default() @@ -71,8 +74,8 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { FileScanConfig { object_store_url, file_groups: vec![vec![partitioned_file]], - file_schema: schema, - statistics: Statistics::default(), + file_schema: schema.clone(), + statistics: Statistics::new_unknown(&schema), // file has 10 cols so index 12 should be month projection: None, limit: None, @@ -240,10 +243,11 @@ async fn test_prune( expected_row_pages_pruned: Option, expected_results: usize, ) { - let output = ContextWithParquet::new(case_data_type, Page) - .await - .query(sql) - .await; + let output: crate::parquet::TestOutput = + ContextWithParquet::new(case_data_type, Page) + .await + .query(sql) + .await; println!("{}", output.description()); assert_eq!(output.predicate_evaluation_errors(), expected_errors); diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 4d028c6f1b31d..25c62f18f5ba1 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::types::Int32Type; @@ -24,14 +26,13 @@ use datafusion::assert_batches_sorted_eq; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; -use datafusion_common::Result; -use datafusion_common::Statistics; +use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; -use std::sync::Arc; use tempfile::NamedTempFile; /// Test for reading data from multiple parquet files with different schemas and coercing them into a single schema. @@ -62,8 +63,8 @@ async fn multi_parquet_coercion() { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: None, limit: None, table_partition_cols: vec![], @@ -78,7 +79,7 @@ async fn multi_parquet_coercion() { let task_ctx = session_ctx.task_ctx(); let read = collect(Arc::new(parquet_exec), task_ctx).await.unwrap(); - let expected = vec![ + let expected = [ "+-------+----+------+", "| c1 | c2 | c3 |", "+-------+----+------+", @@ -126,8 +127,8 @@ async fn multi_parquet_coercion_projection() { FileScanConfig { object_store_url: ObjectStoreUrl::local_filesystem(), file_groups: vec![file_groups], + statistics: Statistics::new_unknown(&file_schema), file_schema, - statistics: Statistics::default(), projection: Some(vec![1, 0, 2]), limit: None, table_partition_cols: vec![], @@ -142,7 +143,7 @@ async fn multi_parquet_coercion_projection() { let task_ctx = session_ctx.task_ctx(); let read = collect(Arc::new(parquet_exec), task_ctx).await.unwrap(); - let expected = vec![ + let expected = [ "+----+-------+------+", "| c2 | c1 | c3 |", "+----+-------+------+", @@ -193,5 +194,6 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), size: metadata.len() as usize, e_tag: None, + version: None, } } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 894ceb1b9800c..abe6ab283aff4 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -17,16 +17,13 @@ //! Test queries on partitioned datasets -use arrow::datatypes::DataType; use std::collections::BTreeSet; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; use std::sync::Arc; -use async_trait::async_trait; -use bytes::Bytes; -use chrono::{TimeZone, Utc}; +use arrow::datatypes::DataType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::{ assert_batches_sorted_eq, @@ -39,11 +36,17 @@ use datafusion::{ prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; +use datafusion_common::stats::Precision; use datafusion_common::ScalarValue; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{TimeZone, Utc}; use futures::stream; use futures::stream::BoxStream; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, + path::Path, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, + ObjectMeta, ObjectStore, PutOptions, PutResult, }; use tokio::io::AsyncWrite; use url::Url; @@ -81,7 +84,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+------+-------+-----+", "| year | month | day |", "+------+-------+-----+", @@ -124,7 +127,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { let mut max_limit = match ScalarValue::try_from_array(results[0].column(0), 0)? { ScalarValue::Int64(Some(count)) => count, - s => panic!("Expected count as Int64 found {}", s.get_datatype()), + s => panic!("Expected count as Int64 found {}", s.data_type()), }; max_limit += 1; @@ -135,7 +138,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { let mut min_limit = match ScalarValue::try_from_array(last_batch.column(0), last_row_idx)? { ScalarValue::Int64(Some(count)) => count, - s => panic!("Expected count as Int64 found {}", s.get_datatype()), + s => panic!("Expected count as Int64 found {}", s.data_type()), }; min_limit -= 1; @@ -165,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + let month = match s { + ScalarValue::Utf8(Some(month)) => month, + s => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -188,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); @@ -217,7 +211,7 @@ async fn csv_filter_with_file_col() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+----+----+", "| c1 | c2 |", "+----+----+", @@ -253,7 +247,7 @@ async fn csv_filter_with_file_nonstring_col() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+----+----+------------+", "| c1 | c2 | date |", "+----+----+------------+", @@ -289,7 +283,7 @@ async fn csv_projection_on_partition() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+----+------------+", "| c1 | date |", "+----+------------+", @@ -326,13 +320,13 @@ async fn csv_grouping_by_partition() -> Result<()> { .collect() .await?; - let expected = vec![ - "+------------+-----------------+----------------------+", - "| date | COUNT(UInt8(1)) | COUNT(DISTINCT t.c1) |", - "+------------+-----------------+----------------------+", - "| 2021-10-26 | 100 | 5 |", - "| 2021-10-27 | 100 | 5 |", - "+------------+-----------------+----------------------+", + let expected = [ + "+------------+----------+----------------------+", + "| date | COUNT(*) | COUNT(DISTINCT t.c1) |", + "+------------+----------+----------------------+", + "| 2021-10-26 | 100 | 5 |", + "| 2021-10-27 | 100 | 5 |", + "+------------+----------+----------------------+", ]; assert_batches_sorted_eq!(expected, &result); @@ -366,7 +360,7 @@ async fn parquet_multiple_partitions() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+----+-----+", "| id | day |", "+----+-----+", @@ -412,7 +406,7 @@ async fn parquet_multiple_nonstring_partitions() -> Result<()> { .collect() .await?; - let expected = vec![ + let expected = [ "+----+-----+", "| id | day |", "+----+-----+", @@ -457,34 +451,30 @@ async fn parquet_statistics() -> Result<()> { //// NO PROJECTION //// let dataframe = ctx.sql("SELECT * FROM t").await?; let physical_plan = dataframe.create_physical_plan().await?; - assert_eq!(physical_plan.schema().fields().len(), 4); + let schema = physical_plan.schema(); + assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan - .statistics() - .column_statistics - .expect("col stats should be defined"); + let stat_cols = physical_plan.statistics()?.column_statistics; assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file - assert_eq!(stat_cols[0].null_count, Some(3)); + assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); // TODO assert partition column (1,2,3) stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::default()); - assert_eq!(stat_cols[2], ColumnStatistics::default()); - assert_eq!(stat_cols[3], ColumnStatistics::default()); + assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); + assert_eq!(stat_cols[2], ColumnStatistics::new_unknown(),); + assert_eq!(stat_cols[3], ColumnStatistics::new_unknown(),); //// WITH PROJECTION //// let dataframe = ctx.sql("SELECT mycol, day FROM t WHERE day='28'").await?; let physical_plan = dataframe.create_physical_plan().await?; - assert_eq!(physical_plan.schema().fields().len(), 2); + let schema = physical_plan.schema(); + assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan - .statistics() - .column_statistics - .expect("col stats should be defined"); + let stat_cols = physical_plan.statistics()?.column_statistics; assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file - assert_eq!(stat_cols[0].null_count, Some(1)); + assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::default()); + assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); Ok(()) } @@ -621,7 +611,12 @@ impl MirroringObjectStore { #[async_trait] impl ObjectStore for MirroringObjectStore { - async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> { + async fn put_opts( + &self, + _location: &Path, + _bytes: Bytes, + _opts: PutOptions, + ) -> object_store::Result { unimplemented!() } @@ -648,7 +643,20 @@ impl ObjectStore for MirroringObjectStore { self.files.iter().find(|x| *x == location).unwrap(); let path = std::path::PathBuf::from(&self.mirrored_file); let file = File::open(&path).unwrap(); - Ok(GetResult::File(file, path)) + let metadata = file.metadata().unwrap(); + let meta = ObjectMeta { + location: location.clone(), + last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), + size: metadata.len() as usize, + e_tag: None, + version: None, + }; + + Ok(GetResult { + range: 0..meta.size, + payload: GetResultPayload::File(file, path), + meta, + }) } async fn get_range( @@ -669,26 +677,16 @@ impl ObjectStore for MirroringObjectStore { Ok(data.into()) } - async fn head(&self, location: &Path) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - Ok(ObjectMeta { - location: location.clone(), - last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, - e_tag: None, - }) - } - async fn delete(&self, _location: &Path) -> object_store::Result<()> { unimplemented!() } - async fn list( + fn list( &self, prefix: Option<&Path>, - ) -> object_store::Result>> { + ) -> BoxStream<'_, object_store::Result> { let prefix = prefix.cloned().unwrap_or_default(); - Ok(Box::pin(stream::iter(self.files.iter().filter_map( + Box::pin(stream::iter(self.files.iter().filter_map( move |location| { // Don't return for exact prefix match let filter = location @@ -702,10 +700,11 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }) }) }, - )))) + ))) } async fn list_with_delimiter( @@ -739,6 +738,7 @@ impl ObjectStore for MirroringObjectStore { last_modified: Utc.timestamp_nanos(0), size: self.file_size as usize, e_tag: None, + version: None, }; objects.push(object); } diff --git a/datafusion/core/tests/row.rs b/datafusion/core/tests/row.rs deleted file mode 100644 index c68b422a4f063..0000000000000 --- a/datafusion/core/tests/row.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; -use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; -use datafusion::error::Result; -use datafusion::execution::context::SessionState; -use datafusion::physical_plan::{collect, ExecutionPlan}; -use datafusion::prelude::SessionContext; -use datafusion_row::reader::read_as_batch; -use datafusion_row::writer::write_batch_unchecked; -use object_store::{local::LocalFileSystem, path::Path, ObjectStore}; -use std::sync::Arc; - -#[tokio::test] -async fn test_with_parquet() -> Result<()> { - let ctx = SessionContext::new(); - let state = ctx.state(); - let task_ctx = state.task_ctx(); - let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7]); - let exec = - get_exec(&state, "alltypes_plain.parquet", projection.as_ref(), None).await?; - let schema = exec.schema().clone(); - - let batches = collect(exec, task_ctx).await?; - assert_eq!(1, batches.len()); - let batch = &batches[0]; - - let mut vector = vec![0; 20480]; - let row_offsets = { write_batch_unchecked(&mut vector, 0, batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&vector, schema, &row_offsets)? }; - assert_eq!(*batch, output_batch); - - Ok(()) -} - -async fn get_exec( - state: &SessionState, - file_name: &str, - projection: Option<&Vec>, - limit: Option, -) -> Result> { - let testdata = datafusion::test_util::parquet_test_data(); - let filename = format!("{testdata}/{file_name}"); - - let path = Path::from_filesystem_path(filename).unwrap(); - - let format = ParquetFormat::default(); - let object_store = Arc::new(LocalFileSystem::new()) as Arc; - let object_store_url = ObjectStoreUrl::local_filesystem(); - - let meta = object_store.head(&path).await.unwrap(); - - let file_schema = format - .infer_schema(state, &object_store, &[meta.clone()]) - .await - .expect("Schema inference"); - let statistics = format - .infer_stats(state, &object_store, file_schema.clone(), &meta) - .await - .expect("Stats inference"); - let file_groups = vec![vec![meta.into()]]; - let exec = format - .create_physical_plan( - state, - FileScanConfig { - object_store_url, - file_schema, - file_groups, - statistics, - projection: projection.cloned(), - limit, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - ) - .await?; - Ok(exec) -} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 251063e396b60..af6d0d5f4e245 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -17,8 +17,6 @@ use super::*; use datafusion::scalar::ScalarValue; -use datafusion::test_util::scan_empty; -use datafusion_common::cast::as_float64_array; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -47,346 +45,24 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); - if let ScalarValue::List(Some(mut v), _) = ScalarValue::try_from_array(column, 0)? { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - v.sort_by(cmp); - assert_eq!( - *v, - vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::UInt32(Some(2)), - ScalarValue::UInt32(Some(3)), - ScalarValue::UInt32(Some(4)), - ScalarValue::UInt32(Some(5)) - ] - ); - } else { - unreachable!(); - } - - Ok(()) -} - -#[tokio::test] -async fn aggregate() -> Result<()> { - let results = execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| 60 | 220 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_empty() -> Result<()> { - // The predicate on this query purposely generates no results - let results = - execute_with_partition("SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000", 4) - .await - .unwrap(); - - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+--------------+", - "| SUM(test.c1) | SUM(test.c2) |", - "+--------------+--------------+", - "| | |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg() -> Result<()> { - let results = execute_with_partition("SELECT AVG(c1), AVG(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+--------------+", - "| AVG(test.c1) | AVG(test.c2) |", - "+--------------+--------------+", - "| 1.5 | 5.5 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_max() -> Result<()> { - let results = execute_with_partition("SELECT MAX(c1), MAX(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+--------------+", - "| MAX(test.c1) | MAX(test.c2) |", - "+--------------+--------------+", - "| 3 | 10 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min() -> Result<()> { - let results = execute_with_partition("SELECT MIN(c1), MIN(c2) FROM test", 4).await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+--------------+", - "| MIN(test.c1) | MIN(test.c2) |", - "+--------------+--------------+", - "| 0 | 1 |", - "+--------------+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped() -> Result<()> { - let results = - execute_with_partition("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; - - let expected = vec![ - "+----+--------------+", - "| c1 | SUM(test.c2) |", - "+----+--------------+", - "| 0 | 55 |", - "| 1 | 55 |", - "| 2 | 55 |", - "| 3 | 55 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_avg() -> Result<()> { - let results = - execute_with_partition("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; - - let expected = vec![ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "| 0 | 5.5 |", - "| 1 | 5.5 |", - "| 2 | 5.5 |", - "| 3 | 5.5 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_empty() -> Result<()> { - let results = execute_with_partition( - "SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", - 4, - ) - .await?; - - let expected = vec![ - "+----+--------------+", - "| c1 | AVG(test.c2) |", - "+----+--------------+", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_max() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; - - let expected = vec![ - "+----+--------------+", - "| c1 | MAX(test.c2) |", - "+----+--------------+", - "| 0 | 10 |", - "| 1 | 10 |", - "| 2 | 10 |", - "| 3 | 10 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_grouped_min() -> Result<()> { - let results = - execute_with_partition("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; - - let expected = vec![ - "+----+--------------+", - "| c1 | MIN(test.c2) |", - "+----+--------------+", - "| 0 | 1 |", - "| 1 | 1 |", - "| 2 | 1 |", - "| 3 | 1 |", - "+----+--------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9706712283358269 |", - "| 0.2667177795079635 | 0.9965400387585364 |", - "| 0.3600766362333053 | 0.9706712283358269 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_min_max_w_custom_window_frames_unbounded_start() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = - "SELECT - MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, - MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 - FROM aggregate_test_100 - ORDER BY C9 - LIMIT 5"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+---------------------+--------------------+", - "| min1 | max1 |", - "+---------------------+--------------------+", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "| 0.01479305307777301 | 0.9965400387585364 |", - "| 0.01479305307777301 | 0.9800193410444061 |", - "+---------------------+--------------------+", - ]; - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn aggregate_avg_add() -> Result<()> { - let results = execute_with_partition( - "SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test", - 4, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+--------------+-------------------------+-------------------------+-------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |", - "+--------------+-------------------------+-------------------------+-------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+-------------------------+-------------------------+-------------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn case_sensitive_identifiers_aggregates() { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = vec![ - "+----------+", - "| MAX(t.i) |", - "+----------+", - "| 1 |", - "+----------+", - ]; - - let results = plan_and_collect(&ctx, "SELECT max(i) FROM t") - .await - .unwrap(); - - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT MAX(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - // Using double quotes allows specifying the function name with capitalization - let err = plan_and_collect(&ctx, "SELECT \"MAX\"(i) FROM t") - .await - .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function 'MAX'")); - - let results = plan_and_collect(&ctx, "SELECT \"max\"(i) FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); -} - -#[tokio::test] -async fn count_basic() -> Result<()> { - let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; - assert_eq!(results.len(), 1); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; + let mut scalars = scalar_vec[0].clone(); + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + scalars.sort_by(cmp); + assert_eq!( + scalars, + vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::UInt32(Some(2)), + ScalarValue::UInt32(Some(3)), + ScalarValue::UInt32(Some(4)), + ScalarValue::UInt32(Some(5)) + ] + ); - let expected = vec![ - "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", - "+----------------+----------------+", - "| 10 | 10 |", - "+----------------+----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -396,7 +72,7 @@ async fn count_partitioned() -> Result<()> { execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let expected = vec![ + let expected = [ "+----------------+----------------+", "| COUNT(test.c1) | COUNT(test.c2) |", "+----------------+----------------+", @@ -412,7 +88,7 @@ async fn count_aggregated() -> Result<()> { let results = execute_with_partition("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?; - let expected = vec![ + let expected = [ "+----+----------------+", "| c1 | COUNT(test.c2) |", "+----+----------------+", @@ -499,162 +175,6 @@ async fn count_aggregated_cube() -> Result<()> { Ok(()) } -#[tokio::test] -async fn count_multi_expr() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT count(c1, c2) FROM test"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+------------------------+", - "| COUNT(test.c1,test.c2) |", - "+------------------------+", - "| 2 |", - "+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn count_multi_expr_group_by() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - Field::new("c3", DataType::Int32, true), - ])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![ - Some(0), - None, - Some(1), - Some(2), - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - Some(0), - None, - None, - ])), - Arc::new(Int32Array::from(vec![ - Some(10), - Some(10), - Some(10), - Some(10), - Some(10), - ])), - ], - )?; - - let ctx = SessionContext::new(); - ctx.register_batch("test", data)?; - let sql = "SELECT c3, count(c1, c2) FROM test group by c3"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+----+------------------------+", - "| c3 | COUNT(test.c1,test.c2) |", - "+----+------------------------+", - "| 10 | 2 |", - "+----+------------------------+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn simple_avg() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT AVG(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // avg(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - -#[tokio::test] -async fn simple_mean() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], - )?; - - let ctx = SessionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - let result = plan_and_collect(&ctx, "SELECT MEAN(a) FROM t").await?; - - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - - let values = as_float64_array(batch.column(0)).expect("failed to cast version"); - assert_eq!(values.len(), 1); - // mean(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); - Ok(()) -} - async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { @@ -739,15 +259,13 @@ async fn count_distinct_integers_aggregated_single_partition() -> Result<()> { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; - let expected = vec![ - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - ]; + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+"]; assert_batches_sorted_eq!(expected, &results); Ok(()) @@ -765,49 +283,22 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; - let expected = vec![ - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", + let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", - "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - ]; + "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+"]; assert_batches_sorted_eq!(expected, &results); Ok(()) } -#[tokio::test] -async fn aggregate_with_alias() -> Result<()> { - let ctx = SessionContext::new(); - let state = ctx.state(); - - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - ])); - - let plan = scan_empty(None, schema.as_ref(), None)? - .aggregate(vec![col("c1")], vec![sum(col("c2"))])? - .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])? - .build()?; - - let plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&Arc::new(plan)).await?; - assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); - assert_eq!( - "total_salary", - physical_plan.schema().field(1).name().as_str() - ); - Ok(()) -} - #[tokio::test] async fn test_accumulator_row_accumulator() -> Result<()> { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, COUNT(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 @@ -817,8 +308,7 @@ async fn test_accumulator_row_accumulator() -> Result<()> { LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----+----+--------------------------------+-----------+--------------------------------+------------+--------------------+--------------------------------+------+--------------+", + let expected = ["+----+----+--------------------------------+-----------+--------------------------------+------------+--------------------+--------------------------------+------+--------------+", "| c1 | c2 | min1 | min2 | max1 | max2 | avg1 | min3 | cnt1 | sum1 |", "+----+----+--------------------------------+-----------+--------------------------------+------------+--------------------+--------------------------------+------+--------------+", "| a | 1 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | 774637006 | waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs | 4015442341 | 2437927011.0 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | 5 | 6094771121.5 |", @@ -826,8 +316,7 @@ async fn test_accumulator_row_accumulator() -> Result<()> { "| a | 3 | Amn2K87Db5Es3dFQO9cw9cvpAM6h35 | 431948861 | oLZ21P2JEDooxV1pU31cIxQHEeeoLu | 3998790955 | 2225685115.1666665 | Amn2K87Db5Es3dFQO9cw9cvpAM6h35 | 6 | 6676994872.5 |", "| a | 4 | KJFcmTVjdkCMv94wYCtfHMFhzyRsmH | 466439833 | ydkwycaISlYSlEq3TlkS2m15I2pcp8 | 2502326480 | 1655431654.0 | KJFcmTVjdkCMv94wYCtfHMFhzyRsmH | 4 | 3310812222.5 |", "| a | 5 | MeSTAXq8gVxVjbEjgkvU9YLte0X9uE | 141047417 | QJYm7YRA3YetcBHI5wkMZeLXVmfuNy | 2496054700 | 1216992989.6666667 | MeSTAXq8gVxVjbEjgkvU9YLte0X9uE | 3 | 1825431770.0 |", - "+----+----+--------------------------------+-----------+--------------------------------+------------+--------------------+--------------------------------+------+--------------+", - ]; + "+----+----+--------------------------------+-----------+--------------------------------+------------+--------------------+--------------------------------+------+--------------+"]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/core/tests/sql/arrow_files.rs b/datafusion/core/tests/sql/arrow_files.rs deleted file mode 100644 index e74294b312904..0000000000000 --- a/datafusion/core/tests/sql/arrow_files.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use datafusion::execution::options::ArrowReadOptions; - -use super::*; - -async fn register_arrow(ctx: &mut SessionContext) { - ctx.register_arrow( - "arrow_simple", - "tests/data/example.arrow", - ArrowReadOptions::default(), - ) - .await - .unwrap(); -} - -#[tokio::test] -async fn arrow_query() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "SELECT * FROM arrow_simple"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----+-----+-------+", - "| f0 | f1 | f2 |", - "+----+-----+-------+", - "| 1 | foo | true |", - "| 2 | bar | |", - "| 3 | baz | false |", - "| 4 | | true |", - "+----+-----+-------+", - ]; - - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn arrow_explain() { - let mut ctx = SessionContext::new(); - register_arrow(&mut ctx).await; - let sql = "EXPLAIN SELECT * FROM arrow_simple"; - let actual = execute(&ctx, sql).await; - let actual = normalize_vec_for_explain(actual); - let expected = vec![ - vec![ - "logical_plan", - "TableScan: arrow_simple projection=[f0, f1, f2]", - ], - vec![ - "physical_plan", - "ArrowExec: file_groups={1 group: [[WORKING_DIR/tests/data/example.arrow]]}, projection=[f0, f1, f2]\n", - ], - ]; - - assert_eq!(expected, actual); -} diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index aa34552044d46..b1434dddee50f 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -26,11 +26,11 @@ async fn create_custom_table() -> Result<()> { let cfg = RuntimeConfig::new(); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; ctx.sql(sql).await.unwrap(); @@ -48,11 +48,11 @@ async fn create_external_table_with_ddl() -> Result<()> { let cfg = RuntimeConfig::new(); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); - let ctx = SessionContext::with_state(state); + let ctx = SessionContext::new_with_state(state); let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS MOCKTABLE LOCATION 'mockprotocol://path/to/table';"; ctx.sql(sql).await.unwrap(); diff --git a/datafusion/core/tests/sql/csv_files.rs b/datafusion/core/tests/sql/csv_files.rs new file mode 100644 index 0000000000000..5ed0068d61359 --- /dev/null +++ b/datafusion/core/tests/sql/csv_files.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[tokio::test] +async fn csv_custom_quote() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Utf8, false), + ])); + let filename = format!("partition.{}", "csv"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value{index:}"); + let data = format!("~{text1}~,~{text2}~\r\n"); + file.write_all(data.as_bytes())?; + } + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .quote(b'~'), + ) + .await?; + + let results = plan_and_collect(&ctx, "SELECT * from test").await?; + + let expected = vec![ + "+-----+--------+", + "| c1 | c2 |", + "+-----+--------+", + "| id0 | value0 |", + "| id1 | value1 |", + "| id2 | value2 |", + "| id3 | value3 |", + "| id4 | value4 |", + "| id5 | value5 |", + "| id6 | value6 |", + "| id7 | value7 |", + "| id8 | value8 |", + "| id9 | value9 |", + "+-----+--------+", + ]; + + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + +#[tokio::test] +async fn csv_custom_escape() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Utf8, false), + ])); + let filename = format!("partition.{}", "csv"); + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for index in 0..10 { + let text1 = format!("id{index:}"); + let text2 = format!("value\\\"{index:}"); + let data = format!("\"{text1}\",\"{text2}\"\r\n"); + file.write_all(data.as_bytes())?; + } + + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .escape(b'\\'), + ) + .await?; + + let results = plan_and_collect(&ctx, "SELECT * from test").await?; + + let expected = vec![ + "+-----+---------+", + "| c1 | c2 |", + "+-----+---------+", + "| id0 | value\"0 |", + "| id1 | value\"1 |", + "| id2 | value\"2 |", + "| id3 | value\"3 |", + "| id4 | value\"4 |", + "| id5 | value\"5 |", + "| id6 | value\"6 |", + "| id7 | value\"7 |", + "| id8 | value\"8 |", + "| id9 | value\"9 |", + "+-----+---------+", + ]; + + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 971dea81283fe..37f8cefc90809 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -16,8 +16,10 @@ // under the License. use super::*; + use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::metrics::Timestamp; #[tokio::test] async fn explain_analyze_baseline_metrics() { @@ -26,7 +28,7 @@ async fn explain_analyze_baseline_metrics() { let config = SessionConfig::new() .with_target_partitions(3) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv_by_sql(&ctx).await; // a query with as many operators as we have metrics for let sql = "EXPLAIN ANALYZE \ @@ -78,7 +80,7 @@ async fn explain_analyze_baseline_metrics() { ); assert_metrics!( &formatted, - "ProjectionExec: expr=[COUNT(UInt8(1))", + "ProjectionExec: expr=[COUNT(*)", "metrics=[output_rows=1, elapsed_compute=" ); assert_metrics!( @@ -142,11 +144,11 @@ async fn explain_analyze_baseline_metrics() { metrics.iter().for_each(|m| match m.value() { MetricValue::StartTimestamp(ts) => { saw_start = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); + assert!(nanos_from_timestamp(ts) > 0); } MetricValue::EndTimestamp(ts) => { saw_end = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); + assert!(nanos_from_timestamp(ts) > 0); } _ => {} }); @@ -161,7 +163,9 @@ async fn explain_analyze_baseline_metrics() { datafusion::physical_plan::accept(physical_plan.as_ref(), &mut TimeValidator {}) .unwrap(); } - +fn nanos_from_timestamp(ts: &Timestamp) -> i64 { + ts.value().unwrap().timestamp_nanos_opt().unwrap() +} #[tokio::test] async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation @@ -210,7 +214,9 @@ async fn csv_explain_plans() { // // verify the grahviz format of the plan let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "// Begin DataFusion GraphViz Plan,", + "// display it online here: https://dreampuf.github.io/GraphvizOnline", + "", "digraph {", " subgraph cluster_1", " {", @@ -282,7 +288,9 @@ async fn csv_explain_plans() { // // verify the grahviz format of the plan let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "// Begin DataFusion GraphViz Plan,", + "// display it online here: https://dreampuf.github.io/GraphvizOnline", + "", "digraph {", " subgraph cluster_1", " {", @@ -427,7 +435,9 @@ async fn csv_explain_verbose_plans() { // // verify the grahviz format of the plan let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "// Begin DataFusion GraphViz Plan,", + "// display it online here: https://dreampuf.github.io/GraphvizOnline", + "", "digraph {", " subgraph cluster_1", " {", @@ -499,7 +509,9 @@ async fn csv_explain_verbose_plans() { // // verify the grahviz format of the plan let expected = vec![ - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)", + "// Begin DataFusion GraphViz Plan,", + "// display it online here: https://dreampuf.github.io/GraphvizOnline", + "", "digraph {", " subgraph cluster_1", " {", @@ -548,7 +560,7 @@ async fn csv_explain_verbose_plans() { // Since the plan contains path that are environmentally // dependant(e.g. full path of the test file), only verify // important content - assert_contains!(&actual, "logical_plan after push_down_projection"); + assert_contains!(&actual, "logical_plan after optimize_projections"); assert_contains!(&actual, "physical_plan"); assert_contains!(&actual, "FilterExec: c2@1 > 10"); assert_contains!(actual, "ProjectionExec: expr=[c1@0 as c1]"); @@ -563,7 +575,7 @@ async fn explain_analyze_runs_optimizers() { // This happens as an optimization pass where count(*) can be // answered using statistics only. - let expected = "EmptyExec: produce_one_row=true"; + let expected = "PlaceholderRowExec"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; @@ -587,7 +599,7 @@ async fn test_physical_plan_display_indent() { let config = SessionConfig::new() .with_target_partitions(9000) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ FROM aggregate_test_100 \ @@ -599,12 +611,12 @@ async fn test_physical_plan_display_indent() { let physical_plan = dataframe.create_physical_plan().await.unwrap(); let expected = vec![ "GlobalLimitExec: skip=0, fetch=10", - " SortPreservingMergeExec: [the_min@2 DESC]", - " SortExec: fetch=10, expr=[the_min@2 DESC]", + " SortPreservingMergeExec: [the_min@2 DESC], fetch=10", + " SortExec: TopK(fetch=10), expr=[the_min@2 DESC]", " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 9000), input_partitions=9000", + " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", " CoalesceBatchesExec: target_batch_size=4096", " FilterExec: c12@1 < 10", @@ -613,7 +625,7 @@ async fn test_physical_plan_display_indent() { ]; let normalizer = ExplainNormalizer::new(); - let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) + let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) .trim() .lines() // normalize paths @@ -631,7 +643,7 @@ async fn test_physical_plan_display_indent_multi_children() { let config = SessionConfig::new() .with_target_partitions(9000) .with_batch_size(4096); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); // ensure indenting works for nodes with multiple children register_aggregate_csv(&ctx).await.unwrap(); let sql = "SELECT c1 \ @@ -646,20 +658,20 @@ async fn test_physical_plan_display_indent_multi_children() { let expected = vec![ "ProjectionExec: expr=[c1@0 as c1]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 0 })]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 9000), input_partitions=9000", + " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", " CsvExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c2\", index: 0 }], 9000), input_partitions=9000", + " RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000", " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", " ProjectionExec: expr=[c1@0 as c2]", " CsvExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true", ]; let normalizer = ExplainNormalizer::new(); - let actual = format!("{}", displayable(physical_plan.as_ref()).indent()) + let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) .trim() .lines() // normalize paths @@ -687,7 +699,7 @@ async fn csv_explain_analyze() { // Only test basic plumbing and try to avoid having to change too // many things. explain_analyze_baseline_metrics covers the values // in greater depth - let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))], metrics=[output_rows=5"; + let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(*)], metrics=[output_rows=5"; assert_contains!(&formatted, needle); let verbose_needle = "Output Rows"; @@ -766,7 +778,7 @@ async fn csv_explain_analyze_verbose() { async fn explain_logical_plan_only() { let mut config = ConfigOptions::new(); config.explain.logical_plan_only = true; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); @@ -774,9 +786,9 @@ async fn explain_logical_plan_only() { let expected = vec![ vec![ "logical_plan", - "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]\ \n SubqueryAlias: t\ - \n Projection: column1\ + \n Projection: \ \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" ]]; assert_eq!(expected, actual); @@ -786,16 +798,37 @@ async fn explain_logical_plan_only() { async fn explain_physical_plan_only() { let mut config = ConfigOptions::new(); config.explain.physical_plan_only = true; - let ctx = SessionContext::with_config(config.into()); + let ctx = SessionContext::new_with_config(config.into()); let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); let expected = vec![vec![ "physical_plan", - "ProjectionExec: expr=[2 as COUNT(UInt8(1))]\ - \n EmptyExec: produce_one_row=true\ + "ProjectionExec: expr=[2 as COUNT(*)]\ + \n PlaceholderRowExec\ \n", ]]; assert_eq!(expected, actual); } + +#[tokio::test] +async fn csv_explain_analyze_with_statistics() { + let mut config = ConfigOptions::new(); + config.explain.physical_plan_only = true; + config.explain.show_statistics = true; + let ctx = SessionContext::new_with_config(config.into()); + register_aggregate_csv_by_sql(&ctx).await; + + let sql = "EXPLAIN ANALYZE SELECT c1 FROM aggregate_test_100"; + let actual = execute_to_batches(&ctx, sql).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual) + .unwrap() + .to_string(); + + // should contain scan statistics + assert_contains!( + &formatted, + ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" + ); +} diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 6783670545c3c..7d41ad4a881c5 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -60,9 +60,82 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("atan2(NULL, NULL)", "NULL"); test_expression!("atan2(1, NULL)", "NULL"); test_expression!("atan2(NULL, 1)", "NULL"); + test_expression!("nanvl(NULL, NULL)", "NULL"); + test_expression!("nanvl(1, NULL)", "NULL"); + test_expression!("nanvl(NULL, 1)", "NULL"); + test_expression!("isnan(NULL)", "NULL"); + test_expression!("iszero(NULL)", "NULL"); Ok(()) } +#[tokio::test] +#[cfg_attr(not(feature = "crypto_expressions"), ignore)] +async fn test_encoding_expressions() -> Result<()> { + // Input Utf8 + test_expression!("encode('tom','base64')", "dG9t"); + test_expression!("arrow_cast(decode('dG9t','base64'), 'Utf8')", "tom"); + test_expression!("encode('tom','hex')", "746f6d"); + test_expression!("arrow_cast(decode('746f6d','hex'), 'Utf8')", "tom"); + + // Input LargeUtf8 + test_expression!("encode(arrow_cast('tom', 'LargeUtf8'),'base64')", "dG9t"); + test_expression!( + "arrow_cast(decode(arrow_cast('dG9t', 'LargeUtf8'),'base64'), 'Utf8')", + "tom" + ); + test_expression!("encode(arrow_cast('tom', 'LargeUtf8'),'hex')", "746f6d"); + test_expression!( + "arrow_cast(decode(arrow_cast('746f6d', 'LargeUtf8'),'hex'), 'Utf8')", + "tom" + ); + + // Input Binary + test_expression!("encode(arrow_cast('tom', 'Binary'),'base64')", "dG9t"); + test_expression!( + "arrow_cast(decode(arrow_cast('dG9t', 'Binary'),'base64'), 'Utf8')", + "tom" + ); + test_expression!("encode(arrow_cast('tom', 'Binary'),'hex')", "746f6d"); + test_expression!( + "arrow_cast(decode(arrow_cast('746f6d', 'Binary'),'hex'), 'Utf8')", + "tom" + ); + + // Input LargeBinary + test_expression!("encode(arrow_cast('tom', 'LargeBinary'),'base64')", "dG9t"); + test_expression!( + "arrow_cast(decode(arrow_cast('dG9t', 'LargeBinary'),'base64'), 'Utf8')", + "tom" + ); + test_expression!("encode(arrow_cast('tom', 'LargeBinary'),'hex')", "746f6d"); + test_expression!( + "arrow_cast(decode(arrow_cast('746f6d', 'LargeBinary'),'hex'), 'Utf8')", + "tom" + ); + + // NULL + test_expression!("encode(NULL,'base64')", "NULL"); + test_expression!("decode(NULL,'base64')", "NULL"); + test_expression!("encode(NULL,'hex')", "NULL"); + test_expression!("decode(NULL,'hex')", "NULL"); + + // Empty string + test_expression!("encode('','base64')", ""); + test_expression!("decode('','base64')", ""); + test_expression!("encode('','hex')", ""); + test_expression!("decode('','hex')", ""); + + Ok(()) +} + +#[should_panic(expected = "Invalid timezone \\\"Foo\\\": 'Foo' is not a valid timezone")] +#[tokio::test] +async fn test_array_cast_invalid_timezone_will_panic() { + let ctx = SessionContext::new(); + let sql = "SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some(\"Foo\"))')"; + execute(&ctx, sql).await; +} + #[tokio::test] #[cfg_attr(not(feature = "crypto_expressions"), ignore)] async fn test_crypto_expressions() -> Result<()> { @@ -223,10 +296,11 @@ async fn test_interval_expressions() -> Result<()> { "interval '0.5 minute'", "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" ); - test_expression!( - "interval '.5 minute'", - "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" - ); + // https://github.com/apache/arrow-rs/issues/4424 + // test_expression!( + // "interval '.5 minute'", + // "0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs" + // ); test_expression!( "interval '5 minute'", "0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs" @@ -384,8 +458,10 @@ async fn test_substring_expr() -> Result<()> { Ok(()) } +/// Test string expressions test split into two batches +/// to prevent stack overflow error #[tokio::test] -async fn test_string_expressions() -> Result<()> { +async fn test_string_expressions_batch1() -> Result<()> { test_expression!("ascii('')", "0"); test_expression!("ascii('x')", "120"); test_expression!("ascii(NULL)", "NULL"); @@ -437,6 +513,13 @@ async fn test_string_expressions() -> Result<()> { test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); test_expression!("rtrim('testxxzx', 'xyz')", "test"); test_expression!("rtrim(NULL, 'xyz')", "NULL"); + Ok(()) +} + +/// Test string expressions test split into two batches +/// to prevent stack overflow error +#[tokio::test] +async fn test_string_expressions_batch2() -> Result<()> { test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); test_expression!("split_part(NULL, '~@~', 20)", "NULL"); @@ -512,21 +595,28 @@ async fn test_regex_expressions() -> Result<()> { #[tokio::test] async fn test_cast_expressions() -> Result<()> { + test_expression!("CAST('0' AS INT)", "0"); + test_expression!("CAST(NULL AS INT)", "NULL"); + test_expression!("TRY_CAST('0' AS INT)", "0"); + test_expression!("TRY_CAST('x' AS INT)", "NULL"); + Ok(()) +} + +#[tokio::test] +#[ignore] +// issue: https://github.com/apache/arrow-datafusion/issues/6596 +async fn test_array_cast_expressions() -> Result<()> { test_expression!("CAST([1,2,3,4] AS INT[])", "[1, 2, 3, 4]"); test_expression!( "CAST([1,2,3,4] AS NUMERIC(10,4)[])", "[1.0000, 2.0000, 3.0000, 4.0000]" ); - test_expression!("CAST('0' AS INT)", "0"); - test_expression!("CAST(NULL AS INT)", "NULL"); - test_expression!("TRY_CAST('0' AS INT)", "0"); - test_expression!("TRY_CAST('x' AS INT)", "NULL"); Ok(()) } #[tokio::test] async fn test_random_expression() -> Result<()> { - let ctx = create_ctx(); + let ctx = SessionContext::new(); let sql = "SELECT random() r1"; let actual = execute(&ctx, sql).await; let r1 = actual[0][0].parse::().unwrap(); @@ -537,7 +627,7 @@ async fn test_random_expression() -> Result<()> { #[tokio::test] async fn test_uuid_expression() -> Result<()> { - let ctx = create_ctx(); + let ctx = SessionContext::new(); let sql = "SELECT uuid()"; let actual = execute(&ctx, sql).await; let uuid = actual[0][0].parse::().unwrap(); @@ -549,7 +639,7 @@ async fn test_uuid_expression() -> Result<()> { async fn test_extract_date_part() -> Result<()> { test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000.0"); test_expression!( - "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00')", "2020.0" ); test_expression!("date_part('QUARTER', CAST('2000-01-01' AS DATE))", "1.0"); @@ -596,37 +686,56 @@ async fn test_extract_date_part() -> Result<()> { "12.0" ); test_expression!( - "EXTRACT(second FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", + "12.12345678" + ); + test_expression!( + "EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", + "12123.45678" + ); + test_expression!( + "EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", + "12123456.78" + ); + test_expression!( + "EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00')", + "1.212345678e10" + ); + test_expression!( + "date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00')", "12.12345678" ); test_expression!( - "EXTRACT(millisecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", "12123.45678" ); test_expression!( - "EXTRACT(microsecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", "12123456.78" ); test_expression!( - "EXTRACT(nanosecond FROM to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00')", "1.212345678e10" ); + + // Keep precision when coercing Utf8 to Timestamp test_expression!( - "date_part('second', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('second', '2020-09-08T12:00:12.12345678+00:00')", "12.12345678" ); test_expression!( - "date_part('millisecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00')", "12123.45678" ); test_expression!( - "date_part('microsecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00')", "12123456.78" ); test_expression!( - "date_part('nanosecond', to_timestamp('2020-09-08T12:00:12.12345678+00:00'))", + "date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00')", "1.212345678e10" ); + Ok(()) } @@ -796,18 +905,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> { Ok(()) } -#[tokio::test] -async fn csv_query_avg_sqrt() -> Result<()> { - let ctx = create_ctx(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let mut actual = execute(&ctx, sql).await; - actual.sort(); - let expected = vec![vec!["0.6706002946036462"]]; - assert_float_eq(&expected, &actual); - Ok(()) -} - #[tokio::test] async fn nested_subquery() -> Result<()> { let ctx = SessionContext::new(); @@ -826,13 +923,11 @@ async fn nested_subquery() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; // the purpose of this test is just to make sure the query produces a valid plan #[rustfmt::skip] - let expected = vec![ - "+-----+", + let expected = ["+-----+", "| cnt |", "+-----+", "| 0 |", - "+-----+" - ]; + "+-----+"]; assert_batches_eq!(expected, &actual); Ok(()) } diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index b4a92db3fc371..58f0ac21d951c 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{DataType, TimeUnit}; #[tokio::test] async fn group_by_date_trunc() -> Result<()> { @@ -55,7 +57,7 @@ async fn group_by_date_trunc() -> Result<()> { "SELECT date_trunc('week', t1) as week, SUM(c2) FROM test GROUP BY date_trunc('week', t1)", ).await?; - let expected = vec![ + let expected = [ "+---------------------+--------------+", "| week | SUM(test.c2) |", "+---------------------+--------------+", @@ -68,6 +70,95 @@ async fn group_by_date_trunc() -> Result<()> { Ok(()) } +#[tokio::test] +async fn group_by_limit() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = create_groupby_context(&tmp_dir).await?; + + let sql = "SELECT trace_id, MAX(ts) from traces group by trace_id order by MAX(ts) desc limit 4"; + let dataframe = ctx.sql(sql).await?; + + // ensure we see `lim=[4]` + let physical_plan = dataframe.create_physical_plan().await?; + let mut expected_physical_plan = r#" +GlobalLimitExec: skip=0, fetch=4 + SortExec: TopK(fetch=4), expr=[MAX(traces.ts)@1 DESC] + AggregateExec: mode=Single, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.ts)], lim=[4] + "#.trim().to_string(); + let actual_phys_plan = + format_plan(physical_plan.clone(), &mut expected_physical_plan); + assert_eq!(actual_phys_plan, expected_physical_plan); + + let batches = collect(physical_plan, ctx.task_ctx()).await?; + let expected = r#" ++----------+----------------------+ +| trace_id | MAX(traces.ts) | ++----------+----------------------+ +| 9 | 2020-12-01T00:00:18Z | +| 8 | 2020-12-01T00:00:17Z | +| 7 | 2020-12-01T00:00:16Z | +| 6 | 2020-12-01T00:00:15Z | ++----------+----------------------+ +"# + .trim(); + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + + Ok(()) +} + +fn format_plan( + physical_plan: Arc, + expected_phys_plan: &mut String, +) -> String { + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let last_line = actual_phys_plan + .as_str() + .lines() + .last() + .expect("Plan should not be empty"); + + expected_phys_plan.push('\n'); + expected_phys_plan.push_str(last_line); + expected_phys_plan.push('\n'); + actual_phys_plan +} + +async fn create_groupby_context(tmp_dir: &TempDir) -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + false, + ), + ])); + + // generate a file + let filename = "traces.csv"; + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for trace_id in 0..10 { + for ts in 0..10 { + let ts = trace_id + ts; + let data = format!("\"{trace_id}\",2020-12-01T00:00:{ts:02}.000Z\n"); + file.write_all(data.as_bytes())?; + } + } + + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + ctx.register_csv( + "traces", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + ) + .await?; + Ok(ctx) +} + #[tokio::test] async fn group_by_dictionary() { async fn run_test_case() { @@ -103,7 +194,7 @@ async fn group_by_dictionary() { .await .expect("ran plan correctly"); - let expected = vec![ + let expected = [ "+------+--------------+", "| dict | COUNT(t.val) |", "+------+--------------+", @@ -120,7 +211,7 @@ async fn group_by_dictionary() { .await .expect("ran plan correctly"); - let expected = vec![ + let expected = [ "+-----+---------------+", "| val | COUNT(t.dict) |", "+-----+---------------+", @@ -139,14 +230,14 @@ async fn group_by_dictionary() { .await .expect("ran plan correctly"); - let expected = vec![ - "+-------+------------------------+", - "| t.val | COUNT(DISTINCT t.dict) |", - "+-------+------------------------+", - "| 1 | 2 |", - "| 2 | 2 |", - "| 4 | 1 |", - "+-------+------------------------+", + let expected = [ + "+-----+------------------------+", + "| val | COUNT(DISTINCT t.dict) |", + "+-----+------------------------+", + "| 1 | 2 |", + "| 2 | 2 |", + "| 4 | 1 |", + "+-----+------------------------+", ]; assert_batches_sorted_eq!(expected, &results); } diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs deleted file mode 100644 index 68ac6c5d62904..0000000000000 --- a/datafusion/core/tests/sql/information_schema.rs +++ /dev/null @@ -1,220 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use async_trait::async_trait; -use datafusion::execution::context::SessionState; -use datafusion::{ - catalog::{ - catalog::{CatalogProvider, MemoryCatalogProvider}, - schema::{MemorySchemaProvider, SchemaProvider}, - }, - datasource::{TableProvider, TableType}, -}; -use datafusion_expr::Expr; - -use super::*; - -#[tokio::test] -async fn information_schema_tables_tables_with_multiple_catalogs() { - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - schema - .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog - .register_schema("my_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog - .register_schema("my_other_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+------------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+------------------+--------------------+-------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| my_catalog | information_schema | columns | VIEW |", - "| my_catalog | information_schema | df_settings | VIEW |", - "| my_catalog | information_schema | tables | VIEW |", - "| my_catalog | information_schema | views | VIEW |", - "| my_catalog | my_schema | t1 | BASE TABLE |", - "| my_catalog | my_schema | t2 | BASE TABLE |", - "| my_other_catalog | information_schema | columns | VIEW |", - "| my_other_catalog | information_schema | df_settings | VIEW |", - "| my_other_catalog | information_schema | tables | VIEW |", - "| my_other_catalog | information_schema | views | VIEW |", - "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", - "+------------------+--------------------+-------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -#[tokio::test] -async fn information_schema_tables_table_types() { - struct TestTable(TableType); - - #[async_trait] - impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn table_type(&self) -> TableType { - self.0 - } - - fn schema(&self) -> SchemaRef { - unimplemented!() - } - - async fn scan( - &self, - _state: &SessionState, - _: Option<&Vec>, - _: &[Expr], - _: Option, - ) -> Result> { - unimplemented!() - } - } - - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - - ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) - .unwrap(); - ctx.register_table("query", Arc::new(TestTable(TableType::View))) - .unwrap(); - ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) - .unwrap(); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+-------------+-----------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+-----------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | public | physical | BASE TABLE |", - "| datafusion | public | query | VIEW |", - "| datafusion | public | temp | LOCAL TEMPORARY |", - "+---------------+--------------------+-------------+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -fn table_with_many_types() -> Arc { - let schema = Schema::new(vec![ - Field::new("int32_col", DataType::Int32, false), - Field::new("float64_col", DataType::Float64, true), - Field::new("utf8_col", DataType::Utf8, true), - Field::new("large_utf8_col", DataType::LargeUtf8, false), - Field::new("binary_col", DataType::Binary, false), - Field::new("large_binary_col", DataType::LargeBinary, false), - Field::new( - "timestamp_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Float64Array::from(vec![1.0])), - Arc::new(StringArray::from(vec![Some("foo")])), - Arc::new(LargeStringArray::from(vec![Some("bar")])), - Arc::new(BinaryArray::from(vec![b"foo" as &[u8]])), - Arc::new(LargeBinaryArray::from(vec![b"foo" as &[u8]])), - Arc::new(TimestampNanosecondArray::from(vec![Some(123)])), - ], - ) - .unwrap(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); - Arc::new(provider) -} - -#[tokio::test] -async fn information_schema_columns() { - let ctx = - SessionContext::with_config(SessionConfig::new().with_information_schema(true)); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) - .unwrap(); - - schema - .register_table("t2".to_owned(), table_with_many_types()) - .unwrap(); - catalog - .register_schema("my_schema", Arc::new(schema)) - .unwrap(); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let result = plan_and_collect(&ctx, "SELECT * from information_schema.columns") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", - "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", - "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", - "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); -} - -/// Execute SQL and return results -async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { - ctx.sql(sql).await?.collect().await -} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 118aeb043e44b..528bde632355b 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,161 +15,9 @@ // specific language governing permissions and limitations // under the License. -use super::*; - -#[tokio::test] -async fn test_join_timestamp() -> Result<()> { - let ctx = SessionContext::new(); - - // register time table - let timestamp_schema = Arc::new(Schema::new(vec![Field::new( - "time", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - )])); - let timestamp_data = RecordBatch::try_new( - timestamp_schema.clone(), - vec![Arc::new(TimestampNanosecondArray::from(vec![ - 131964190213133, - 131964190213134, - 131964190213135, - ]))], - )?; - ctx.register_batch("timestamp", timestamp_data)?; - - let sql = "SELECT * \ - FROM timestamp as a \ - JOIN (SELECT * FROM timestamp) as b \ - ON a.time = b.time \ - ORDER BY a.time"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+-------------------------------+-------------------------------+", - "| time | time |", - "+-------------------------------+-------------------------------+", - "| 1970-01-02T12:39:24.190213133 | 1970-01-02T12:39:24.190213133 |", - "| 1970-01-02T12:39:24.190213134 | 1970-01-02T12:39:24.190213134 |", - "| 1970-01-02T12:39:24.190213135 | 1970-01-02T12:39:24.190213135 |", - "+-------------------------------+-------------------------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn test_join_float32() -> Result<()> { - let ctx = SessionContext::new(); - - // register population table - let population_schema = Arc::new(Schema::new(vec![ - Field::new("city", DataType::Utf8, true), - Field::new("population", DataType::Float32, true), - ])); - let population_data = RecordBatch::try_new( - population_schema.clone(), - vec![ - Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])), - Arc::new(Float32Array::from(vec![838.698, 1778.934, 626.443])), - ], - )?; - ctx.register_batch("population", population_data)?; - - let sql = "SELECT * \ - FROM population as a \ - JOIN (SELECT * FROM population) as b \ - ON a.population = b.population \ - ORDER BY a.population"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+------+------------+------+------------+", - "| city | population | city | population |", - "+------+------------+------+------------+", - "| c | 626.443 | c | 626.443 |", - "| a | 838.698 | a | 838.698 |", - "| b | 1778.934 | b | 1778.934 |", - "+------+------------+------+------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn test_join_float64() -> Result<()> { - let ctx = SessionContext::new(); - - // register population table - let population_schema = Arc::new(Schema::new(vec![ - Field::new("city", DataType::Utf8, true), - Field::new("population", DataType::Float64, true), - ])); - let population_data = RecordBatch::try_new( - population_schema.clone(), - vec![ - Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])), - Arc::new(Float64Array::from(vec![838.698, 1778.934, 626.443])), - ], - )?; - ctx.register_batch("population", population_data)?; +use datafusion::test_util::register_unbounded_file_with_ordering; - let sql = "SELECT * \ - FROM population as a \ - JOIN (SELECT * FROM population) as b \ - ON a.population = b.population \ - ORDER BY a.population"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+------+------------+------+------------+", - "| city | population | city | population |", - "+------+------------+------+------------+", - "| c | 626.443 | c | 626.443 |", - "| a | 838.698 | a | 838.698 |", - "| b | 1778.934 | b | 1778.934 |", - "+------+------------+------+------------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -// TODO Tests to prove correct implementation of INNER JOIN's with qualified names. -// https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. -#[tokio::test] -#[ignore] -async fn inner_join_qualified_names() -> Result<()> { - // Setup the statements that test qualified names function correctly. - let equivalent_sql = [ - "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c - FROM t1 - INNER JOIN t2 ON t1.a = t2.a - ORDER BY t1.a", - "SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c - FROM t1 - INNER JOIN t2 ON t2.a = t1.a - ORDER BY t1.a", - ]; - - let expected = vec![ - "+---+----+----+---+-----+-----+", - "| a | b | c | a | b | c |", - "+---+----+----+---+-----+-----+", - "| 1 | 10 | 50 | 1 | 100 | 500 |", - "| 2 | 20 | 60 | 2 | 200 | 600 |", - "| 4 | 40 | 80 | 4 | 400 | 800 |", - "+---+----+----+---+-----+-----+", - ]; - - for sql in equivalent_sql.iter() { - let ctx = create_join_context_qualified("t1", "t2")?; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} +use super::*; #[tokio::test] #[ignore] @@ -177,24 +25,7 @@ async fn inner_join_qualified_names() -> Result<()> { async fn nestedjoin_with_alias() -> Result<()> { // repro case for https://github.com/apache/arrow-datafusion/issues/2867 let sql = "select * from ((select 1 as a, 2 as b) c INNER JOIN (select 1 as a, 3 as d) e on c.a = e.a) f;"; - let expected = vec![ - "+---+---+---+---+", - "| a | b | a | d |", - "+---+---+---+---+", - "| 1 | 2 | 1 | 3 |", - "+---+---+---+---+", - ]; - let ctx = SessionContext::new(); - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn nestedjoin_without_alias() -> Result<()> { - let sql = "select * from (select 1 as a, 2 as b) c INNER JOIN (select 1 as a, 3 as d) e on c.a = e.a;"; - let expected = vec![ + let expected = [ "+---+---+---+---+", "| a | b | a | d |", "+---+---+---+---+", @@ -208,275 +39,6 @@ async fn nestedjoin_without_alias() -> Result<()> { Ok(()) } -#[tokio::test] -async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Result<()> { - let ctx = SessionContext::new(); - let batch = RecordBatch::try_from_iter(vec![ - ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), - ( - "country", - Arc::new(StringArray::from(vec!["Germany", "Sweden", "Japan"])) as _, - ), - ]) - .unwrap(); - ctx.register_batch("countries", batch)?; - - let batch = RecordBatch::try_from_iter(vec![ - ( - "id", - Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])) as _, - ), - ( - "city", - Arc::new(StringArray::from(vec![ - "Hamburg", - "Stockholm", - "Osaka", - "Berlin", - "Göteborg", - "Tokyo", - "Kyoto", - ])) as _, - ), - ( - "country_id", - Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3, 3])) as _, - ), - ]) - .unwrap(); - - ctx.register_batch("cities", batch)?; - - // city.id is not in the on constraint, but the output result will contain both city.id and - // country.id - let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----+----+-----------+---------+", - "| id | id | city | country |", - "+----+----+-----------+---------+", - "| 1 | 1 | Hamburg | Germany |", - "| 2 | 2 | Stockholm | Sweden |", - "| 3 | 3 | Osaka | Japan |", - "| 4 | 1 | Berlin | Germany |", - "| 5 | 2 | Göteborg | Sweden |", - "| 6 | 3 | Tokyo | Japan |", - "| 7 | 3 | Kyoto | Japan |", - "+----+----+-----------+---------+", - ]; - - assert_batches_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn join_timestamp() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let expected = vec![ - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| nanos | micros | millis | secs | name | nanos | micros | millis | secs | name |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123 | 2011-12-13T11:13:10 | Row 1 | 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123 | 2011-12-13T11:13:10 | Row 1 |", - "| 2018-11-13T17:11:10.011375885 | 2018-11-13T17:11:10.011375 | 2018-11-13T17:11:10.011 | 2018-11-13T17:11:10 | Row 0 | 2018-11-13T17:11:10.011375885 | 2018-11-13T17:11:10.011375 | 2018-11-13T17:11:10.011 | 2018-11-13T17:11:10 | Row 0 |", - "| 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10 | Row 3 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10 | Row 3 |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - ]; - - let results = execute_to_batches( - &ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.nanos = t2.nanos", - ) - .await; - - assert_batches_sorted_eq!(expected, &results); - - let results = execute_to_batches( - &ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.micros = t2.micros", - ) - .await; - - assert_batches_sorted_eq!(expected, &results); - - let results = execute_to_batches( - &ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.millis = t2.millis", - ) - .await; - - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn left_join_should_not_panic_with_empty_side() -> Result<()> { - let ctx = SessionContext::new(); - - let t1_schema = Schema::new(vec![ - Field::new("t1_id", DataType::Int64, true), - Field::new("t1_value", DataType::Utf8, false), - ]); - let t1_data = RecordBatch::try_new( - Arc::new(t1_schema), - vec![ - Arc::new(Int64Array::from(vec![5247, 3821, 6321, 8821, 7748])), - Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Schema::new(vec![ - Field::new("t2_id", DataType::Int64, true), - Field::new("t2_value", DataType::Boolean, true), - ]); - let t2_data = RecordBatch::try_new( - Arc::new(t2_schema), - vec![ - Arc::new(Int64Array::from(vec![358, 2820, 3804, 7748])), - Arc::new(BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - ])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - let expected_left_join = vec![ - "+-------+----------+-------+----------+", - "| t1_id | t1_value | t2_id | t2_value |", - "+-------+----------+-------+----------+", - "| 5247 | a | | |", - "| 3821 | b | | |", - "| 6321 | c | | |", - "| 8821 | d | | |", - "| 7748 | e | 7748 | |", - "+-------+----------+-------+----------+", - ]; - - let results_left_join = - execute_to_batches(&ctx, "SELECT * FROM t1 LEFT JOIN t2 ON t1_id = t2_id").await; - assert_batches_sorted_eq!(expected_left_join, &results_left_join); - - let expected_right_join = vec![ - "+-------+----------+-------+----------+", - "| t2_id | t2_value | t1_id | t1_value |", - "+-------+----------+-------+----------+", - "| | | 3821 | b |", - "| | | 5247 | a |", - "| | | 6321 | c |", - "| | | 8821 | d |", - "| 7748 | | 7748 | e |", - "+-------+----------+-------+----------+", - ]; - - let result_right_join = - execute_to_batches(&ctx, "SELECT * FROM t2 RIGHT JOIN t1 ON t1_id = t2_id").await; - assert_batches_sorted_eq!(expected_right_join, &result_right_join); - - Ok(()) -} - -#[tokio::test] -async fn left_join_using_2() -> Result<()> { - let results = execute_with_partition( - "SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", - 1, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - - assert_batches_eq!(expected, &results); - Ok(()) -} - -#[tokio::test] -async fn left_join_using_join_key_projection() -> Result<()> { - let results = execute_with_partition( - "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", - 1, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+----+----+----+", - "| c1 | c2 | c2 |", - "+----+----+----+", - "| 0 | 1 | 1 |", - "| 0 | 2 | 2 |", - "| 0 | 3 | 3 |", - "| 0 | 4 | 4 |", - "| 0 | 5 | 5 |", - "| 0 | 6 | 6 |", - "| 0 | 7 | 7 |", - "| 0 | 8 | 8 |", - "| 0 | 9 | 9 |", - "| 0 | 10 | 10 |", - "+----+----+----+", - ]; - - assert_batches_eq!(expected, &results); - Ok(()) -} - -#[tokio::test] -async fn left_join_2() -> Result<()> { - let results = execute_with_partition( - "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2", - 1, - ) - .await?; - assert_eq!(results.len(), 1); - - let expected = vec![ - "+----+----+----+", - "| c1 | c2 | c2 |", - "+----+----+----+", - "| 0 | 1 | 1 |", - "| 0 | 2 | 2 |", - "| 0 | 3 | 3 |", - "| 0 | 4 | 4 |", - "| 0 | 5 | 5 |", - "| 0 | 6 | 6 |", - "| 0 | 7 | 7 |", - "| 0 | 8 | 8 |", - "| 0 | 9 | 9 |", - "| 0 | 10 | 10 |", - "+----+----+----+", - ]; - - assert_batches_eq!(expected, &results); - Ok(()) -} - #[tokio::test] async fn join_partitioned() -> Result<()> { // self join on partition id (workaround for duplicate column name) @@ -495,2193 +57,186 @@ async fn join_partitioned() -> Result<()> { } #[tokio::test] -async fn hash_join_with_date32() -> Result<()> { - let ctx = create_hashjoin_datatype_context()?; - - // inner join on hash supported data type (Date32) - let sql = "select * from t1 join t2 on t1.c1 = t2.c1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.c1 = t2.c1 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - "| 1970-01-02 | 1970-01-02T00:00:00 | 1.23 | abc | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "| 1970-01-04 | | -123.12 | jkl | 1970-01-04 | | 789.00 | |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn hash_join_with_date64() -> Result<()> { - let ctx = create_hashjoin_datatype_context()?; - - // left join on hash supported data type (Date64) - let sql = "select * from t1 left join t2 on t1.c2 = t2.c2"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Left Join: t1.c2 = t2.c2 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+---------+-----+------------+---------------------+---------+--------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+--------+", - "| | 1970-01-04T00:00:00 | 789.00 | ghi | | 1970-01-04T00:00:00 | 0.00 | qwerty |", - "| 1970-01-02 | 1970-01-02T00:00:00 | 1.23 | abc | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "| 1970-01-03 | 1970-01-03T00:00:00 | 456.00 | def | | | | |", - "| 1970-01-04 | | -123.12 | jkl | | | | |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn hash_join_with_decimal() -> Result<()> { - let ctx = create_hashjoin_datatype_context()?; - - // right join on hash supported data type (Decimal) - let sql = "select * from t1 right join t2 on t1.c3 = t2.c3"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - "| | | | | | | 100000.00 | abcdefg |", - "| | | | | | 1970-01-04T00:00:00 | 0.00 | qwerty |", - "| | 1970-01-04T00:00:00 | 789.00 | ghi | 1970-01-04 | | 789.00 | |", - "| 1970-01-04 | | -123.12 | jkl | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn hash_join_with_dictionary() -> Result<()> { - let ctx = create_hashjoin_datatype_context()?; - - // inner join on hash supported data type (Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))) - let sql = "select * from t1 join t2 on t1.c4 = t2.c4"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.c4 = t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", - " TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+------+-----+------------+---------------------+---------+-----+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+------+-----+------------+---------------------+---------+-----+", - "| 1970-01-02 | 1970-01-02T00:00:00 | 1.23 | abc | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "+------------+---------------------+------+-----+------------+---------------------+---------+-----+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn reduce_left_join_1() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to inner join - let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_id < 100"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_left_join_2() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to inner join - let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` - // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` - // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.t1_id = t2.t2_id Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_left_join_3() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; +#[ignore = "Test ignored, will be enabled after fixing the NAAJ bug"] +// https://github.com/apache/arrow-datafusion/issues/4211 +async fn null_aware_left_anti_join() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_left_semi_anti_join_context_with_null_ids( + "t1_id", + "t2_id", + repartition_joins, + ) + .unwrap(); - // reduce subquery to inner join - let sql = "select * from (select t1.* from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 3) t3 left join t2 on t3.t1_int = t2.t2_int where t3.t1_id < 100"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N]", - " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id NOT IN (SELECT t2_id FROM t2) ORDER BY t1_id"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = ["++", "++"]; + assert_batches_eq!(expected, &actual); + } Ok(()) } #[tokio::test] -async fn reduce_right_join_1() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; +async fn join_change_in_planner() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::new_with_config(config); + let tmp_dir = TempDir::new().unwrap(); + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone()).unwrap(); + // Create schema + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + // Specify the ordering: + let file_sort_order = vec![[datafusion_expr::col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>()]; + register_unbounded_file_with_ordering( + &ctx, + schema.clone(), + &left_file_path, + "left", + file_sort_order.clone(), + true, + ) + .await?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone()).unwrap(); + register_unbounded_file_with_ordering( + &ctx, + schema, + &right_file_path, + "right", + file_sort_order, + true, + ) + .await?; + let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); + let expected = { + [ + "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" + ] + }; + let mut actual: Vec<&str> = formatted.trim().lines().collect(); + // Remove CSV lines + actual.remove(4); + actual.remove(7); - // reduce to inner join - let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where t1.t1_int is not null"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_int IS NOT NULL [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); assert_eq!( - expected, actual, + expected, + actual[..], "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - Ok(()) } #[tokio::test] -async fn reduce_right_join_2() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to inner join - let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where not(t1.t1_int = t2.t2_int)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.t1_id = t2.t2_id Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_full_join_to_right_join() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to right join - let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t2.t2_name is not null"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Right Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_name IS NOT NULL [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_full_join_to_left_join() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to left join - let sql = - "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b'"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Left Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_full_join_to_inner_join() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to inner join - let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b' and t2.t2_name = 'x'"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_name = Utf8(\"x\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn sort_merge_equijoin() -> Result<()> { - let ctx = create_sort_merge_join_context("t1_id", "t2_id")?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn sort_merge_join_on_date32() -> Result<()> { - let ctx = create_sort_merge_join_datatype_context()?; - - // inner sort merge join on data type (Date32) - let sql = "select * from t1 join t2 on t1.c1 = t2.c1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = vec![ - "SortMergeJoin: join_type=Inner, on=[(Column { name: \"c1\", index: 0 }, Column { name: \"c1\", index: 0 })]", - " SortExec: expr=[c1@0 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " SortExec: expr=[c1@0 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - "| 1970-01-02 | 1970-01-02T00:00:00 | 1.23 | abc | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "| 1970-01-04 | | -123.12 | jkl | 1970-01-04 | | 789.00 | |", - "+------------+---------------------+---------+-----+------------+---------------------+---------+-----+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn sort_merge_join_on_decimal() -> Result<()> { - let ctx = create_sort_merge_join_datatype_context()?; - - // right join on data type (Decimal) - let sql = "select * from t1 right join t2 on t1.c3 = t2.c3"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = vec![ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]", - " SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]", - " SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2), input_partitions=2", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " SortExec: expr=[c3@2 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - "| c1 | c2 | c3 | c4 | c1 | c2 | c3 | c4 |", - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - "| | | | | | | 100000.00 | abcdefg |", - "| | | | | | 1970-01-04T00:00:00 | 0.00 | qwerty |", - "| | 1970-01-04T00:00:00 | 789.00 | ghi | 1970-01-04 | | 789.00 | |", - "| 1970-01-04 | | -123.12 | jkl | 1970-01-02 | 1970-01-02T00:00:00 | -123.12 | abc |", - "+------------+---------------------+---------+-----+------------+---------------------+-----------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn left_semi_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_left_semi_anti_join_context_with_null_ids( - "t1_id", - "t2_id", - repartition_joins, - ) - .unwrap(); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id IN (SELECT t2_id FROM t2) ORDER BY t1_id"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = if repartition_joins { - vec![ - "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " MemoryExec: partitions=1, partition_sizes=[1]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 11 | a |", - "| 22 | b |", - "| 44 | d |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT 1 FROM t2 WHERE t1_id = t2_id) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 11 | a |", - "| 22 | b |", - "| 44 | d |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id FROM t1 INTERSECT SELECT t2_id FROM t2 ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+", - "| t1_id |", - "+-------+", - "| 11 |", - "| 22 |", - "| 44 |", - "| |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id, t1_name FROM t1 LEFT SEMI JOIN t2 ON (t1_id = t2_id) ORDER BY t1_id"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = if repartition_joins { - vec![ - "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " MemoryExec: partitions=1, partition_sizes=[1]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 11 | a |", - "| 22 | b |", - "| 44 | d |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn left_semi_join_pushdown() -> Result<()> { - let ctx = create_left_semi_anti_join_context_with_null_ids("t1_id", "t2_id", false) - .unwrap(); - - // assert logical plan - let sql = "SELECT t1.t1_id, t1.t1_name FROM t1 LEFT SEMI JOIN t2 ON (t1.t1_id = t2.t2_id and t2.t2_int > 1)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: t2.t2_int > UInt32(1) [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn left_anti_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_left_semi_anti_join_context_with_null_ids( - "t1_id", - "t2_id", - repartition_joins, - ) - .unwrap(); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT 1 FROM t2 WHERE t1_id = t2_id) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 33 | c |", - "| | e |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id FROM t1 EXCEPT SELECT t2_id FROM t2 ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+", - "| t1_id |", - "+-------+", - "| 33 |", - "+-------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id, t1_name FROM t1 LEFT ANTI JOIN t2 ON (t1_id = t2_id) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 33 | c |", - "| | e |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn error_left_anti_join() -> Result<()> { - // https://github.com/apache/arrow-datafusion/issues/4366 - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_left_semi_anti_join_context_with_null_ids( - "t1_id", - "t2_id", - repartition_joins, - ) - .unwrap(); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT 1 FROM t2 WHERE t1_id = t2_id and t1_id > 11) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+", - "| t1_id | t1_name |", - "+-------+---------+", - "| 11 | a |", - "| 11 | a |", - "| 33 | c |", - "| | e |", - "+-------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -#[ignore = "Test ignored, will be enabled after fixing the NAAJ bug"] -// https://github.com/apache/arrow-datafusion/issues/4211 -async fn null_aware_left_anti_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_left_semi_anti_join_context_with_null_ids( - "t1_id", - "t2_id", - repartition_joins, - ) - .unwrap(); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id NOT IN (SELECT t2_id FROM t2) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec!["++", "++"]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn right_semi_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_right_semi_anti_join_context_with_null_ids( - "t1_id", - "t2_id", - repartition_joins, - ) - .unwrap(); - - let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE EXISTS (SELECT * FROM t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = if repartition_joins { - vec![ - "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=t2_name@1 != t1_name@0", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=t2_name@1 != t1_name@0", - " MemoryExec: partitions=1, partition_sizes=[1]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - assert_batches_eq!(expected, &actual); - - let sql = "SELECT t1_id, t1_name, t1_int FROM t2 RIGHT SEMI JOIN t1 on (t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = if repartition_joins { - vec![ - "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=t2_name@0 != t1_name@1", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=t2_name@0 != t1_name@1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn join_and_aggregate_on_same_key() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - let sql = "select distinct(t1.t1_id) from t1 inner join t2 on t1.t1_id = t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Aggregate: groupBy=[[t1.t1_id]], aggr=[[]] [t1_id:UInt32;N]", - " Projection: t1.t1_id [t1_id:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let msg = format!("Creating physical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = - vec![ - "AggregateExec: mode=Single, gby=[t1_id@0 as t1_id], aggr=[]", - " ProjectionExec: expr=[t1_id@0 as t1_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let sql = "select count(*) from (select * from t1 inner join t2 on t1.t1_id = t2.t2_id) group by t1_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]", - " Aggregate: groupBy=[[t1.t1_id]], aggr=[[COUNT(UInt8(1))]] [t1_id:UInt32;N, COUNT(UInt8(1)):Int64;N]", - " Projection: t1.t1_id [t1_id:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let msg = format!("Creating physical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = - vec![ - "ProjectionExec: expr=[COUNT(UInt8(1))@1 as COUNT(UInt8(1))]", - " AggregateExec: mode=Single, gby=[t1_id@0 as t1_id], aggr=[COUNT(UInt8(1))]", - " ProjectionExec: expr=[t1_id@0 as t1_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let sql = - "select count(distinct t1.t1_id) from t1 inner join t2 on t1.t1_id = t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: COUNT(alias1) AS COUNT(DISTINCT t1.t1_id) [COUNT(DISTINCT t1.t1_id):Int64;N]", - " Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]", - " Aggregate: groupBy=[[t1.t1_id AS alias1]], aggr=[[]] [alias1:UInt32;N]", - " Projection: t1.t1_id [t1_id:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // the Partial and FinalPartitioned Aggregate are not combined to Single Aggregate due to group by exprs are different - // TODO improve ReplaceDistinctWithAggregate rule to avoid unnecessary alias Cast - let msg = format!("Creating physical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = - vec![ - "ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT t1.t1_id)]", - " AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)]", - " CoalescePartitionsExec", - " AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)]", - " AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]", - " AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]", - " ProjectionExec: expr=[t1_id@0 as t1_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - // reduce to inner join - let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) = CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = - "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]", - " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn reduce_cross_join_with_wildcard_and_expr() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "select *,t1.t1_id+11 from t1,t2 where t1.t1_id+11=t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t1.t1_id + Int64(11):Int64;N]", - " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]" - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert physical plan - let msg = format!("Creating physical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - let expected = if repartition_joins { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name: \"CAST(t2.t2_id AS Int64)\", index: 3 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + Int64(11)\", index: 3 }], 2), input_partitions=2", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"CAST(t2.t2_id AS Int64)\", index: 3 }], 2), input_partitions=2", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS Int64)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name: \"CAST(t2.t2_id AS Int64)\", index: 3 })]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS Int64)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn both_side_expr_key_inner_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \ - FROM t1 \ - INNER JOIN t2 \ - ON t1.t1_id + cast(12 as INT UNSIGNED) = t2.t2_id + cast(1 as INT UNSIGNED)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - - let expected = if repartition_joins { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }, Column { name: \"t2.t2_id + UInt32(1)\", index: 1 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }], 2), input_partitions=2", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as t1.t1_id + UInt32(12)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id + UInt32(1)\", index: 1 }], 2), input_partitions=2", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as t2.t2_id + UInt32(1)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }, Column { name: \"t2.t2_id + UInt32(1)\", index: 1 })]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as t1.t1_id + UInt32(12)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as t2.t2_id + UInt32(1)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn left_side_expr_key_inner_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \ - FROM t1 \ - INNER JOIN t2 \ - ON t1.t1_id + cast(11 as INT UNSIGNED) = t2.t2_id"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - - let expected = if repartition_joins { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }], 2), input_partitions=2", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as t1.t1_id + UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as t1.t1_id + UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn right_side_expr_key_inner_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \ - FROM t1 \ - INNER JOIN t2 \ - ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - - let expected = if repartition_joins { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 1 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - UInt32(11)\", index: 1 }], 2), input_partitions=2", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 1 })]", - " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "SELECT * \ - FROM t1 \ - INNER JOIN t2 \ - ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - - let expected = if repartition_joins { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 3 })]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - UInt32(11)\", index: 3 }], 2), input_partitions=2", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - } else { - vec![ - "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 3 })]", - " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ] - }; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } - - Ok(()) -} - -#[tokio::test] -async fn join_with_type_coercion_for_equi_expr() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn join_only_with_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ - from t1 \ - inner join t2 \ - on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()> { - // Regression test for GH #4873 - let col1 = Arc::new(StringArray::from(vec![ - "A", "A", "A", "A", "A", "A", "A", "A", "BB", "BB", "BB", "BB", - ])) as ArrayRef; - - let col2 = - Arc::new(UInt64Array::from(vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) as ArrayRef; - - let col3 = - Arc::new(UInt64Array::from(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) as ArrayRef; - +async fn join_change_in_planner_without_sort() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::new_with_config(config); + let tmp_dir = TempDir::new()?; + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone())?; let schema = Arc::new(Schema::new(vec![ - Field::new("col1", DataType::Utf8, true), - Field::new("col2", DataType::UInt64, true), - Field::new("col3", DataType::UInt64, true), - ])) as SchemaRef; - - let batch = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap(); - let mem_table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); - - // Create context and register table - let ctx = SessionContext::new(); - ctx.register_table("tbl", Arc::new(mem_table)).unwrap(); - - let sql = "select col1, col2, coalesce(sum_col3, 0) as sum_col3 \ - from (select distinct col2 from tbl) AS q1 \ - cross join (select distinct col1 from tbl) AS q2 \ - left outer join (SELECT col1, col2, sum(col3) as sum_col3 FROM tbl GROUP BY col1, col2) AS q3 \ - USING(col2, col1) \ - ORDER BY col1, col2"; - - let expected = vec![ - "+------+------+----------+", - "| col1 | col2 | sum_col3 |", - "+------+------+----------+", - "| A | 1 | 2.0 |", - "| A | 2 | 2.0 |", - "| A | 3 | 2.0 |", - "| A | 4 | 2.0 |", - "| A | 5 | 0.0 |", - "| A | 6 | 0.0 |", - "| BB | 1 | 0.0 |", - "| BB | 2 | 0.0 |", - "| BB | 3 | 0.0 |", - "| BB | 4 | 0.0 |", - "| BB | 5 | 2.0 |", - "| BB | 6 | 2.0 |", - "+------+------+----------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_to_join_with_both_side_expr() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "| 33 | c | 3 |", - "| 44 | d | 4 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_to_join_with_muti_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in - (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", - " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "| 33 | c | 3 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn three_projection_exprs_subquery_to_join() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in - (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", - " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "| 33 | c | 3 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in - (select t2.t2_id + 1 from t2 where t1.t1_int > 0)"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - Ok(()) -} - -#[tokio::test] -async fn not_in_subquery_to_join_with_correlated_outer_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 not in - (select t2.t2_id + 1 from t2 where t1.t1_int > 0)"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftAnti Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - Ok(()) -} - -#[tokio::test] -async fn in_subquery_to_join_with_outer_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in - (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "| 33 | c | 3 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", false)?; - - let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in - (select t2.t2_id + 1 from t2) - and t1.t1_int in(select t2.t2_int + 1 from t2) - and t1.t1_id > 0"; - - // assert logical plan - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan().unwrap(); - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - " SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]", - " Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]", - " TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]", - ]; - - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 44 | d | 4 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn right_as_inner_table_nested_loop_join() -> Result<()> { - let ctx = create_nested_loop_join_context()?; - - // Distribution: left is `UnspecifiedDistribution`, right is `SinglePartition`. - let sql = "SELECT t1.t1_id, t2.t2_id - FROM t1 INNER JOIN t2 ON t1.t1_id > t2.t2_id - WHERE t1.t1_id > 10 AND t2.t2_int > 1"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await?; - - // right is single partition side, so it will be visited many times. - let expected = vec![ - "NestedLoopJoinExec: join_type=Inner, filter=BinaryExpr { left: Column { name: \"t1_id\", index: 0 }, op: Gt, right: Column { name: \"t2_id\", index: 1 } }", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: t1_id@0 > 10", - " RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[t2_id@0 as t2_id]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: t2_int@1 > 1", - " RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+-------+", - "| t1_id | t2_id |", - "+-------+-------+", - "| 22 | 11 |", - "| 33 | 11 |", - "| 44 | 11 |", - "+-------+-------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn left_as_inner_table_nested_loop_join() -> Result<()> { - let ctx = create_nested_loop_join_context()?; - - // Distribution: left is `SinglePartition`, right is `UnspecifiedDistribution`. - let sql = "SELECT t1.t1_id,t2.t2_id FROM (select t1_id from t1 where t1.t1_id > 22) as t1 - RIGHT JOIN (select t2_id from t2 where t2.t2_id > 11) as t2 - ON t1.t1_id < t2.t2_id"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + ctx.register_csv( + "left", + left_file_path.as_os_str().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).mark_infinite(true), + ) + .await?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone())?; + ctx.register_csv( + "right", + right_file_path.as_os_str().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).mark_infinite(true), + ) + .await?; + let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; + let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; - - // left is single partition side, so it will be visited many times. - let expected = vec![ - "NestedLoopJoinExec: join_type=Right, filter=BinaryExpr { left: Column { name: \"t1_id\", index: 0 }, op: Lt, right: Column { name: \"t2_id\", index: 1 } }", - " CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: t1_id@0 > 22", - " RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: t2_id@0 > 11", - " RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected = vec![ - "+-------+-------+", - "| t1_id | t2_id |", - "+-------+-------+", - "| | 22 |", - "| 33 | 44 |", - "| 33 | 55 |", - "| 44 | 55 |", - "+-------+-------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn exists_subquery_to_join_expr_filter() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // exists subquery to LeftSemi join - let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 22 | b | 2 |", - "| 33 | c | 3 |", - "| 44 | d | 4 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn exists_subquery_to_join_inner_filter() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // exists subquery to LeftSemi join - let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - // `t2.t2_int < 3` will be kept in the subquery filter. - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 44 | d | 4 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn exists_subquery_to_join_outer_filter() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // exists subquery to LeftSemi join - let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - // `t1.t1_int < 3` will be moved to the filter of t1. - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 22 | b | 2 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn not_exists_subquery_to_join_expr_filter() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // not exists subquery to LeftAnti join - let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn exists_distinct_subquery_to_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn exists_distinct_subquery_to_join_with_expr() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // `t2_id + t2_int` is in the subquery project. - let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_id + t2_int, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn exists_distinct_subquery_to_join_with_literal() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // `1` is in the subquery project. - let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT 1, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+", - "| t1_id | t1_name | t1_int |", - "+-------+---------+--------+", - "| 11 | a | 1 |", - "+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); + let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); + let expected = { + [ + "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" + ] + }; + let mut actual: Vec<&str> = formatted.trim().lines().collect(); + // Remove CSV lines + actual.remove(4); + actual.remove(7); + + assert_eq!( + expected, + actual[..], + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) +} + +#[tokio::test] +async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { + let config = SessionConfig::new() + .with_target_partitions(8) + .with_allow_symmetric_joins_without_pruning(false); + let ctx = SessionContext::new_with_config(config); + let tmp_dir = TempDir::new()?; + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone())?; + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + ctx.register_csv( + "left", + left_file_path.as_os_str().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).mark_infinite(true), + ) + .await?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone())?; + ctx.register_csv( + "right", + right_file_path.as_os_str().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).mark_infinite(true), + ) + .await?; + let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; + match df.create_physical_plan().await { + Ok(_) => panic!("Expecting error."), + Err(e) => { + assert_eq!(e.strip_backtrace(), "PipelineChecker\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + } } - Ok(()) } diff --git a/datafusion/core/tests/sql/limit.rs b/datafusion/core/tests/sql/limit.rs deleted file mode 100644 index a4247492b4158..0000000000000 --- a/datafusion/core/tests/sql/limit.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; - -#[tokio::test] -async fn limit() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) - .unwrap(); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i DESC limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = vec![ - "+------+", - "| i |", - "+------+", - "| 1000 |", - "| 999 |", - "| 998 |", - "+------+", - ]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t ORDER BY i limit 3") - .await - .unwrap(); - - #[rustfmt::skip] - let expected = vec![ - "+---+", - "| i |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; - - assert_batches_eq!(expected, &results); - - let results = plan_and_collect(&ctx, "SELECT i FROM t limit 3") - .await - .unwrap(); - - // the actual rows are not guaranteed, so only check the count (should be 3) - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 3); - - Ok(()) -} - -#[tokio::test] -async fn limit_multi_partitions() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; - - let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], - ]; - let schema = partitions[0][0].schema(); - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - - ctx.register_table("t", provider).unwrap(); - - // select all rows - let results = plan_and_collect(&ctx, "SELECT i FROM t").await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 15); - - for limit in 1..10 { - let query = format!("SELECT i FROM t limit {limit}"); - let results = plan_and_collect(&ctx, &query).await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, limit, "mismatch with query {query}"); - } - - Ok(()) -} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 943254ca463b4..94fc8015a78a4 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::convert::TryFrom; use std::sync::Arc; use arrow::{ @@ -25,8 +24,8 @@ use arrow::{ use chrono::prelude::*; use chrono::Duration; -use datafusion::config::ConfigOptions; use datafusion::datasource::TableProvider; +use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; @@ -35,14 +34,9 @@ use datafusion::prelude::*; use datafusion::test_util; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{datasource::MemTable, physical_plan::collect}; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::ColumnarValue, -}; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; -use datafusion_common::cast::as_float64_array; +use datafusion_common::plan_err; use datafusion_common::{assert_contains, assert_not_contains}; -use datafusion_expr::Volatility; use object_store::path::Path; use std::fs::File; use std::io::Write; @@ -78,81 +72,29 @@ macro_rules! test_expression { } pub mod aggregates; -pub mod arrow_files; -#[cfg(feature = "avro")] pub mod create_drop; +pub mod csv_files; pub mod explain_analyze; pub mod expr; pub mod group_by; pub mod joins; -pub mod limit; pub mod order; pub mod parquet; +pub mod parquet_schema; +pub mod partitioned_csv; pub mod predicates; -pub mod projection; pub mod references; +pub mod repartition; pub mod select; +mod sql_api; pub mod timestamp; -pub mod udf; - -pub mod information_schema; -pub mod parquet_schema; -pub mod partitioned_csv; -pub mod subqueries; - -fn assert_float_eq(expected: &[Vec], received: &[Vec]) -where - T: AsRef, -{ - expected - .iter() - .flatten() - .zip(received.iter().flatten()) - .for_each(|(l, r)| { - let (l, r) = ( - l.as_ref().parse::().unwrap(), - r.as_str().parse::().unwrap(), - ); - if l.is_nan() || r.is_nan() { - assert!(l.is_nan() && r.is_nan()); - } else if (l - r).abs() > 2.0 * f64::EPSILON { - panic!("{l} != {r}") - } - }); -} - -fn create_ctx() -> SessionContext { - let ctx = SessionContext::new(); - - // register a custom UDF - ctx.register_udf(create_udf( - "custom_sqrt", - vec![DataType::Float64], - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(custom_sqrt), - )); - - ctx -} - -fn custom_sqrt(args: &[ColumnarValue]) -> Result { - let arg = &args[0]; - if let ColumnarValue::Array(v) = arg { - let input = as_float64_array(v).expect("cast failed"); - let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); - Ok(ColumnarValue::Array(Arc::new(array))) - } else { - unimplemented!() - } -} fn create_join_context( column_left: &str, column_right: &str, repartition_joins: bool, ) -> Result { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new() .with_repartition_joins(repartition_joins) .with_target_partitions(2) @@ -202,88 +144,12 @@ fn create_join_context( Ok(ctx) } -fn create_sub_query_join_context( - column_outer: &str, - column_inner_left: &str, - column_inner_right: &str, - repartition_joins: bool, -) -> Result { - let ctx = SessionContext::with_config( - SessionConfig::new() - .with_repartition_joins(repartition_joins) - .with_target_partitions(2) - .with_batch_size(4096), - ); - - let t0_schema = Arc::new(Schema::new(vec![ - Field::new(column_outer, DataType::UInt32, true), - Field::new("t0_name", DataType::Utf8, true), - Field::new("t0_int", DataType::UInt32, true), - ])); - let t0_data = RecordBatch::try_new( - t0_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - ], - )?; - ctx.register_batch("t0", t0_data)?; - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_inner_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - Field::new("t1_int", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_inner_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - Field::new("t2_int", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), - Arc::new(StringArray::from(vec![ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - Arc::new(UInt32Array::from(vec![3, 1, 3, 3])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - fn create_left_semi_anti_join_context_with_null_ids( column_left: &str, column_right: &str, repartition_joins: bool, ) -> Result { - let ctx = SessionContext::with_config( + let ctx = SessionContext::new_with_config( SessionConfig::new() .with_repartition_joins(repartition_joins) .with_target_partitions(2) @@ -351,348 +217,6 @@ fn create_left_semi_anti_join_context_with_null_ids( Ok(ctx) } -fn create_right_semi_anti_join_context_with_null_ids( - column_left: &str, - column_right: &str, - repartition_joins: bool, -) -> Result { - let ctx = SessionContext::with_config( - SessionConfig::new() - .with_repartition_joins(repartition_joins) - .with_target_partitions(2) - .with_batch_size(4096), - ); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - Field::new("t1_int", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![ - Some(11), - Some(22), - Some(33), - Some(44), - None, - ])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - Some("e"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 0])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - ])); - // t2 data size is smaller than t1 - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![Some(11), Some(11), None])), - Arc::new(StringArray::from(vec![Some("a"), Some("x"), None])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - -fn create_join_context_qualified( - left_name: &str, - right_name: &str, -) -> Result { - let ctx = SessionContext::new(); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, true), - Field::new("b", DataType::UInt32, true), - Field::new("c", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - Arc::new(UInt32Array::from(vec![10, 20, 30, 40])), - Arc::new(UInt32Array::from(vec![50, 60, 70, 80])), - ], - )?; - ctx.register_batch(left_name, t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, true), - Field::new("b", DataType::UInt32, true), - Field::new("c", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![1, 2, 9, 4])), - Arc::new(UInt32Array::from(vec![100, 200, 300, 400])), - Arc::new(UInt32Array::from(vec![500, 600, 700, 800])), - ], - )?; - ctx.register_batch(right_name, t2_data)?; - - Ok(ctx) -} - -fn create_hashjoin_datatype_context() -> Result { - let ctx = SessionContext::new(); - - let t1_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal128(5, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - ]); - let dict1: DictionaryArray = - vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); - let t1_data = RecordBatch::try_new( - Arc::new(t1_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - Some(172800000), - Some(259200000), - None, - ])), - Arc::new( - Decimal128Array::from_iter_values([123, 45600, 78900, -12312]) - .with_precision_and_scale(5, 2) - .unwrap(), - ), - Arc::new(dict1), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal128(10, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - ]); - let dict2: DictionaryArray = - vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); - let t2_data = RecordBatch::try_new( - Arc::new(t2_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - None, - Some(259200000), - None, - ])), - Arc::new( - Decimal128Array::from_iter_values([-12312, 10000000, 0, 78900]) - .with_precision_and_scale(10, 2) - .unwrap(), - ), - Arc::new(dict2), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - -fn create_sort_merge_join_context( - column_left: &str, - column_right: &str, -) -> Result { - let mut config = ConfigOptions::new(); - config.optimizer.prefer_hash_join = false; - - let ctx = SessionContext::with_config(config.into()); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new(column_left, DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - Field::new("t1_int", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new(column_right, DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - Field::new("t2_int", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), - Arc::new(StringArray::from(vec![ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - Arc::new(UInt32Array::from(vec![3, 1, 3, 3])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - -fn create_sort_merge_join_datatype_context() -> Result { - let mut config = ConfigOptions::new(); - config.optimizer.prefer_hash_join = false; - config.execution.target_partitions = 2; - config.execution.batch_size = 4096; - - let ctx = SessionContext::with_config(config.into()); - - let t1_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal128(5, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - ]); - let dict1: DictionaryArray = - vec!["abc", "def", "ghi", "jkl"].into_iter().collect(); - let t1_data = RecordBatch::try_new( - Arc::new(t1_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - Some(172800000), - Some(259200000), - None, - ])), - Arc::new( - Decimal128Array::from_iter_values([123, 45600, 78900, -12312]) - .with_precision_and_scale(5, 2) - .unwrap(), - ), - Arc::new(dict1), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Schema::new(vec![ - Field::new("c1", DataType::Date32, true), - Field::new("c2", DataType::Date64, true), - Field::new("c3", DataType::Decimal128(10, 2), true), - Field::new( - "c4", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - ), - ]); - let dict2: DictionaryArray = - vec!["abc", "abcdefg", "qwerty", ""].into_iter().collect(); - let t2_data = RecordBatch::try_new( - Arc::new(t2_schema), - vec![ - Arc::new(Date32Array::from(vec![Some(1), None, None, Some(3)])), - Arc::new(Date64Array::from(vec![ - Some(86400000), - None, - Some(259200000), - None, - ])), - Arc::new( - Decimal128Array::from_iter_values([-12312, 10000000, 0, 78900]) - .with_precision_and_scale(10, 2) - .unwrap(), - ), - Arc::new(dict2), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - -fn create_nested_loop_join_context() -> Result { - let ctx = SessionContext::with_config( - SessionConfig::new() - .with_target_partitions(4) - .with_batch_size(4096), - ); - - let t1_schema = Arc::new(Schema::new(vec![ - Field::new("t1_id", DataType::UInt32, true), - Field::new("t1_name", DataType::Utf8, true), - Field::new("t1_int", DataType::UInt32, true), - ])); - let t1_data = RecordBatch::try_new( - t1_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 33, 44])), - Arc::new(StringArray::from(vec![ - Some("a"), - Some("b"), - Some("c"), - Some("d"), - ])), - Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - ], - )?; - ctx.register_batch("t1", t1_data)?; - - let t2_schema = Arc::new(Schema::new(vec![ - Field::new("t2_id", DataType::UInt32, true), - Field::new("t2_name", DataType::Utf8, true), - Field::new("t2_int", DataType::UInt32, true), - ])); - let t2_data = RecordBatch::try_new( - t2_schema, - vec![ - Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), - Arc::new(StringArray::from(vec![ - Some("z"), - Some("y"), - Some("x"), - Some("w"), - ])), - Arc::new(UInt32Array::from(vec![3, 1, 3, 3])), - ], - )?; - ctx.register_batch("t2", t2_data)?; - - Ok(ctx) -} - fn get_tpch_table_schema(table: &str) -> Schema { match table { "customer" => Schema::new(vec![ @@ -824,10 +348,7 @@ async fn register_tpch_csv_data( DataType::Decimal128(_, _) => { cols.push(Box::new(Decimal128Builder::with_capacity(records.len()))) } - _ => { - let msg = format!("Not implemented: {}", field.data_type()); - Err(DataFusionError::Plan(msg))? - } + _ => plan_err!("Not implemented: {}", field.data_type())?, } } @@ -865,10 +386,7 @@ async fn register_tpch_csv_data( let value_i128 = val.parse::().unwrap(); sb.append_value(value_i128); } - _ => Err(DataFusionError::Plan(format!( - "Not implemented: {}", - field.data_type() - )))?, + _ => plan_err!("Not implemented: {}", field.data_type())?, } } } @@ -933,23 +451,6 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { ); } -async fn register_aggregate_simple_csv(ctx: &SessionContext) -> Result<()> { - // It's not possible to use aggregate_test_100 as it doesn't have enough similar values to test grouping on floats. - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Float32, false), - Field::new("c2", DataType::Float64, false), - Field::new("c3", DataType::Boolean, false), - ])); - - ctx.register_csv( - "aggregate_simple", - "tests/data/aggregate_simple.csv", - CsvReadOptions::new().schema(&schema), - ) - .await?; - Ok(()) -} - async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); @@ -998,7 +499,8 @@ async fn create_ctx_with_partition( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -1042,21 +544,12 @@ fn populate_csv_partitions( Ok(schema) } -/// Return a RecordBatch with a single Int32 array with values (0..sz) -pub fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; - - RecordBatch::try_new(schema, vec![arr]).unwrap() -} - /// Specialised String representation fn col_str(column: &ArrayRef, row_index: usize) -> String { - if column.is_null(row_index) { + // NullArray::is_null() does not work on NullArray. + // can remove check for DataType::Null when + // https://github.com/apache/arrow-rs/issues/4835 is fixed + if column.data_type() == &DataType::Null || column.is_null(row_index) { return "NULL".to_string(); } @@ -1247,93 +740,6 @@ fn normalize_vec_for_explain(v: Vec>) -> Vec> { .collect::>() } -/// Return a new table provider containing all of the supported timestamp types -pub fn table_with_timestamps() -> Arc { - let batch = make_timestamps(); - let schema = batch.schema(); - let partitions = vec![vec![batch]]; - Arc::new(MemTable::try_new(schema, partitions).unwrap()) -} - -/// Return record batch with all of the supported timestamp types -/// values -/// -/// Columns are named: -/// "nanos" --> TimestampNanosecondArray -/// "micros" --> TimestampMicrosecondArray -/// "millis" --> TimestampMillisecondArray -/// "secs" --> TimestampSecondArray -/// "names" --> StringArray -pub fn make_timestamps() -> RecordBatch { - let ts_strings = vec![ - Some("2018-11-13T17:11:10.011375885995"), - Some("2011-12-13T11:13:10.12345"), - None, - Some("2021-1-1T05:11:10.432"), - ]; - - let ts_nanos = ts_strings - .into_iter() - .map(|t| { - t.map(|t| { - t.parse::() - .unwrap() - .timestamp_nanos() - }) - }) - .collect::>(); - - let ts_micros = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) - .collect::>(); - - let ts_millis = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) - .collect::>(); - - let ts_secs = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) - .collect::>(); - - let names = ts_nanos - .iter() - .enumerate() - .map(|(i, _)| format!("Row {i}")) - .collect::>(); - - let arr_nanos = TimestampNanosecondArray::from(ts_nanos); - let arr_micros = TimestampMicrosecondArray::from(ts_micros); - let arr_millis = TimestampMillisecondArray::from(ts_millis); - let arr_secs = TimestampSecondArray::from(ts_secs); - - let names = names.iter().map(|s| s.as_str()).collect::>(); - let arr_names = StringArray::from(names); - - let schema = Schema::new(vec![ - Field::new("nanos", arr_nanos.data_type().clone(), true), - Field::new("micros", arr_micros.data_type().clone(), true), - Field::new("millis", arr_millis.data_type().clone(), true), - Field::new("secs", arr_secs.data_type().clone(), true), - Field::new("name", arr_names.data_type().clone(), true), - ]); - let schema = Arc::new(schema); - - RecordBatch::try_new( - schema, - vec![ - Arc::new(arr_nanos), - Arc::new(arr_micros), - Arc::new(arr_millis), - Arc::new(arr_secs), - Arc::new(arr_names), - ], - ) - .unwrap() -} - #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files diff --git a/datafusion/core/tests/sql/order.rs b/datafusion/core/tests/sql/order.rs index 100e5b8c44fb7..6e3f6319e1190 100644 --- a/datafusion/core/tests/sql/order.rs +++ b/datafusion/core/tests/sql/order.rs @@ -16,9 +16,9 @@ // under the License. use super::*; -use datafusion::datasource::datasource::TableProviderFactory; use datafusion::datasource::listing::ListingTable; use datafusion::datasource::listing_table_factory::ListingTableFactory; +use datafusion::datasource::provider::TableProviderFactory; use datafusion_expr::logical_plan::DdlStatement; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -48,7 +48,9 @@ async fn sort_with_lots_of_repetition_values() -> Result<()> { async fn create_external_table_with_order() -> Result<()> { let ctx = SessionContext::new(); let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a_id ASC) LOCATION 'file://path/to/table';"; - let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = ctx.state().create_logical_plan(sql).await? else { + let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = + ctx.state().create_logical_plan(sql).await? + else { panic!("Wrong command") }; @@ -68,7 +70,7 @@ async fn create_external_table_with_ddl_ordered_non_cols() -> Result<()> { Ok(_) => panic!("Expecting error."), Err(e) => { assert_eq!( - e.to_string(), + e.strip_backtrace(), "Error during planning: Column a is not in schema" ) } @@ -83,7 +85,7 @@ async fn create_external_table_with_ddl_ordered_without_schema() -> Result<()> { match ctx.state().create_logical_plan(sql).await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.to_string(), "Error during planning: Provide a schema before specifying the order while creating a table.") + assert_eq!(e.strip_backtrace(), "Error during planning: Provide a schema before specifying the order while creating a table.") } } Ok(()) @@ -123,7 +125,7 @@ async fn sort_with_duplicate_sort_exprs() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let expected = vec![ + let expected = [ "+----+------+", "| id | name |", "+----+------+", @@ -154,7 +156,7 @@ async fn sort_with_duplicate_sort_exprs() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let expected = vec![ + let expected = [ "+----+------+", "| id | name |", "+----+------+", @@ -178,7 +180,7 @@ async fn test_issue5970_mini() -> Result<()> { let config = SessionConfig::new() .with_target_partitions(2) .with_repartition_sorts(true); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let sql = " WITH m0(t) AS ( @@ -206,21 +208,21 @@ ORDER BY 1, 2; " ProjectionExec: expr=[Int64(0)@0 as m, t@1 as t]", " AggregateExec: mode=FinalPartitioned, gby=[Int64(0)@0 as Int64(0), t@1 as t], aggr=[]", " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"Int64(0)\", index: 0 }, Column { name: \"t\", index: 1 }], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[0 as Int64(0), t@0 as t], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " RepartitionExec: partitioning=Hash([Int64(0)@0, t@1], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " AggregateExec: mode=Partial, gby=[0 as Int64(0), t@0 as t], aggr=[]", " ProjectionExec: expr=[column1@0 as t]", " ValuesExec", " ProjectionExec: expr=[Int64(1)@0 as m, t@1 as t]", " AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1), t@1 as t], aggr=[]", " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([Column { name: \"Int64(1)\", index: 0 }, Column { name: \"t\", index: 1 }], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " RepartitionExec: partitioning=Hash([Int64(1)@0, t@1], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[]", " ProjectionExec: expr=[column1@0 as t]", " ValuesExec", ]; - let formatted = displayable(plan.as_ref()).indent().to_string(); + let formatted = displayable(plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); assert_eq!( expected, actual, @@ -233,7 +235,7 @@ ORDER BY 1, 2; let actual = execute_to_batches(&ctx, sql).await; // in https://github.com/apache/arrow-datafusion/issues/5970 the order of the output was sometimes not right - let expected = vec![ + let expected = [ "+---+---+", "| m | t |", "+---+---+", diff --git a/datafusion/core/tests/sql/parquet.rs b/datafusion/core/tests/sql/parquet.rs index 907a2c9506727..8f810a929df3b 100644 --- a/datafusion/core/tests/sql/parquet.rs +++ b/datafusion/core/tests/sql/parquet.rs @@ -32,7 +32,7 @@ async fn parquet_query() { // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----+---------------------------+", "| id | alltypes_plain.string_col |", "+----+---------------------------+", @@ -263,7 +263,7 @@ async fn parquet_list_columns() { assert_eq!( as_string_array(&utf8_list_array.value(0)).unwrap(), - &StringArray::try_from(vec![Some("abc"), Some("efg"), Some("hij"),]).unwrap() + &StringArray::from(vec![Some("abc"), Some("efg"), Some("hij"),]) ); assert_eq!( @@ -335,7 +335,7 @@ async fn parquet_query_with_max_min() { let sql = "SELECT max(c1) FROM foo"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------+", "| MAX(foo.c1) |", "+-------------+", @@ -347,7 +347,7 @@ async fn parquet_query_with_max_min() { let sql = "SELECT min(c2) FROM foo"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------+", "| MIN(foo.c2) |", "+-------------+", @@ -359,7 +359,7 @@ async fn parquet_query_with_max_min() { let sql = "SELECT max(c3) FROM foo"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------+", "| MAX(foo.c3) |", "+-------------+", @@ -371,7 +371,7 @@ async fn parquet_query_with_max_min() { let sql = "SELECT min(c4) FROM foo"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------+", "| MIN(foo.c4) |", "+-------------+", diff --git a/datafusion/core/tests/sql/parquet_schema.rs b/datafusion/core/tests/sql/parquet_schema.rs index 1d96f2b1ff979..bc1578da2c58a 100644 --- a/datafusion/core/tests/sql/parquet_schema.rs +++ b/datafusion/core/tests/sql/parquet_schema.rs @@ -58,7 +58,7 @@ async fn schema_merge_ignores_metadata_by_default() { write_files(table_dir.as_path(), schemas); // can be any order - let expected = vec![ + let expected = [ "+----+------+", "| id | name |", "+----+------+", @@ -119,7 +119,7 @@ async fn schema_merge_can_preserve_metadata() { write_files(table_dir.as_path(), schemas); // can be any order - let expected = vec![ + let expected = [ "+----+------+", "| id | name |", "+----+------+", diff --git a/datafusion/core/tests/sql/partitioned_csv.rs b/datafusion/core/tests/sql/partitioned_csv.rs index 98cb3b1893612..b77557a66cd89 100644 --- a/datafusion/core/tests/sql/partitioned_csv.rs +++ b/datafusion/core/tests/sql/partitioned_csv.rs @@ -19,31 +19,13 @@ use std::{io::Write, sync::Arc}; -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::{ error::Result, prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; -/// Execute SQL and return results -async fn plan_and_collect( - ctx: &mut SessionContext, - sql: &str, -) -> Result> { - ctx.sql(sql).await?.collect().await -} - -/// Execute SQL and return results -pub async fn execute(sql: &str, partition_count: usize) -> Result> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, partition_count).await?; - plan_and_collect(&mut ctx, sql).await -} - /// Generate CSV partitions within the supplied directory fn populate_csv_partitions( tmp_dir: &TempDir, @@ -78,7 +60,8 @@ pub async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, ) -> Result { - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 498952a808c46..fe735bf6b8282 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -34,7 +34,7 @@ async fn string_coercion() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let expected = vec![ + let expected = [ "+----------------+----------------+", "| vendor_id_utf8 | vendor_id_dict |", "+----------------+----------------+", @@ -175,13 +175,11 @@ where // assert data let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-----------+-------------------------------+--------------------------+-------------------------------------+", + let expected = ["+-----------+-------------------------------+--------------------------+-------------------------------------+", "| p_partkey | SUM(lineitem.l_extendedprice) | AVG(lineitem.l_discount) | COUNT(DISTINCT partsupp.ps_suppkey) |", "+-----------+-------------------------------+--------------------------+-------------------------------------+", "| 63700 | 13309.60 | 0.100000 | 1 |", - "+-----------+-------------------------------+--------------------------+-------------------------------------+", - ]; + "+-----------+-------------------------------+--------------------------+-------------------------------------+"]; assert_batches_eq!(expected, &results); Ok(()) diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs deleted file mode 100644 index a90cf1a2c202a..0000000000000 --- a/datafusion/core/tests/sql/projection.rs +++ /dev/null @@ -1,377 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::provider_as_source; -use datafusion::test_util::scan_empty; -use datafusion_expr::{when, LogicalPlanBuilder, UNNAMED_TABLE}; -use tempfile::TempDir; - -use super::*; - -#[tokio::test] -async fn projection_same_fields() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&ctx, sql).await; - - #[rustfmt::skip] - let expected = vec![ - "+---+", - "| a |", - "+---+", - "| 2 |", - "+---+" - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn projection_type_alias() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_simple_csv(&ctx).await?; - - // Query that aliases one column to the name of a different column - // that also has a different type (c1 == float32, c3 == boolean) - let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec![ - "+---------+", - "| c3 |", - "+---------+", - "| 0.00001 |", - "| 0.00002 |", - "+---------+", - ]; - assert_batches_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_csv(&ctx).await?; - let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-----------------------------+----+", - "| AVG(aggregate_test_100.c12) | c1 |", - "+-----------------------------+----+", - "| 0.41040709263815384 | b |", - "| 0.48600669271341534 | e |", - "| 0.48754517466109415 | a |", - "| 0.48855379387549824 | d |", - "| 0.6600456536439784 | c |", - "+-----------------------------+----+", - ]; - assert_batches_sorted_eq!(expected, &actual); - Ok(()) -} - -#[tokio::test] -async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = - partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn subquery_alias_case_insensitive() -> Result<()> { - let partition_count = 1; - let results = - partitioned_csv::execute("SELECT V1.c1, v1.C2 FROM (SELECT test.C1, TEST.c2 FROM test) V1 ORDER BY v1.c1, V1.C2 LIMIT 1", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 0 | 1 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - - let table = ctx.table("test").await?; - let logical_plan = LogicalPlanBuilder::from(table.into_optimized_plan()?) - .project(vec![col("c2")])? - .build()?; - - let state = ctx.state(); - let optimized_plan = state.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - } - - let expected = "TableScan: test projection=[c2]"; - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) -} - -#[tokio::test] -async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").await.unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let dataframe = DataFrame::new(ctx.state(), plan); - let physical_plan = dataframe.create_physical_plan().await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) -} - -#[tokio::test] -async fn project_cast_dictionary() { - let ctx = SessionContext::new(); - - let host: DictionaryArray = vec![Some("host1"), None, Some("host2")] - .into_iter() - .collect(); - - let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap(); - - let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap(); - - // Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast - let expr = when(col("host").is_null(), lit("")) - .otherwise(col("host")) - .unwrap(); - - let projection = None; - let builder = LogicalPlanBuilder::scan( - "cpu_load_short", - provider_as_source(Arc::new(t)), - projection, - ) - .unwrap(); - - let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap(); - let df = DataFrame::new(ctx.state(), logical_plan); - let actual = df.collect().await.unwrap(); - - let expected = vec![ - "+----------------------------------------------------------------------------------+", - "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |", - "+----------------------------------------------------------------------------------+", - "| host1 |", - "| |", - "| host2 |", - "+----------------------------------------------------------------------------------+", - ]; - assert_batches_eq!(expected, &actual); -} - -#[tokio::test] -async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]]; - - let provider = Arc::new(MemTable::try_new(schema, partitions)?); - let plan = - LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = SessionContext::new(); - let state = ctx.state(); - let optimized_plan = state.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - } - - let expected = format!("TableScan: {UNNAMED_TABLE} projection=[b]"); - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, state.task_ctx()).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) -} - -fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { - let actual: Vec = plan - .schema() - .fields() - .iter() - .map(|f| f.name().clone()) - .collect(); - assert_eq!(actual, expected); -} - -#[tokio::test] -async fn project_column_with_same_name_as_relation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select a.a from (select 1 as a) as a;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_false() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec!["++", "++"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_column_with_filters_that_cant_pushed_down_always_true() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select * from (select 1 as a) f where f.a=1;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} - -#[tokio::test] -async fn project_columns_in_memory_without_propagation() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = "select column1 as a from (values (1), (2)) f where f.column1 = 2;"; - let actual = execute_to_batches(&ctx, sql).await; - - let expected = vec!["+---+", "| a |", "+---+", "| 2 |", "+---+"]; - assert_batches_sorted_eq!(expected, &actual); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/references.rs b/datafusion/core/tests/sql/references.rs index 60191e521340a..f465e8a2dacc5 100644 --- a/datafusion/core/tests/sql/references.rs +++ b/datafusion/core/tests/sql/references.rs @@ -29,12 +29,12 @@ async fn qualified_table_references() -> Result<()> { ] { let sql = format!("SELECT COUNT(*) FROM {table_ref}"); let actual = execute_to_batches(&ctx, &sql).await; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 100 |", - "+-----------------+", + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 100 |", + "+----------+", ]; assert_batches_eq!(expected, &actual); } @@ -73,7 +73,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // however, enclosing it in double quotes is ok let sql = r#"SELECT "f.c1" from test"#; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| f.c1 |", "+--------+", @@ -91,7 +91,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // check that duplicated table name and column name are ok let sql = r#"SELECT "test.c2" as expr1, test."test.c2" as expr2 from test"#; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+-------+", "| expr1 | expr2 |", "+-------+-------+", @@ -107,7 +107,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // this let sql = r#"SELECT "....", "...." as c3 from test order by "....""#; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+------+----+", "| .... | c3 |", "+------+----+", @@ -124,7 +124,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { async fn test_partial_qualified_name() -> Result<()> { let ctx = create_join_context("t1_id", "t2_id", true)?; let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; - let expected = vec![ + let expected = [ "+-------+---------+", "| t1_id | t1_name |", "+-------+---------+", diff --git a/datafusion/core/tests/repartition.rs b/datafusion/core/tests/sql/repartition.rs similarity index 98% rename from datafusion/core/tests/repartition.rs rename to datafusion/core/tests/sql/repartition.rs index 20e64b2eeefc2..332f18e941aaa 100644 --- a/datafusion/core/tests/repartition.rs +++ b/datafusion/core/tests/sql/repartition.rs @@ -33,7 +33,7 @@ use std::sync::Arc; #[tokio::test] async fn unbounded_repartition() -> Result<()> { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let task = ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a2", DataType::UInt32, false)])); let batch = RecordBatch::try_new( diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 52d338d9e9eca..cbdea9d729487 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -29,7 +29,7 @@ async fn query_get_indexed_field() -> Result<()> { )])); let builder = PrimitiveBuilder::::with_capacity(3); let mut lb = ListBuilder::new(builder); - for int_vec in vec![vec![0, 1, 2], vec![4, 5, 6], vec![7, 8, 9]] { + for int_vec in [[0, 1, 2], [4, 5, 6], [7, 8, 9]] { let builder = lb.values(); for int in int_vec { builder.append_value(int); @@ -45,15 +45,13 @@ async fn query_get_indexed_field() -> Result<()> { let sql = "SELECT some_list[1] as i0 FROM ints LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; #[rustfmt::skip] - let expected = vec![ - "+----+", + let expected = ["+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 7 |", - "+----+", - ]; + "+----+"]; assert_batches_eq!(expected, &actual); Ok(()) } @@ -72,10 +70,10 @@ async fn query_nested_get_indexed_field() -> Result<()> { let builder = PrimitiveBuilder::::with_capacity(3); let nested_lb = ListBuilder::new(builder); let mut lb = ListBuilder::new(nested_lb); - for int_vec_vec in vec![ - vec![vec![0, 1], vec![2, 3], vec![3, 4]], - vec![vec![5, 6], vec![7, 8], vec![9, 10]], - vec![vec![11, 12], vec![13, 14], vec![15, 16]], + for int_vec_vec in [ + [[0, 1], [2, 3], [3, 4]], + [[5, 6], [7, 8], [9, 10]], + [[11, 12], [13, 14], [15, 16]], ] { let nested_builder = lb.values(); for int_vec in int_vec_vec { @@ -95,7 +93,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_list[1] as i0 FROM ints LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------+", "| i0 |", "+----------+", @@ -108,15 +106,13 @@ async fn query_nested_get_indexed_field() -> Result<()> { let sql = "SELECT some_list[1][1] as i0 FROM ints LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; #[rustfmt::skip] - let expected = vec![ - "+----+", + let expected = ["+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", - "+----+", - ]; + "+----+"]; assert_batches_eq!(expected, &actual); Ok(()) } @@ -136,7 +132,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { let builder = PrimitiveBuilder::::with_capacity(3); let nested_lb = ListBuilder::new(builder); let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]); - for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]] { + for int_vec in [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] { let lb = sb.field_builder::>(0).unwrap(); for int in int_vec { lb.values().append_value(int); @@ -152,7 +148,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_struct['bar'] as l0 FROM structs LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------------+", "| l0 |", "+----------------+", @@ -166,7 +162,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { // Access to field of struct by CompoundIdentifier let sql = "SELECT some_struct.bar as l0 FROM structs LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------------+", "| l0 |", "+----------------+", @@ -180,21 +176,18 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { let sql = "SELECT some_struct['bar'][1] as i0 FROM structs LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; #[rustfmt::skip] - let expected = vec![ - "+----+", + let expected = ["+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", - "+----+", - ]; + "+----+"]; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] -#[cfg(feature = "dictionary_expressions")] async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types // Use StringDictionary (32 bit indexes = keys) @@ -220,7 +213,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Basic SELECT let sql = "SELECT d1 FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -234,7 +227,7 @@ async fn query_on_string_dictionary() -> Result<()> { // basic filtering let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -247,7 +240,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -259,7 +252,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d2"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -271,7 +264,7 @@ async fn query_on_string_dictionary() -> Result<()> { // order comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 <= d2"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -283,7 +276,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with a non dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d3"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -295,7 +288,7 @@ async fn query_on_string_dictionary() -> Result<()> { // filtering with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+", "| d1 |", "+-------+", @@ -307,7 +300,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+------------------------------+", "| concat(test.d1,Utf8(\"-foo\")) |", "+------------------------------+", @@ -321,7 +314,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation with two dictionaries let sql = "SELECT concat(d1, d2) FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------------------+", "| concat(test.d1,test.d2) |", "+-------------------------+", @@ -335,7 +328,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation let sql = "SELECT COUNT(d1) FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------------+", "| COUNT(test.d1) |", "+----------------+", @@ -347,7 +340,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation min let sql = "SELECT MIN(d1) FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------------+", "| MIN(test.d1) |", "+--------------+", @@ -359,7 +352,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation max let sql = "SELECT MAX(d1) FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------------+", "| MAX(test.d1) |", "+--------------+", @@ -371,21 +364,21 @@ async fn query_on_string_dictionary() -> Result<()> { // grouping let sql = "SELECT d1, COUNT(*) FROM test group by d1"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+-----------------+", - "| d1 | COUNT(UInt8(1)) |", - "+-------+-----------------+", - "| one | 1 |", - "| | 1 |", - "| three | 1 |", - "+-------+-----------------+", + let expected = [ + "+-------+----------+", + "| d1 | COUNT(*) |", + "+-------+----------+", + "| | 1 |", + "| one | 1 |", + "| three | 1 |", + "+-------+----------+", ]; assert_batches_sorted_eq!(expected, &actual); // window functions let sql = "SELECT d1, row_number() OVER (partition by d1) as rn1 FROM test"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------+-----+", "| d1 | rn1 |", "+-------+-----+", @@ -414,7 +407,8 @@ async fn sort_on_window_null_string() -> Result<()> { ]) .unwrap(); - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(1)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(1)); ctx.register_batch("test", batch)?; let sql = @@ -422,7 +416,7 @@ async fn sort_on_window_null_string() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; // NULLS LAST - let expected = vec![ + let expected = [ "+-------+-----+", "| d1 | rn1 |", "+-------+-----+", @@ -437,7 +431,7 @@ async fn sort_on_window_null_string() -> Result<()> { "SELECT d2, row_number() OVER (partition by d2) as rn1 FROM test ORDER BY d2 asc"; let actual = execute_to_batches(&ctx, sql).await; // NULLS LAST - let expected = vec![ + let expected = [ "+-------+-----+", "| d2 | rn1 |", "+-------+-----+", @@ -453,7 +447,7 @@ async fn sort_on_window_null_string() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; // NULLS FIRST - let expected = vec![ + let expected = [ "+-------+-----+", "| d2 | rn1 |", "+-------+-----+", @@ -531,6 +525,53 @@ async fn test_prepare_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_query_parameters() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 + let results = ctx + .sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo") + .await? + .with_param_values(vec![ + ("foo", ScalarValue::UInt32(Some(3))), + ("coo", ScalarValue::UInt32(Some(0))), + ])? + .collect() + .await?; + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; @@ -578,7 +619,7 @@ async fn boolean_literal() -> Result<()> { execute_with_partition("SELECT c1, c3 FROM test WHERE c1 > 2 AND c3 = true", 4) .await?; - let expected = vec![ + let expected = [ "+----+------+", "| c1 | c3 |", "+----+------+", @@ -597,7 +638,7 @@ async fn boolean_literal() -> Result<()> { #[tokio::test] async fn unprojected_filter() { let config = SessionConfig::new(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let df = ctx.read_table(table_with_sequence(1, 3).unwrap()).unwrap(); let df = df @@ -611,7 +652,7 @@ async fn unprojected_filter() { let results = df.collect().await.unwrap(); - let expected = vec![ + let expected = [ "+-----------------------+", "| ?table?.i + ?table?.i |", "+-----------------------+", diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs new file mode 100644 index 0000000000000..d7adc9611b2ff --- /dev/null +++ b/datafusion/core/tests/sql/sql_api.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::prelude::*; +use tempfile::TempDir; + +#[tokio::test] +async fn unsupported_ddl_returns_error() { + // Verify SessionContext::with_sql_options errors appropriately + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + // disallow ddl + let options = SQLOptions::new().with_allow_ddl(false); + + let sql = "create view test_view as select * from test"; + let df = ctx.sql_with_options(sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: DDL not supported: CreateView" + ); + + // allow ddl + let options = options.with_allow_ddl(true); + ctx.sql_with_options(sql, options).await.unwrap(); +} + +#[tokio::test] +async fn unsupported_dml_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let options = SQLOptions::new().with_allow_dml(false); + + let sql = "insert into test values (1)"; + let df = ctx.sql_with_options(sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: DML not supported: Insert Into" + ); + + let options = options.with_allow_dml(true); + ctx.sql_with_options(sql, options).await.unwrap(); +} + +#[tokio::test] +async fn unsupported_copy_returns_error() { + let tmpdir = TempDir::new().unwrap(); + let tmpfile = tmpdir.path().join("foo.parquet"); + + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let options = SQLOptions::new().with_allow_dml(false); + + let sql = format!("copy (values(1)) to '{}'", tmpfile.to_string_lossy()); + let df = ctx.sql_with_options(&sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: DML not supported: COPY" + ); + + let options = options.with_allow_dml(true); + ctx.sql_with_options(&sql, options).await.unwrap(); +} + +#[tokio::test] +async fn unsupported_statement_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let options = SQLOptions::new().with_allow_statements(false); + + let sql = "set datafusion.execution.batch_size = 5"; + let df = ctx.sql_with_options(sql, options).await; + assert_eq!( + df.unwrap_err().strip_backtrace(), + "Error during planning: Statement not supported: SetVariable" + ); + + let options = options.with_allow_statements(true); + ctx.sql_with_options(sql, options).await.unwrap(); +} + +#[tokio::test] +async fn ddl_can_not_be_planned_by_session_state() { + let ctx = SessionContext::new(); + + // make a table via SQL + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let state = ctx.state(); + + // can not create a logical plan for catalog DDL + let sql = "drop table test"; + let plan = state.create_logical_plan(sql).await.unwrap(); + let physical_plan = state.create_physical_plan(&plan).await; + assert_eq!( + physical_plan.unwrap_err().strip_backtrace(), + "This feature is not implemented: Unsupported logical plan: DropTable" + ); +} diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs deleted file mode 100644 index 574553a938da5..0000000000000 --- a/datafusion/core/tests/sql/subqueries.rs +++ /dev/null @@ -1,761 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::sql::execute_to_batches; -use datafusion::assert_batches_eq; -use datafusion::prelude::SessionContext; -use log::debug; - -#[tokio::test] -async fn correlated_recursive_scalar_subquery() -> Result<()> { - let ctx = SessionContext::new(); - register_tpch_csv(&ctx, "customer").await?; - register_tpch_csv(&ctx, "orders").await?; - register_tpch_csv(&ctx, "lineitem").await?; - - let sql = r#" -select c_custkey from customer -where c_acctbal < ( - select sum(o_totalprice) from orders - where o_custkey = c_custkey - and o_totalprice < ( - select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey - ) -) order by c_custkey;"#; - - // assert plan - let dataframe = ctx.sql(sql).await.unwrap(); - debug!("input:\n{}", dataframe.logical_plan().display_indent()); - - let plan = dataframe.into_optimized_plan().unwrap(); - let actual = format!("{}", plan.display_indent()); - let expected = "Sort: customer.c_custkey ASC NULLS LAST\ - \n Projection: customer.c_custkey\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\ - \n TableScan: customer projection=[c_custkey, c_acctbal]\ - \n SubqueryAlias: __scalar_sq_1\ - \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\ - \n Projection: orders.o_custkey, orders.o_totalprice\ - \n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\ - \n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\ - \n SubqueryAlias: __scalar_sq_2\ - \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\ - \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\ - \n TableScan: lineitem projection=[l_orderkey, l_extendedprice]"; - assert_eq!(actual, expected); - - Ok(()) -} - -#[tokio::test] -async fn correlated_where_in() -> Result<()> { - let orders = r#"1,3691,O,194029.55,1996-01-02,5-LOW,Clerk#000000951,0, -65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0, -"#; - let lineitems = r#"1,15519,785,1,17,24386.67,0.04,0.02,N,O,1996-03-13,1996-02-12,1996-03-22,DELIVER IN PERSON,TRUCK, -1,6731,732,2,36,58958.28,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL, -65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK, -65,7382,897,2,22,28366.36,0,0.05,N,O,1995-07-17,1995-06-04,1995-07-19,COLLECT COD,FOB, -"#; - - let ctx = SessionContext::new(); - register_tpch_csv_data(&ctx, "orders", orders).await?; - register_tpch_csv_data(&ctx, "lineitem", lineitems).await?; - - let sql = r#"select o_orderkey from orders -where o_orderstatus in ( - select l_linestatus from lineitem where l_orderkey = orders.o_orderkey -);"#; - - // assert plan - let dataframe = ctx.sql(sql).await.unwrap(); - let plan = dataframe.into_optimized_plan().unwrap(); - let actual = format!("{}", plan.display_indent()); - - let expected = "Projection: orders.o_orderkey\ - \n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\ - \n TableScan: orders projection=[o_orderkey, o_orderstatus]\ - \n SubqueryAlias: __correlated_sq_1\ - \n Projection: lineitem.l_linestatus, lineitem.l_orderkey\ - \n TableScan: lineitem projection=[l_orderkey, l_linestatus]"; - assert_eq!(actual, expected); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+------------+", - "| o_orderkey |", - "+------------+", - "| 1 |", - "+------------+", - ]; - assert_batches_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn exists_subquery_with_same_table() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - // Subquery and outer query refer to the same table. - // It will not be rewritten to join because it is not a correlated subquery. - let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE EXISTS(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Subquery: [t1_int:UInt32;N]", - " Projection: t1.t1_int [t1_int:UInt32;N]", - " Filter: t1.t1_id > t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn in_subquery_with_same_table() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - // Subquery and outer query refer to the same table. - // It will be rewritten to join because in-subquery has extra predicate(`t1.t1_id = __correlated_sq_1.t1_int`). - let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t1_int:UInt32;N]", - " Projection: t1.t1_int [t1_int:UInt32;N]", - " Filter: t1.t1_id > t1.t1_int [t1_id:UInt32;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn in_subquery_nested_exist_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int))"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - " SubqueryAlias: __correlated_sq_2 [t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_int] [t1_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn invalid_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t1.t1_int) FROM t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Scalar subquery should only return one column, but found 2: t2.t2_id, t2.t2_name"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn subquery_not_allowed() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - // In/Exist Subquery is not allowed in ORDER BY clause. - let sql = "SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn non_aggregated_correlated_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#, - &format!("{err:?}") - ); - - let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn non_aggregated_correlated_scalar_subquery_with_limit() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 2) as t2_int from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery must be aggregated to return at most one row"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn non_aggregated_correlated_scalar_subquery_with_single_row() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) as t2_int from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_id, () AS t2_int [t1_id:UInt32;N, t2_int:UInt32;N]", - " Subquery: [t2_int:UInt32;N]", - " Limit: skip=0, fetch=1 [t2_int:UInt32;N]", - " Projection: t2.t2_int [t2_int:UInt32;N]", - " Filter: t2.t2_int = outer_ref(t1.t1_int) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let sql = "SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_id [t1_id:UInt32;N]", - " Filter: t1.t1_int = () [t1_id:UInt32;N, t1_int:UInt32;N]", - " Subquery: [t2_int:UInt32;N]", - " Limit: skip=0, fetch=1 [t2_int:UInt32;N]", - " Projection: t2.t2_int [t2_int:UInt32;N]", - " Filter: t2.t2_int = outer_ref(t1.t1_int) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let sql = "SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_id, () AS t2_int [t1_id:UInt32;N, t2_int:Int64]", - " Subquery: [a:Int64]", - " Projection: a [a:Int64]", - " Filter: a = CAST(outer_ref(t1.t1_int) AS Int64) [a:Int64]", - " Projection: Int64(1) AS a [a:Int64]", - " EmptyRelation []", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn non_equal_correlated_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: t2.t2_id < outer_ref(t1.t1_id)"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn aggregated_correlated_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_id, __scalar_sq_1.__value AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", - " Left Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t2_id:UInt32;N, __value:UInt64;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N, __value:UInt64;N]", - " Projection: t2.t2_id, SUM(t2.t2_int) AS __value [t2_id:UInt32;N, __value:UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn aggregated_correlated_scalar_subquery_with_extra_group_by_columns() -> Result<()> -{ - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_name) as t2_sum from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn aggregated_correlated_scalar_subquery_with_extra_group_by_constant() -> Result<()> -{ - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: t1.t1_id, __scalar_sq_1.__value AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", - " Left Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t2_id:UInt32;N, __value:UInt64;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N, __value:UInt64;N]", - " Projection: t2.t2_id, SUM(t2.t2_int) AS __value [t2_id:UInt32;N, __value:UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn group_by_correlated_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - let sql = "SELECT sum(t1_int) from t1 GROUP BY (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id)"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let err = dataframe.into_optimized_plan().err().unwrap(); - - assert_eq!( - r#"Context("check_analyzed_plan", Plan("Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"))"#, - &format!("{err:?}") - ); - - Ok(()) -} - -#[tokio::test] -async fn support_agg_correlated_columns() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", - " Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", - " Aggregate: groupBy=[[]], aggr=[[SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", - " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn support_agg_correlated_columns2() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [COUNT(UInt8(1)):Int64;N]", - " Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]", - " Filter: CAST(SUM(outer_ref(t1.t1_int) + t2.t2_id) AS Int64) > Int64(0) [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", - " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)), SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", - " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn support_join_correlated_columns() -> Result<()> { - let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?; - let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name))"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t0_id:UInt32;N, t0_name:Utf8;N]", - " Subquery: [Int64(1):Int64]", - " Projection: Int64(1) [Int64(1):Int64]", - " Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn support_join_correlated_columns2() -> Result<()> { - let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?; - let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN (select * from t2 where t2.t2_name = t0.t0_name) as t2 ON(t1.t1_id = t2.t2_id ))"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t0_id:UInt32;N, t0_name:Utf8;N]", - " Subquery: [Int64(1):Int64]", - " Projection: Int64(1) [Int64(1):Int64]", - " Inner Join: Filter: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " SubqueryAlias: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_name = outer_ref(t0.t0_name) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn support_order_by_correlated_columns() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id >= t1_id order by t1_id)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Sort: outer_ref(t1.t1_id) ASC NULLS LAST [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_id >= outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -// TODO: issue https://github.com/apache/arrow-datafusion/issues/6263 -#[ignore] -#[tokio::test] -async fn support_limit_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Limit: skip=0, fetch=1 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where t1_name = t2_name limit 10)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: t1.t1_id IN () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [t2_id:UInt32;N]", - " Limit: skip=0, fetch=10 [t2_id:UInt32;N]", - " Projection: t2.t2_id [t2_id:UInt32;N]", - " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn support_union_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS \ - (SELECT * FROM t2 WHERE t2_id = t1_id UNION ALL \ - SELECT * FROM t2 WHERE upper(t2_name) = upper(t1.t1_name))"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", - " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Union [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: upper(t2.t2_name) = upper(outer_ref(t1.t1_name)) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - Ok(()) -} - -#[tokio::test] -async fn simple_uncorrelated_scalar_subquery() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "select (select count(*) from t1) as b"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: __scalar_sq_1.__value AS b [b:Int64;N]", - " SubqueryAlias: __scalar_sq_1 [__value:Int64;N]", - " Projection: COUNT(UInt8(1)) AS __value [__value:Int64;N]", - " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = vec!["+---+", "| b |", "+---+", "| 4 |", "+---+"]; - assert_batches_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn simple_uncorrelated_scalar_subquery2() -> Result<()> { - let ctx = create_join_context("t1_id", "t2_id", true)?; - - let sql = "select (select count(*) from t1) as b, (select count(1) from t2) as c"; - - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - - let expected = vec![ - "Projection: __scalar_sq_1.__value AS b, __scalar_sq_2.__value AS c [b:Int64;N, c:Int64;N]", - " CrossJoin: [__value:Int64;N, __value:Int64;N]", - " SubqueryAlias: __scalar_sq_1 [__value:Int64;N]", - " Projection: COUNT(UInt8(1)) AS __value [__value:Int64;N]", - " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", - " SubqueryAlias: __scalar_sq_2 [__value:Int64;N]", - " Projection: COUNT(Int64(1)) AS __value [__value:Int64;N]", - " Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] [COUNT(Int64(1)):Int64;N]", - " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - // assert data - let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+---+---+", - "| b | c |", - "+---+---+", - "| 4 | 4 |", - "+---+---+", - ]; - assert_batches_eq!(expected, &results); - - Ok(()) -} diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 2058d8ed1fd60..ada66503a1816 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -138,7 +138,7 @@ async fn timestamp_minmax() -> Result<()> { let sql = "SELECT MIN(table_a.ts), MAX(table_b.ts) FROM table_a, table_b"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------------------+-----------------------------+", "| MIN(table_a.ts) | MAX(table_b.ts) |", "+-------------------------+-----------------------------+", @@ -516,7 +516,7 @@ async fn group_by_timestamp_millis() -> Result<()> { let sql = "SELECT timestamp, SUM(count) FROM t1 GROUP BY timestamp ORDER BY timestamp ASC"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+---------------------+---------------+", "| timestamp | SUM(t1.count) |", "+---------------------+---------------+", @@ -565,28 +565,32 @@ async fn timestamp_sub_interval_days() -> Result<()> { } #[tokio::test] -#[ignore] // https://github.com/apache/arrow-datafusion/issues/3327 async fn timestamp_add_interval_months() -> Result<()> { let ctx = SessionContext::new(); + let table_a = + make_timestamp_tz_table::(Some("+00:00".into()))?; + ctx.register_table("table_a", table_a)?; - let sql = "SELECT NOW(), NOW() + INTERVAL '17' MONTH;"; + let sql = "SELECT ts, ts + INTERVAL '17' MONTH FROM table_a;"; let results = execute_to_batches(&ctx, sql).await; - let actual = result_vec(&results); + let actual_vec = result_vec(&results); - let res1 = actual[0][0].as_str(); - let res2 = actual[0][1].as_str(); + for actual in actual_vec { + let res1 = actual[0].as_str(); + let res2 = actual[1].as_str(); - let format = "%Y-%m-%d %H:%M:%S%.6f"; - let t1_naive = chrono::NaiveDateTime::parse_from_str(res1, format).unwrap(); - let t2_naive = chrono::NaiveDateTime::parse_from_str(res2, format).unwrap(); + let format = "%Y-%m-%dT%H:%M:%S%.6fZ"; + let t1_naive = NaiveDateTime::parse_from_str(res1, format).unwrap(); + let t2_naive = NaiveDateTime::parse_from_str(res2, format).unwrap(); - let year = t1_naive.year() + (t1_naive.month() as i32 + 17) / 12; - let month = (t1_naive.month() + 17) % 12; + let year = t1_naive.year() + (t1_naive.month0() as i32 + 17) / 12; + let month = (t1_naive.month0() + 17) % 12 + 1; - assert_eq!( - t1_naive.with_year(year).unwrap().with_month(month).unwrap(), - t2_naive - ); + assert_eq!( + t1_naive.with_year(year).unwrap().with_month(month).unwrap(), + t2_naive + ); + } Ok(()) } @@ -618,7 +622,7 @@ async fn timestamp_array_add_interval() -> Result<()> { let sql = "SELECT ts, ts - INTERVAL '8' MILLISECONDS FROM table_a"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------------------------+----------------------------------------------+", "| ts | table_a.ts - IntervalMonthDayNano(\"8000000\") |", "+----------------------------+----------------------------------------------+", @@ -631,41 +635,35 @@ async fn timestamp_array_add_interval() -> Result<()> { let sql = "SELECT ts, ts + INTERVAL '1' SECOND FROM table_b"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+-------------------------------------------------+", + let expected = ["+----------------------------+-------------------------------------------------+", "| ts | table_b.ts + IntervalMonthDayNano(\"1000000000\") |", "+----------------------------+-------------------------------------------------+", "| 2020-09-08T13:42:29.190855 | 2020-09-08T13:42:30.190855 |", "| 2020-09-08T12:42:29.190855 | 2020-09-08T12:42:30.190855 |", "| 2020-09-08T11:42:29.190855 | 2020-09-08T11:42:30.190855 |", - "+----------------------------+-------------------------------------------------+", - ]; + "+----------------------------+-------------------------------------------------+"]; assert_batches_eq!(expected, &actual); let sql = "SELECT ts, ts + INTERVAL '2' MONTH FROM table_b"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+---------------------------------------------------------------------+", + let expected = ["+----------------------------+---------------------------------------------------------------------+", "| ts | table_b.ts + IntervalMonthDayNano(\"158456325028528675187087900672\") |", "+----------------------------+---------------------------------------------------------------------+", "| 2020-09-08T13:42:29.190855 | 2020-11-08T13:42:29.190855 |", "| 2020-09-08T12:42:29.190855 | 2020-11-08T12:42:29.190855 |", "| 2020-09-08T11:42:29.190855 | 2020-11-08T11:42:29.190855 |", - "+----------------------------+---------------------------------------------------------------------+", - ]; + "+----------------------------+---------------------------------------------------------------------+"]; assert_batches_eq!(expected, &actual); let sql = "SELECT ts, ts - INTERVAL '16' YEAR FROM table_b"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+----------------------------+-----------------------------------------------------------------------+", + let expected = ["+----------------------------+-----------------------------------------------------------------------+", "| ts | table_b.ts - IntervalMonthDayNano(\"15211807202738752817960438464512\") |", "+----------------------------+-----------------------------------------------------------------------+", "| 2020-09-08T13:42:29.190855 | 2004-09-08T13:42:29.190855 |", "| 2020-09-08T12:42:29.190855 | 2004-09-08T12:42:29.190855 |", "| 2020-09-08T11:42:29.190855 | 2004-09-08T11:42:29.190855 |", - "+----------------------------+-----------------------------------------------------------------------+", - ]; + "+----------------------------+-----------------------------------------------------------------------+"]; assert_batches_eq!(expected, &actual); Ok(()) } @@ -677,7 +675,7 @@ async fn cast_timestamp_before_1970() -> Result<()> { let sql = "select cast('1969-01-01T00:00:00Z' as timestamp);"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+------------------------------+", "| Utf8(\"1969-01-01T00:00:00Z\") |", "+------------------------------+", @@ -689,7 +687,7 @@ async fn cast_timestamp_before_1970() -> Result<()> { let sql = "select cast('1969-01-01T00:00:00.1Z' as timestamp);"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------------------------------+", "| Utf8(\"1969-01-01T00:00:00.1Z\") |", "+--------------------------------+", @@ -709,7 +707,7 @@ async fn test_arrow_typeof() -> Result<()> { let sql = "select arrow_typeof(date_trunc('minute', to_timestamp_seconds(61)));"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"minute\"),to_timestamp_seconds(Int64(61)))) |", "+--------------------------------------------------------------------------+", @@ -720,7 +718,7 @@ async fn test_arrow_typeof() -> Result<()> { let sql = "select arrow_typeof(date_trunc('second', to_timestamp_millis(61)));"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-------------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"second\"),to_timestamp_millis(Int64(61)))) |", "+-------------------------------------------------------------------------+", @@ -731,18 +729,16 @@ async fn test_arrow_typeof() -> Result<()> { let sql = "select arrow_typeof(date_trunc('millisecond', to_timestamp_micros(61)));"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+------------------------------------------------------------------------------+", + let expected = ["+------------------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"millisecond\"),to_timestamp_micros(Int64(61)))) |", "+------------------------------------------------------------------------------+", "| Timestamp(Microsecond, None) |", - "+------------------------------------------------------------------------------+", - ]; + "+------------------------------------------------------------------------------+"]; assert_batches_eq!(expected, &actual); let sql = "select arrow_typeof(date_trunc('microsecond', to_timestamp(61)));"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-----------------------------------------------------------------------+", "| arrow_typeof(date_trunc(Utf8(\"microsecond\"),to_timestamp(Int64(61)))) |", "+-----------------------------------------------------------------------+", @@ -764,7 +760,7 @@ async fn cast_timestamp_to_timestamptz() -> Result<()> { let sql = "SELECT ts::timestamptz, arrow_typeof(ts::timestamptz) FROM table_a;"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+-----------------------------+---------------------------------------+", "| table_a.ts | arrow_typeof(table_a.ts) |", "+-----------------------------+---------------------------------------+", @@ -784,7 +780,7 @@ async fn test_cast_to_time() -> Result<()> { let sql = "SELECT 0::TIME"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------+", "| Int64(0) |", "+----------+", @@ -804,7 +800,7 @@ async fn test_cast_to_time_with_time_zone_should_not_work() -> Result<()> { let results = plan_and_collect(&ctx, sql).await.unwrap_err(); assert_eq!( - results.to_string(), + results.strip_backtrace(), "This feature is not implemented: Unsupported SQL type Time(None, WithTimeZone)" ); @@ -817,7 +813,7 @@ async fn test_cast_to_time_without_time_zone() -> Result<()> { let sql = "SELECT 0::TIME WITHOUT TIME ZONE"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+----------+", "| Int64(0) |", "+----------+", @@ -837,7 +833,7 @@ async fn test_cast_to_timetz_should_not_work() -> Result<()> { let results = plan_and_collect(&ctx, sql).await.unwrap_err(); assert_eq!( - results.to_string(), + results.strip_backtrace(), "This feature is not implemented: Unsupported SQL type Time(None, Tz)" ); Ok(()) @@ -862,7 +858,7 @@ async fn test_current_date() -> Result<()> { let sql = "select case when current_date() = cast(now() as date) then 'OK' else 'FAIL' end result"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -894,7 +890,7 @@ async fn test_current_time() -> Result<()> { let sql = "select case when current_time() = (now()::bigint % 86400000000000)::time then 'OK' else 'FAIL' end result"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -914,7 +910,7 @@ async fn test_ts_dt_binary_ops() -> Result<()> { "select count(1) result from (select now() as n) a where n = '2000-01-01'::date"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -929,7 +925,7 @@ async fn test_ts_dt_binary_ops() -> Result<()> { "select count(1) result from (select now() as n) a where n >= '2000-01-01'::date"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -943,7 +939,7 @@ async fn test_ts_dt_binary_ops() -> Result<()> { let sql = "select now() = '2000-01-01'::date as result"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -957,7 +953,7 @@ async fn test_ts_dt_binary_ops() -> Result<()> { let sql = "select now() >= '2000-01-01'::date as result"; let results = execute_to_batches(&ctx, sql).await; - let expected = vec![ + let expected = [ "+--------+", "| result |", "+--------+", @@ -1035,14 +1031,14 @@ async fn timestamp_sub_with_tz() -> Result<()> { let sql = "SELECT val, ts1 - ts2 AS ts_diff FROM table_a ORDER BY ts2 - ts1"; let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-----+---------------------------------------------------+", - "| val | ts_diff |", - "+-----+---------------------------------------------------+", - "| 3 | 0 years 0 mons 0 days 10 hours 0 mins 30.000 secs |", - "| 1 | 0 years 0 mons 0 days 10 hours 0 mins 20.000 secs |", - "| 2 | 0 years 0 mons 0 days 10 hours 0 mins 10.000 secs |", - "+-----+---------------------------------------------------+", + let expected = [ + "+-----+-----------------------------------+", + "| val | ts_diff |", + "+-----+-----------------------------------+", + "| 3 | 0 days 0 hours 0 mins 30.000 secs |", + "| 1 | 0 days 0 hours 0 mins 20.000 secs |", + "| 2 | 0 days 0 hours 0 mins 10.000 secs |", + "+-----+-----------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sqllogictests/.gitignore b/datafusion/core/tests/sqllogictests/.gitignore deleted file mode 100644 index 8e5bbf044f1e0..0000000000000 --- a/datafusion/core/tests/sqllogictests/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.py \ No newline at end of file diff --git a/datafusion/core/tests/sqllogictests/MOVED.md b/datafusion/core/tests/sqllogictests/MOVED.md new file mode 100644 index 0000000000000..dd70dab9d11f2 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/MOVED.md @@ -0,0 +1,20 @@ + + +The SQL Logic Test code has moved to `datafusion/sqllogictest` diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs deleted file mode 100644 index 41dfbad394245..0000000000000 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; -use datafusion_common::config::ConfigOptions; -use datafusion_common::TableReference; -use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource}; -use datafusion_sql::planner::ContextProvider; -use std::sync::Arc; - -pub struct LogicTestContextProvider {} - -// Only a mock, don't need to implement -impl ContextProvider for LogicTestContextProvider { - fn get_table_provider( - &self, - _name: TableReference, - ) -> datafusion_common::Result> { - todo!() - } - - fn get_function_meta(&self, _name: &str) -> Option> { - todo!() - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - todo!() - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - todo!() - } - - fn options(&self) -> &ConfigOptions { - todo!() - } -} diff --git a/datafusion/core/tests/sqllogictests/src/setup.rs b/datafusion/core/tests/sqllogictests/src/setup.rs deleted file mode 100644 index 8072a0f74f5fa..0000000000000 --- a/datafusion/core/tests/sqllogictests/src/setup.rs +++ /dev/null @@ -1,214 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::{ - arrow::{ - array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, - }, - datasource::MemTable, - prelude::{CsvReadOptions, SessionContext}, - test_util, -}; -use std::sync::Arc; - -use crate::utils; - -#[cfg(feature = "avro")] -pub async fn register_avro_tables(ctx: &mut crate::TestContext) { - use datafusion::prelude::AvroReadOptions; - - ctx.enable_testdir(); - - let table_path = ctx.testdir_path().join("avro"); - std::fs::create_dir(&table_path).expect("failed to create avro table path"); - - let testdata = datafusion::test_util::arrow_test_data(); - let alltypes_plain_file = format!("{testdata}/avro/alltypes_plain.avro"); - std::fs::copy( - &alltypes_plain_file, - format!("{}/alltypes_plain1.avro", table_path.display()), - ) - .unwrap(); - std::fs::copy( - &alltypes_plain_file, - format!("{}/alltypes_plain2.avro", table_path.display()), - ) - .unwrap(); - - ctx.session_ctx() - .register_avro( - "alltypes_plain_multi_files", - table_path.display().to_string().as_str(), - AvroReadOptions::default(), - ) - .await - .unwrap(); -} - -pub async fn register_aggregate_tables(ctx: &SessionContext) { - register_aggregate_test_100(ctx).await; - register_decimal_table(ctx); - register_median_test_tables(ctx); - register_test_data(ctx); -} - -fn register_median_test_tables(ctx: &SessionContext) { - // Register median tables - let items: Vec<(&str, DataType, ArrayRef)> = vec![ - ( - "i8", - DataType::Int8, - Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])), - ), - ( - "i16", - DataType::Int16, - Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])), - ), - ( - "i32", - DataType::Int32, - Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])), - ), - ( - "i64", - DataType::Int64, - Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])), - ), - ( - "u8", - DataType::UInt8, - Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])), - ), - ( - "u16", - DataType::UInt16, - Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])), - ), - ( - "u32", - DataType::UInt32, - Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])), - ), - ( - "u64", - DataType::UInt64, - Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])), - ), - ( - "f32", - DataType::Float32, - Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), - ), - ( - "f64", - DataType::Float64, - Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])), - ), - ( - "f64_nan", - DataType::Float64, - Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])), - ), - ]; - - for (name, data_type, values) in items { - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new("a", data_type, false)])), - vec![values], - ) - .unwrap(); - let table_name = &format!("median_{name}"); - ctx.register_batch(table_name, batch).unwrap(); - } -} - -fn register_test_data(ctx: &SessionContext) { - let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, true), - Field::new("c2", DataType::Int64, true), - ])); - - let data = RecordBatch::try_new( - schema, - vec![ - Arc::new(Int64Array::from(vec![ - Some(0), - Some(1), - None, - Some(3), - Some(3), - ])), - Arc::new(Int64Array::from(vec![ - None, - Some(1), - Some(1), - Some(2), - Some(2), - ])), - ], - ) - .unwrap(); - - ctx.register_batch("test", data).unwrap(); -} - -fn register_decimal_table(ctx: &SessionContext) { - let batch_decimal = utils::make_decimal(); - let schema = batch_decimal.schema(); - let partitions = vec![vec![batch_decimal]]; - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - ctx.register_table("d_table", provider).unwrap(); -} - -async fn register_aggregate_test_100(ctx: &SessionContext) { - let test_data = datafusion::test_util::arrow_test_data(); - let schema = test_util::aggr_test_schema(); - ctx.register_csv( - "aggregate_test_100", - &format!("{test_data}/csv/aggregate_test_100.csv"), - CsvReadOptions::new().schema(&schema), - ) - .await - .unwrap(); -} - -pub async fn register_scalar_tables(ctx: &SessionContext) { - register_nan_table(ctx) -} - -/// Register a table with a NaN value (different than NULL, and can -/// not be created via SQL) -fn register_nan_table(ctx: &SessionContext) { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); - - let data = RecordBatch::try_new( - schema, - vec![Arc::new(Float64Array::from(vec![ - Some(1.0), - None, - Some(f64::NAN), - ]))], - ) - .unwrap(); - ctx.register_batch("test_float", data).unwrap(); -} diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt deleted file mode 100644 index ac4e223ab69ac..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ /dev/null @@ -1,1906 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -statement ok -CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT, - c5 INT, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 BIGINT UNSIGNED NOT NULL, - c10 VARCHAR NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL -) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv' - -####### -# Error tests -####### - -# https://github.com/apache/arrow-datafusion/issues/3353 -statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "APPROX_DISTINCT\(aggregate_test_100\.c9\)" -SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 - -# csv_query_approx_percentile_cont_with_weight -statement error Error during planning: The function ApproxPercentileContWithWeight does not support inputs of type Utf8. -SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 - -statement error Error during planning: The weight argument for ApproxPercentileContWithWeight does not support inputs of type Utf8 -SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 - -statement error Error during planning: The percentile argument for ApproxPercentileContWithWeight must be Float64, not Utf8. -SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 - -# csv_query_approx_percentile_cont_with_histogram_bins -statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). -SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 - -statement error Error during planning: The percentile sample points count for ApproxPercentileCont must be integer, not Utf8. -SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 - -statement error Error during planning: The percentile sample points count for ApproxPercentileCont must be integer, not Float64. -SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 - -# csv_query_array_agg_unsupported -statement error This feature is not implemented: Order-sensitive aggregators is not supported on multiple partitions -SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100 - -statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 -SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 - - -# FIX: custom absolute values -# csv_query_avg_multi_batch - -# csv_query_avg -query R -SELECT avg(c12) FROM aggregate_test_100 ----- -0.508972509913 - -# csv_query_bit_and -query IIIII -SELECT bit_and(c5), bit_and(c6), bit_and(c7), bit_and(c8), bit_and(c9) FROM aggregate_test_100 ----- -0 0 0 0 0 - -# csv_query_bit_and_distinct -query IIIII -SELECT bit_and(distinct c5), bit_and(distinct c6), bit_and(distinct c7), bit_and(distinct c8), bit_and(distinct c9) FROM aggregate_test_100 ----- -0 0 0 0 0 - -# csv_query_bit_or -query IIIII -SELECT bit_or(c5), bit_or(c6), bit_or(c7), bit_or(c8), bit_or(c9) FROM aggregate_test_100 ----- --1 -1 255 65535 4294967295 - -# csv_query_bit_or_distinct -query IIIII -SELECT bit_or(distinct c5), bit_or(distinct c6), bit_or(distinct c7), bit_or(distinct c8), bit_or(distinct c9) FROM aggregate_test_100 ----- --1 -1 255 65535 4294967295 - -# csv_query_bit_xor -query IIIII -SELECT bit_xor(c5), bit_xor(c6), bit_xor(c7), bit_xor(c8), bit_xor(c9) FROM aggregate_test_100 ----- -1632751011 5960911605712039654 148 54789 169634700 - -# csv_query_bit_xor_distinct (should be different than above) -query IIIII -SELECT bit_xor(distinct c5), bit_xor(distinct c6), bit_xor(distinct c7), bit_xor(distinct c8), bit_xor(distinct c9) FROM aggregate_test_100 ----- -1632751011 5960911605712039654 196 54789 169634700 - -# csv_query_bit_xor_distinct_expr -query I -SELECT bit_xor(distinct c5 % 2) FROM aggregate_test_100 ----- --2 - -# csv_query_covariance_1 -query R -SELECT covar_pop(c2, c12) FROM aggregate_test_100 ----- --0.079169322354 - -# csv_query_covariance_2 -query R -SELECT covar(c2, c12) FROM aggregate_test_100 ----- --0.079969012479 - -# single_row_query_covar_1 -query R -select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq ----- -NULL - -# single_row_query_covar_2 -query R -select covar_pop(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq ----- -0 - -# all_nulls_query_covar -query RR -with data as ( - select null::int as f, null::int as b - union all - select null::int as f, null::int as b -) -select covar_samp(f, b), covar_pop(f, b) -from data ----- -NULL NULL - -# covar_query_with_nulls -query RR -with data as ( - select 1 as f, 4 as b - union all - select null as f, 99 as b - union all - select 2 as f, 5 as b - union all - select 98 as f, null as b - union all - select 3 as f, 6 as b - union all - select null as f, null as b -) -select covar_samp(f, b), covar_pop(f, b) -from data ----- -1 0.666666666667 - -# csv_query_correlation -query R -SELECT corr(c2, c12) FROM aggregate_test_100 ----- --0.190645441906 - -# single_row_query_correlation -query R -select corr(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq ----- -0 - -# all_nulls_query_correlation -query R -with data as ( - select null::int as f, null::int as b - union all - select null::int as f, null::int as b -) -select corr(f, b) -from data ----- -NULL - -# correlation_query_with_nulls -query R -with data as ( - select 1 as f, 4 as b - union all - select null as f, 99 as b - union all - select 2 as f, 5 as b - union all - select 98 as f, null as b - union all - select 3 as f, 6 as b - union all - select null as f, null as b -) -select corr(f, b) -from data ----- -1 - -# csv_query_variance_1 -query R -SELECT var_pop(c2) FROM aggregate_test_100 ----- -1.8675 - -# csv_query_variance_2 -query R -SELECT var_pop(c6) FROM aggregate_test_100 ----- -26156334342021890000000000000000000000 - -# csv_query_variance_3 -query R -SELECT var_pop(c12) FROM aggregate_test_100 ----- -0.092342237216 - -# csv_query_variance_4 -query R -SELECT var(c2) FROM aggregate_test_100 ----- -1.886363636364 - -# csv_query_variance_5 -query R -SELECT var_samp(c2) FROM aggregate_test_100 ----- -1.886363636364 - -# csv_query_stddev_1 -query R -SELECT stddev_pop(c2) FROM aggregate_test_100 ----- -1.366565036872 - -# csv_query_stddev_2 -query R -SELECT stddev_pop(c6) FROM aggregate_test_100 ----- -5114326382039172000 - -# csv_query_stddev_3 -query R -SELECT stddev_pop(c12) FROM aggregate_test_100 ----- -0.303878655413 - -# csv_query_stddev_4 -query R -SELECT stddev(c12) FROM aggregate_test_100 ----- -0.305409539941 - -# csv_query_stddev_5 -query R -SELECT stddev_samp(c12) FROM aggregate_test_100 ----- -0.305409539941 - -# csv_query_stddev_6 -query R -select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq ----- -0.950438495292 - -# csv_query_approx_median_1 -query I -SELECT approx_median(c2) FROM aggregate_test_100 ----- -3 - -# csv_query_approx_median_2 -query I -SELECT approx_median(c6) FROM aggregate_test_100 ----- -1146409980542786560 - -# csv_query_approx_median_3 -query R -SELECT approx_median(c12) FROM aggregate_test_100 ----- -0.555006541052 - -# csv_query_median_1 -query I -SELECT median(c2) FROM aggregate_test_100 ----- -3 - -# csv_query_median_2 -query I -SELECT median(c6) FROM aggregate_test_100 ----- -1125553990140691277 - -# csv_query_median_3 -query R -SELECT median(c12) FROM aggregate_test_100 ----- -0.551390054439 - -# median_i8 -query I -SELECT median(a) FROM median_i8 ----- --14 - -# median_i16 -query I -SELECT median(a) FROM median_i16 ----- --16334 - -# median_i32 -query I -SELECT median(a) FROM median_i32 ----- --1073741774 - -# median_i64 -query I -SELECT median(a) FROM median_i64 ----- --4611686018427387854 - -# median_u8 -query I -SELECT median(a) FROM median_u8 ----- -50 - -# median_u16 -query I -SELECT median(a) FROM median_u16 ----- -50 - -# median_u32 -query I -SELECT median(a) FROM median_u32 ----- -50 - -# median_u64 -query I -SELECT median(a) FROM median_u64 ----- -50 - -# median_f32 -query R -SELECT median(a) FROM median_f32 ----- -3.3 - -# median_f64 -query R -SELECT median(a) FROM median_f64 ----- -3.3 - -# median_f64_nan -query R -SELECT median(a) FROM median_f64_nan ----- -NaN - -# approx_median_f64_nan -query R -SELECT approx_median(a) FROM median_f64_nan ----- -NaN - -# median_multi -# test case for https://github.com/apache/arrow-datafusion/issues/3105 -# has an intermediate grouping -statement ok -create table cpu (host string, usage float) as select * from (values -('host0', 90.1), -('host1', 90.2), -('host1', 90.4) -); - -query TR rowsort -select host, median(usage) from cpu group by host; ----- -host0 90.1 -host1 90.3 - -statement ok -drop table cpu; - -# this test is to show create table as and select into works in the same way -statement ok -SELECT * INTO cpu -FROM (VALUES - ('host0', 90.1), - ('host1', 90.2), - ('host1', 90.4) - ) AS cpu (host, usage); - -query TR rowsort -select host, median(usage) from cpu group by host; ----- -host0 90.1 -host1 90.3 - -query R -select median(usage) from cpu; ----- -90.2 - -statement ok -drop table cpu; - -# median_multi_odd - -# data is not sorted and has an odd number of values per group -statement ok -create table cpu (host string, usage float) as select * from (values - ('host0', 90.2), - ('host1', 90.1), - ('host1', 90.5), - ('host0', 90.5), - ('host1', 90.0), - ('host1', 90.3), - ('host0', 87.9), - ('host1', 89.3) -); - -query TR rowsort -select host, median(usage) from cpu group by host; ----- -host0 90.2 -host1 90.1 - - -statement ok -drop table cpu; - -# median_multi_even -# data is not sorted and has an odd number of values per group -statement ok -create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3)); - -query TR rowsort -select host, median(usage) from cpu group by host; ----- -host0 90.35 -host1 90.25 - -statement ok -drop table cpu - -# csv_query_external_table_count -query I -SELECT COUNT(c12) FROM aggregate_test_100 ----- -100 - -# csv_query_external_table_sum -query II -SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100 ----- -13060 3017641 - -# csv_query_count -query I -SELECT count(c12) FROM aggregate_test_100 ----- -100 - -# csv_query_count_distinct -query I -SELECT count(distinct c2) FROM aggregate_test_100 ----- -5 - -# csv_query_count_distinct_expr -query I -SELECT count(distinct c2 % 2) FROM aggregate_test_100 ----- -2 - -# csv_query_count_star -query I -SELECT COUNT(*) FROM aggregate_test_100 ----- -100 - -# csv_query_count_literal -query I -SELECT COUNT(2) FROM aggregate_test_100 ----- -100 - -# csv_query_approx_count -# FIX: https://github.com/apache/arrow-datafusion/issues/3353 -# query II -# SELECT approx_distinct(c9) AS count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 -# ---- -# 100 99 - -# csv_query_approx_count_dupe_expr_aliased -query II -SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_100 ----- -100 100 - -## This test executes the APPROX_PERCENTILE_CONT aggregation against the test -## data, asserting the estimated quantiles are ±5% their actual values. -## -## Actual quantiles calculated with: -## -## ```r -## read_csv("./testing/data/csv/aggregate_test_100.csv") |> -## select_if(is.numeric) |> -## summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) -## ``` -## -## Giving: -## -## ```text -## c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 -## -## 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 -## 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 -## 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 -## ``` -## -## Column `c12` is omitted due to a large relative error (~10%) due to the small -## float values. - -#csv_query_approx_percentile_cont (c2) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c3) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.1) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.5) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.9) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c4) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.1) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.5) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.9) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c5) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.1) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.9) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c6) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.1) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.5) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.9) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c7) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.1) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.5) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.9) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c8) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.1) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.5) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.9) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c9) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.1) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.5) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c10) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.1) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.5) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.9) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_approx_percentile_cont (c11) -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.1) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.5) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 ----- -true - -query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 ----- -true - -# csv_query_cube_avg -query TIR -SELECT c1, c2, AVG(c3) FROM aggregate_test_100_by_sql GROUP BY CUBE (c1, c2) ORDER BY c1, c2 ----- -a 1 -17.6 -a 2 -15.333333333333 -a 3 -4.5 -a 4 -32 -a 5 -32 -a NULL -18.333333333333 -b 1 31.666666666667 -b 2 25.5 -b 3 -42 -b 4 -44.6 -b 5 -0.2 -b NULL -5.842105263158 -c 1 47.5 -c 2 -55.571428571429 -c 3 47.5 -c 4 -10.75 -c 5 12 -c NULL -1.333333333333 -d 1 -8.142857142857 -d 2 109.333333333333 -d 3 41.333333333333 -d 4 54 -d 5 -49.5 -d NULL 25.444444444444 -e 1 75.666666666667 -e 2 37.8 -e 3 48 -e 4 37.285714285714 -e 5 -11 -e NULL 40.333333333333 -NULL 1 16.681818181818 -NULL 2 8.363636363636 -NULL 3 20.789473684211 -NULL 4 1.260869565217 -NULL 5 -13.857142857143 -NULL NULL 7.81 - -# csv_query_rollup_avg -query TIIR -SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100_by_sql WHERE c1 IN ('a', 'b', NULL) GROUP BY ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3 ----- -a 1 -85 -15154 -a 1 -56 8692 -a 1 -25 15295 -a 1 -5 12636 -a 1 83 -14704 -a 1 NULL 1353 -a 2 -48 -18025 -a 2 -43 13080 -a 2 45 15673 -a 2 NULL 3576 -a 3 -72 -11122 -a 3 -12 -9168 -a 3 13 22338.5 -a 3 14 28162 -a 3 17 -22796 -a 3 NULL 4958.833333333333 -a 4 -101 11640 -a 4 -54 -2376 -a 4 -38 20744 -a 4 65 -28462 -a 4 NULL 386.5 -a 5 -101 -12484 -a 5 -31 -12907 -a 5 36 -16974 -a 5 NULL -14121.666666666666 -a NULL NULL 306.047619047619 -b 1 12 7652 -b 1 29 -18218 -b 1 54 -18410 -b 1 NULL -9658.666666666666 -b 2 -60 -21739 -b 2 31 23127 -b 2 63 21456 -b 2 68 15874 -b 2 NULL 9679.5 -b 3 -101 -13217 -b 3 17 14457 -b 3 NULL 620 -b 4 -117 19316 -b 4 -111 -1967 -b 4 -59 25286 -b 4 17 -28070 -b 4 47 20690 -b 4 NULL 7051 -b 5 -82 22080 -b 5 -44 15788 -b 5 -5 24896 -b 5 62 16337 -b 5 68 21576 -b 5 NULL 20135.4 -b NULL NULL 7732.315789473684 -NULL NULL NULL 3833.525 - -# csv_query_groupingsets_avg -query TIIR -SELECT c1, c2, c3, AVG(c4) -FROM aggregate_test_100_by_sql -WHERE c1 IN ('a', 'b', NULL) -GROUP BY GROUPING SETS ((c1), (c1,c2), (c1,c2,c3)) -ORDER BY c1, c2, c3 ----- -a 1 -85 -15154 -a 1 -56 8692 -a 1 -25 15295 -a 1 -5 12636 -a 1 83 -14704 -a 1 NULL 1353 -a 2 -48 -18025 -a 2 -43 13080 -a 2 45 15673 -a 2 NULL 3576 -a 3 -72 -11122 -a 3 -12 -9168 -a 3 13 22338.5 -a 3 14 28162 -a 3 17 -22796 -a 3 NULL 4958.833333333333 -a 4 -101 11640 -a 4 -54 -2376 -a 4 -38 20744 -a 4 65 -28462 -a 4 NULL 386.5 -a 5 -101 -12484 -a 5 -31 -12907 -a 5 36 -16974 -a 5 NULL -14121.666666666666 -a NULL NULL 306.047619047619 -b 1 12 7652 -b 1 29 -18218 -b 1 54 -18410 -b 1 NULL -9658.666666666666 -b 2 -60 -21739 -b 2 31 23127 -b 2 63 21456 -b 2 68 15874 -b 2 NULL 9679.5 -b 3 -101 -13217 -b 3 17 14457 -b 3 NULL 620 -b 4 -117 19316 -b 4 -111 -1967 -b 4 -59 25286 -b 4 17 -28070 -b 4 47 20690 -b 4 NULL 7051 -b 5 -82 22080 -b 5 -44 15788 -b 5 -5 24896 -b 5 62 16337 -b 5 68 21576 -b 5 NULL 20135.4 -b NULL NULL 7732.315789473684 - -# csv_query_singlecol_with_rollup_avg -query TIIR -SELECT c1, c2, c3, AVG(c4) -FROM aggregate_test_100_by_sql -WHERE c1 IN ('a', 'b', NULL) -GROUP BY c1, ROLLUP (c2, c3) -ORDER BY c1, c2, c3 ----- -a 1 -85 -15154 -a 1 -56 8692 -a 1 -25 15295 -a 1 -5 12636 -a 1 83 -14704 -a 1 NULL 1353 -a 2 -48 -18025 -a 2 -43 13080 -a 2 45 15673 -a 2 NULL 3576 -a 3 -72 -11122 -a 3 -12 -9168 -a 3 13 22338.5 -a 3 14 28162 -a 3 17 -22796 -a 3 NULL 4958.833333333333 -a 4 -101 11640 -a 4 -54 -2376 -a 4 -38 20744 -a 4 65 -28462 -a 4 NULL 386.5 -a 5 -101 -12484 -a 5 -31 -12907 -a 5 36 -16974 -a 5 NULL -14121.666666666666 -a NULL NULL 306.047619047619 -b 1 12 7652 -b 1 29 -18218 -b 1 54 -18410 -b 1 NULL -9658.666666666666 -b 2 -60 -21739 -b 2 31 23127 -b 2 63 21456 -b 2 68 15874 -b 2 NULL 9679.5 -b 3 -101 -13217 -b 3 17 14457 -b 3 NULL 620 -b 4 -117 19316 -b 4 -111 -1967 -b 4 -59 25286 -b 4 17 -28070 -b 4 47 20690 -b 4 NULL 7051 -b 5 -82 22080 -b 5 -44 15788 -b 5 -5 24896 -b 5 62 16337 -b 5 68 21576 -b 5 NULL 20135.4 -b NULL NULL 7732.315789473684 - -# csv_query_approx_percentile_cont_with_weight -query TI -SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----- -a 73 -b 68 -c 122 -d 124 -e 115 - -# csv_query_approx_percentile_cont_with_weight (2) -query TI -SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----- -a 73 -b 68 -c 122 -d 124 -e 115 - -# csv_query_approx_percentile_cont_with_histogram_bins -query TI -SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----- -a 73 -b 68 -c 122 -d 124 -e 115 - -query TI -SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----- -a 74 -b 68 -c 123 -d 124 -e 115 - -# csv_query_sum_crossjoin -query TTI -SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 ----- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 - -# csv_query_cube_sum_crossjoin -query TTI -SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1 ----- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 -NULL a 5985 -NULL b 5415 -NULL c 5985 -NULL d 5130 -NULL e 5985 -NULL NULL 28500 - -# csv_query_cube_distinct_count -query TII -SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2 ----- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 -NULL 1 22 -NULL 2 20 -NULL 3 17 -NULL 4 23 -NULL 5 14 -NULL NULL 80 - -# csv_query_rollup_distinct_count -query TII -SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2 ----- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 -NULL NULL 80 - -# csv_query_rollup_sum_crossjoin -query TTI -SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1 ----- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 -NULL NULL 28500 - -# query_count_without_from -query I -SELECT count(1 + 1) ----- -1 - -# csv_query_array_agg -query ? -SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test ----- -[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] - -# csv_query_array_agg_empty -query ? -SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test ----- -[] - -# csv_query_array_agg_one -query ? -SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test ----- -[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] - -# csv_query_array_agg_with_overflow -query IIRIII -select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by c2 order by c2 ----- -1 367 16.681818181818 125 -99 22 -2 184 8.363636363636 122 -117 22 -3 395 20.789473684211 123 -101 19 -4 29 1.260869565217 123 -117 23 -5 -194 -13.857142857143 118 -101 14 - -# csv_query_array_agg_unsupported -statement error This feature is not implemented: Order-sensitive aggregators is not supported on multiple partitions -SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100; - -# csv_query_array_cube_agg_with_overflow -query TIIRIII -select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2 ----- -a 1 -88 -17.6 83 -85 5 -a 2 -46 -15.333333333333 45 -48 3 -a 3 -27 -4.5 17 -72 6 -a 4 -128 -32 65 -101 4 -a 5 -96 -32 36 -101 3 -a NULL -385 -18.333333333333 83 -101 21 -b 1 95 31.666666666667 54 12 3 -b 2 102 25.5 68 -60 4 -b 3 -84 -42 17 -101 2 -b 4 -223 -44.6 47 -117 5 -b 5 -1 -0.2 68 -82 5 -b NULL -111 -5.842105263158 68 -117 19 -c 1 190 47.5 103 -24 4 -c 2 -389 -55.571428571429 29 -117 7 -c 3 190 47.5 97 -2 4 -c 4 -43 -10.75 123 -90 4 -c 5 24 12 118 -94 2 -c NULL -28 -1.333333333333 123 -117 21 -d 1 -57 -8.142857142857 125 -99 7 -d 2 328 109.333333333333 122 93 3 -d 3 124 41.333333333333 123 -76 3 -d 4 162 54 102 5 3 -d 5 -99 -49.5 -40 -59 2 -d NULL 458 25.444444444444 125 -99 18 -e 1 227 75.666666666667 120 36 3 -e 2 189 37.8 97 -61 5 -e 3 192 48 112 -95 4 -e 4 261 37.285714285714 97 -56 7 -e 5 -22 -11 64 -86 2 -e NULL 847 40.333333333333 120 -95 21 -NULL 1 367 16.681818181818 125 -99 22 -NULL 2 184 8.363636363636 122 -117 22 -NULL 3 395 20.789473684211 123 -101 19 -NULL 4 29 1.260869565217 123 -117 23 -NULL 5 -194 -13.857142857143 118 -101 14 -NULL NULL 781 7.81 125 -117 100 - -# TODO this querys output is non determinisitic (the order of the elements -# differs run to run -# -# csv_query_array_agg_distinct -# query T -# SELECT array_agg(distinct c2) FROM aggregate_test_100 -# ---- -# [4, 2, 3, 5, 1] - -# aggregate_time_min_and_max -query TT -select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' union select '00:00:02') ----- -00:00:00 00:00:02 - -# aggregate_decimal_min -query RT -select min(c1), arrow_typeof(min(c1)) from d_table ----- --100.009 Decimal128(10, 3) - -# aggregate_decimal_max -query RT -select max(c1), arrow_typeof(max(c1)) from d_table ----- -110.009 Decimal128(10, 3) - -# aggregate_decimal_sum -query RT -select sum(c1), arrow_typeof(sum(c1)) from d_table ----- -100 Decimal128(20, 3) - -# aggregate_decimal_avg -query RT -select avg(c1), arrow_typeof(avg(c1)) from d_table ----- -5 Decimal128(14, 7) - -# FIX: different test table -# aggregate -# query I -# SELECT SUM(c1), SUM(c2) FROM test -# ---- -# 60 220 - -# TODO: aggregate_empty - -# TODO: aggregate_avg - -# TODO: aggregate_max - -# TODO: aggregate_min - -# TODO: aggregate_grouped - -# TODO: aggregate_grouped_avg - -# TODO: aggregate_grouped_empty - -# TODO: aggregate_grouped_max - -# TODO: aggregate_grouped_min - -# TODO: aggregate_avg_add - -# TODO: case_sensitive_identifiers_aggregates - -# TODO: count_basic - -# TODO: count_partitioned - -# TODO: count_aggregated - -# TODO: count_aggregated_cube - -# TODO: simple_avg - -# TODO: simple_mean - -# query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) -query RI -SELECT AVG(c1), SUM(DISTINCT c2) FROM test ----- -1.75 3 - -# query_sum_distinct - 2 sum(distinct) functions -query II -SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test ----- -4 3 - -# # query_count_distinct -query I -SELECT COUNT(DISTINCT c1) FROM test ----- -3 - -# TODO: count_distinct_integers_aggregated_single_partition - -# TODO: count_distinct_integers_aggregated_multiple_partitions - -# TODO: aggregate_with_alias - -# array_agg_zero -query ? -SELECT ARRAY_AGG([]) ----- -[] - -# array_agg_one -query ? -SELECT ARRAY_AGG([1]) ----- -[[1]] - -# test_approx_percentile_cont_decimal_support -query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ----- -a 4 -b 5 -c 4 -d 4 -e 4 - - -# array_agg_zero -query ? -SELECT ARRAY_AGG([]); ----- -[] - -# array_agg_one -query ? -SELECT ARRAY_AGG([1]); ----- -[[1]] - -# variance_single_value -query RRRR -select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; ----- -NULL 0 NULL 0 - -# variance_two_values -query RRRR -select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq; ----- -2 1 1.414213562373 1 - - -# sum / count for all nulls -statement ok -create table the_nulls as values (null::bigint, 1), (null::bigint, 1), (null::bigint, 2); - -# counts should be zeros (even for nulls) -query II -SELECT count(column1), column2 from the_nulls group by column2 order by column2; ----- -0 1 -0 2 - -# sums should be null -query II -SELECT sum(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# avg should be null -query RI -SELECT avg(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# bit_and should be null -query II -SELECT bit_and(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# bit_or should be null -query II -SELECT bit_or(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# bit_xor should be null -query II -SELECT bit_xor(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# min should be null -query II -SELECT min(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - -# max should be null -query II -SELECT max(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 - - -statement ok -drop table the_nulls; - -# All supported timestamp types - -# "nanos" --> TimestampNanosecondArray -# "micros" --> TimestampMicrosecondArray -# "millis" --> TimestampMillisecondArray -# "secs" --> TimestampSecondArray -# "names" --> StringArray - -statement ok -create table t_source -as values - ('2018-11-13T17:11:10.011375885995', 'Row 0'), - ('2011-12-13T11:13:10.12345', 'Row 1'), - (null, 'Row 2'), - ('2021-01-01T05:11:10.432', 'Row 3'); - -statement ok -create table bit_aggregate_functions ( - c1 SMALLINT NOT NULL, - c2 SMALLINT NOT NULL, - c3 SMALLINT, -) -as values - (5, 10, 11), - (33, 11, null), - (9, 12, null); - -# query_bit_and -query III -SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions ----- -1 8 11 - -# query_bit_or -query III -SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions ----- -45 15 11 - -# query_bit_xor -query III -SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions ----- -45 13 11 - -statement ok -create table bool_aggregate_functions ( - c1 boolean not null, - c2 boolean not null, - c3 boolean not null, - c4 boolean not null, - c5 boolean, - c6 boolean, - c7 boolean, - c8 boolean, -) -as values - (true, true, false, false, true, true, null, null), - (true, false, true, false, false, null, false, null), - (true, true, false, false, null, true, false, null); - -# query_bool_and -query BBBBBBBB -SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions ----- -true false false false false true false NULL - -# query_bool_and_distinct -query BBBBBBBB -SELECT bool_and(distinct c1), bool_and(distinct c2), bool_and(distinct c3), bool_and(distinct c4), bool_and(distinct c5), bool_and(distinct c6), bool_and(distinct c7), bool_and(distinct c8) FROM bool_aggregate_functions ----- -true false false false false true false NULL - -# query_bool_or -query BBBBBBBB -SELECT bool_or(c1), bool_or(c2), bool_or(c3), bool_or(c4), bool_or(c5), bool_or(c6), bool_or(c7), bool_or(c8) FROM bool_aggregate_functions ----- -true true true false true true false NULL - -# query_bool_or_distinct -query BBBBBBBB -SELECT bool_or(distinct c1), bool_or(distinct c2), bool_or(distinct c3), bool_or(distinct c4), bool_or(distinct c5), bool_or(distinct c6), bool_or(distinct c7), bool_or(distinct c8) FROM bool_aggregate_functions ----- -true true true false true true false NULL - -statement ok -create table t as -select - arrow_cast(column1, 'Timestamp(Nanosecond, None)') as nanos, - arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, - arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, - arrow_cast(column1, 'Timestamp(Second, None)') as secs, - column2 as names -from t_source; - -# Demonstate the contents -query PPPPT -select * from t; ----- -2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 -2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 -NULL NULL NULL NULL Row 2 -2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 - - -# aggregate_timestamps_sum -statement error Error during planning: The function Sum does not support inputs of type Timestamp\(Nanosecond, None\) -SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t; - -# aggregate_timestamps_count -query IIII -SELECT count(nanos), count(micros), count(millis), count(secs) FROM t; ----- -3 3 3 3 - - -# aggregate_timestamps_min -query PPPP -SELECT min(nanos), min(micros), min(millis), min(secs) FROM t; ----- -2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 - -# aggregate_timestamps_max -query PPPP -SELECT max(nanos), max(micros), max(millis), max(secs) FROM t; ----- -2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 - - - -# aggregate_timestamps_avg -statement error Error during planning: The function Avg does not support inputs of type Timestamp\(Nanosecond, None\). -SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t - - -statement ok -drop table t_source; - -statement ok -drop table t; - -# All supported time types - -# Columns are named: -# "nanos" --> Time64NanosecondArray -# "micros" --> Time64MicrosecondArray -# "millis" --> Time32MillisecondArray -# "secs" --> Time32SecondArray -# "names" --> StringArray - -statement ok -create table t_source -as values - ('18:06:30.243620451', 'Row 0'), - ('20:08:28.161121654', 'Row 1'), - ('19:11:04.156423842', 'Row 2'), - ('21:06:28.247821084', 'Row 3'); - - -statement ok -create table t as -select - arrow_cast(column1, 'Time64(Nanosecond)') as nanos, - arrow_cast(column1, 'Time64(Microsecond)') as micros, - arrow_cast(column1, 'Time32(Millisecond)') as millis, - arrow_cast(column1, 'Time32(Second)') as secs, - column2 as names -from t_source; - -# Demonstate the contents -query DDDDT -select * from t; ----- -18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 Row 0 -20:08:28.161121654 20:08:28.161121 20:08:28.161 20:08:28 Row 1 -19:11:04.156423842 19:11:04.156423 19:11:04.156 19:11:04 Row 2 -21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 - -# aggregate_times_sum -statement error DataFusion error: Error during planning: The function Sum does not support inputs of type Time64\(Nanosecond\). -SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t - -# aggregate_times_count -query IIII -SELECT count(nanos), count(micros), count(millis), count(secs) FROM t ----- -4 4 4 4 - - -# aggregate_times_min -query DDDD -SELECT min(nanos), min(micros), min(millis), min(secs) FROM t ----- -18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 - -# aggregate_times_max -query DDDD -SELECT max(nanos), max(micros), max(millis), max(secs) FROM t ----- -21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 - - -# aggregate_times_avg -statement error DataFusion error: Error during planning: The function Avg does not support inputs of type Time64\(Nanosecond\). -SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t - -statement ok -drop table t_source; - -statement ok -drop table t; - -query I -select median(a) from (select 1 as a where 1=0); ----- -NULL - -query error DataFusion error: Execution error: aggregate function needs at least one non-null element -select approx_median(a) from (select 1 as a where 1=0); - - -# aggregate_decimal_sum -query RT -select sum(c1), arrow_typeof(sum(c1)) from d_table; ----- -100 Decimal128(20, 3) - - -# aggregate_decimal_avg -query RT -select avg(c1), arrow_typeof(avg(c1)) from d_table ----- -5 Decimal128(14, 7) - -# Use PostgresSQL dialect -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - -# Creating the table -statement ok -CREATE TABLE test_table (c1 INT, c2 INT, c3 INT) - -# Inserting data -statement ok -INSERT INTO test_table VALUES (1, 10, 50), (1, 20, 60), (2, 10, 70), (2, 20, 80), (3, 10, NULL) - -# query_group_by_with_filter -query II rowsort -SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test_table GROUP BY c1 ----- -1 20 -2 20 -3 NULL - -# query_group_by_avg_with_filter -query IR rowsort -SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test_table GROUP BY c1 ----- -1 20 -2 20 -3 NULL - -# query_group_by_with_multiple_filters -query IIR rowsort -SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3 FROM test_table GROUP BY c1 ----- -1 20 55 -2 20 70 -3 NULL NULL - -# query_group_by_distinct_with_filter -query II rowsort -SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count FROM test_table GROUP BY c1 ----- -1 1 -2 1 -3 0 - -# query_without_group_by_with_filter -query I rowsort -SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test_table ----- -40 - -# count_without_group_by_with_filter -query I rowsort -SELECT COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2 FROM test_table ----- -2 - -# query_with_and_without_filter -query III rowsort -SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result, SUM(c2) as result_no_filter FROM test_table GROUP BY c1; ----- -1 20 30 -2 20 30 -3 NULL 10 - -# query_filter_on_different_column_than_aggregate -query I rowsort -select sum(c1) FILTER (WHERE c2 < 30) from test_table; ----- -9 - -# query_test_empty_filter -query I rowsort -SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table; ----- -NULL - -# Creating the decimal table -statement ok -CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1)) - -# Inserting data -statement ok -INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL) - -# aggregate_decimal_with_group_by -query IIRRRRIIR rowsort -select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1 ----- -1 2 15.15 30.3 10.1 20.2 2 0 NULL -2 2 15.15 30.3 10.1 20.2 2 0 NULL -3 2 10.1 20.2 10.1 10.1 1 0 NULL - -# aggregate_decimal_with_group_by_decimal -query RIRRRRIR rowsort -select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3 ----- -100.1 2 10.1 20.2 10.1 10.1 0 NULL -200.2 1 20.2 20.2 20.2 20.2 0 NULL -700.1 2 15.15 30.3 10.1 20.2 0 NULL -NULL 1 10.1 10.1 10.1 10.1 0 NULL - -# Restore the default dialect -statement ok -set datafusion.sql_parser.dialect = 'Generic'; - -# Prepare the table with dictionary values for testing -statement ok -CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2); - -statement ok -CREATE TABLE value_dict AS SELECT arrow_cast(x, 'Dictionary(Int64, Int32)') AS x_dict FROM value; - -query ? -select x_dict from value_dict; ----- -1 -2 -3 -1 -3 -4 -5 -2 - -query I -select sum(x_dict) from value_dict; ----- -21 - -query R -select avg(x_dict) from value_dict; ----- -2.625 - -query I -select min(x_dict) from value_dict; ----- -1 - -query I -select max(x_dict) from value_dict; ----- -5 - -query I -select sum(x_dict) from value_dict where x_dict > 3; ----- -9 - -query R -select avg(x_dict) from value_dict where x_dict > 3; ----- -4.5 - -query I -select min(x_dict) from value_dict where x_dict > 3; ----- -4 - -query I -select max(x_dict) from value_dict where x_dict > 3; ----- -5 - -query I -select sum(x_dict) from value_dict group by x_dict % 2 order by sum(x_dict); ----- -8 -13 - -query R -select avg(x_dict) from value_dict group by x_dict % 2 order by avg(x_dict); ----- -2.6 -2.666666666667 - -query I -select min(x_dict) from value_dict group by x_dict % 2 order by min(x_dict); ----- -1 -2 - -query I -select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); ----- -4 -5 - -# bool aggregtion -statement ok -CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3); - -query B -select min(x) from value_bool; ----- -false - -query B -select max(x) from value_bool; ----- -true - -query B -select min(x) from value_bool group by g order by g; ----- -false -false -true -NULL - -query B -select max(x) from value_bool group by g order by g; ----- -true -false -true -NULL diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt deleted file mode 100644 index df9edce0b1df8..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ /dev/null @@ -1,206 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -############# -## Array expressions Tests -############# - -# array scalar function #1 -query ??? rowsort -select make_array(1, 2, 3), make_array(1.0, 2.0, 3.0), make_array('h', 'e', 'l', 'l', 'o'); ----- -[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] - -# array scalar function #2 -query ??? rowsort -select make_array(1, 2, 3), make_array(make_array(1, 2), make_array(3, 4)), make_array([[[[1], [2]]]]); ----- -[1, 2, 3] [[1, 2], [3, 4]] [[[[[1], [2]]]]] - -# array scalar function #3 -query ?? rowsort -select make_array([1, 2, 3], [4, 5, 6], [7, 8, 9]), make_array([[1, 2], [3, 4]], [[5, 6], [7, 8]]); ----- -[[1, 2, 3], [4, 5, 6], [7, 8, 9]] [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] - -# array scalar function #4 -query ?? rowsort -select make_array([1.0, 2.0], [3.0, 4.0]), make_array('h', 'e', 'l', 'l', 'o'); ----- -[[1.0, 2.0], [3.0, 4.0]] [h, e, l, l, o] - -# array scalar function #5 -query ? rowsort -select make_array(make_array(make_array(make_array(1, 2, 3), make_array(4, 5, 6)), make_array(make_array(7, 8, 9), make_array(10, 11, 12)))) ----- -[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] - -# array_append scalar function -query ??? rowsort -select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); ----- -[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] - -# array_prepend scalar function -query ??? rowsort -select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, 3.0, 4.0)), array_prepend('h', make_array('e', 'l', 'l', 'o')); ----- -[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] - -# array_fill scalar function #1 -query ??? rowsort -select array_fill(11, make_array(1, 2, 3)), array_fill(3, make_array(2, 3)), array_fill(2, make_array(2)); ----- -[[[11, 11, 11], [11, 11, 11]]] [[3, 3, 3], [3, 3, 3]] [2, 2] - -# array_fill scalar function #2 -query ?? rowsort -select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, 2)); ----- -[[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]] - -# array_concat scalar function #1 -query ?? rowsort -select array_concat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), array_concat(make_array([1], [2]), make_array([3], [4])); ----- -[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] - -# array_concat scalar function #2 -query ? rowsort -select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array(5, 6), make_array(7, 8))); ----- -[[1, 2], [3, 4], [5, 6], [7, 8]] - -# array_concat scalar function #3 -query ? rowsort -select array_concat(make_array([1], [2], [3]), make_array([4], [5], [6]), make_array([7], [8], [9])); ----- -[[1], [2], [3], [4], [5], [6], [7], [8], [9]] - -# array_concat scalar function #4 -query ? rowsort -select array_concat(make_array([[1]]), make_array([[2]])); ----- -[[[1]], [[2]]] - -# array_position scalar function #1 -query III -select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, 4, 5], 5), array_position([1, 1, 1], 1); ----- -3 5 1 - -# array_position scalar function #2 -query III -select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ----- -4 5 2 - -# array_positions scalar function -query III -select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); ----- -[3, 4] [5] [1, 2, 3] - -# array_replace scalar function -query ??? -select array_replace(make_array(1, 2, 3, 4), 2, 3), array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), array_replace(make_array(1, 2, 3), 4, 0); ----- -[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] - -# array_to_string scalar function -query ??? -select array_to_string(['h', 'e', 'l', 'l', 'o'], ','), array_to_string([1, 2, 3, 4, 5], '-'), array_to_string([1.0, 2.0, 3.0], '|'); ----- -h,e,l,l,o 1-2-3-4-5 1|2|3 - -# array_to_string scalar function #2 -query ??? -select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_fill(3, [3, 2, 2]), '/\'); ----- -11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3 - -# cardinality scalar function -query III -select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinality(make_array('h', 'e', 'l', 'l', 'o')); ----- -5 3 5 - -# cardinality scalar function #2 -query II -select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3, array[3, 2, 3])); ----- -6 18 - -# trim_array scalar function -query ??? -select trim_array(make_array(1, 2, 3, 4, 5), 2), trim_array(['h', 'e', 'l', 'l', 'o'], 3), trim_array([1.0, 2.0, 3.0], 2); ----- -[1, 2, 3] [h, e] [1.0] - -# trim_array scalar function #2 -query ?? -select trim_array([[1, 2], [3, 4], [5, 6]], 2), trim_array(array_fill(4, [3, 4, 2]), 2); ----- -[[1, 2]] [[[4, 4], [4, 4], [4, 4], [4, 4]]] - -# array_length scalar function -query III rowsort -select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3)), array_length(make_array([1, 2], [3, 4], [5, 6])); ----- -5 3 3 - -# array_length scalar function #2 -query III rowsort -select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ----- -5 3 3 - -# array_length scalar function #3 -query III rowsort -select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ----- -NULL NULL 2 - -# array_length scalar function #4 -query IIII rowsort -select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, 2, 5]), 2), array_length(array_fill(3, [3, 2, 5]), 3), array_length(array_fill(3, [3, 2, 5]), 4); ----- -3 2 5 NULL - -# array_dims scalar function -query III rowsort -select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), array_dims(make_array([[[[1], [2]]]])); ----- -[3] [2, 2] [1, 1, 1, 2, 1] - -# array_dims scalar function #2 -query II rowsort -select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4])); ----- -[1, 2, 3] [2, 5, 4] - -# array_ndims scalar function -query III rowsort -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); ----- -1 2 5 - -# array_ndims scalar function #2 -query II rowsort -select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ----- -3 21 diff --git a/datafusion/core/tests/sqllogictests/test_files/avro.slt b/datafusion/core/tests/sqllogictests/test_files/avro.slt deleted file mode 100644 index 5a01ae72cb304..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/avro.slt +++ /dev/null @@ -1,97 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -statement ok -CREATE EXTERNAL TABLE alltypes_plain ( - id INT NOT NULL, - bool_col BOOLEAN NOT NULL, - tinyint_col TINYINT NOT NULL, - smallint_col SMALLINT NOT NULL, - int_col INT NOT NULL, - bigint_col BIGINT NOT NULL, - float_col FLOAT NOT NULL, - double_col DOUBLE NOT NULL, - date_string_col BYTEA NOT NULL, - string_col VARCHAR NOT NULL, - timestamp_col TIMESTAMP NOT NULL, -) -STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.avro' - -statement ok -CREATE EXTERNAL TABLE single_nan ( - mycol FLOAT -) -STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/single_nan.avro' - -# test avro query -query IT -SELECT id, CAST(string_col AS varchar) FROM alltypes_plain ----- -4 0 -5 1 -6 0 -7 1 -2 0 -3 1 -0 0 -1 1 - -# test avro single nan schema -query R -SELECT mycol FROM single_nan ----- -NULL - -# test avro query multi files -query IT -SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files ----- -4 0 -5 1 -6 0 -7 1 -2 0 -3 1 -0 0 -1 1 -4 0 -5 1 -6 0 -7 1 -2 0 -3 1 -0 0 -1 1 - -# test avro explain -query TT -EXPLAIN SELECT count(*) from alltypes_plain ----- -logical_plan -Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] ---TableScan: alltypes_plain projection=[id] -physical_plan -AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))] ---CoalescePartitionsExec -----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))] -------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]}, projection=[id] diff --git a/datafusion/core/tests/sqllogictests/test_files/copy.slt b/datafusion/core/tests/sqllogictests/test_files/copy.slt deleted file mode 100644 index e7bde89d2940c..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/copy.slt +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# tests for copy command - -statement ok -create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, 'Bar'); - -# Copy from table -statement error DataFusion error: This feature is not implemented: `COPY \.\. TO \.\.` statement is not yet supported -COPY source_table to '/tmp/table.parquet'; - -# Copy from table with options -statement error DataFusion error: This feature is not implemented: `COPY \.\. TO \.\.` statement is not yet supported -COPY source_table to '/tmp/table.parquet' (row_group_size 55); - -# Copy from table with options (and trailing comma) -statement error DataFusion error: This feature is not implemented: `COPY \.\. TO \.\.` statement is not yet supported -COPY source_table to '/tmp/table.parquet' (row_group_size 55, row_group_limit_bytes 9,); - - -# Error cases: - -# Incomplete statement -statement error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) -COPY (select col2, sum(col1) from source_table - -# Copy from table with non literal -statement error DataFusion error: SQL error: ParserError\("Expected ',' or '\)' after option definition, found: \+"\) -COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); diff --git a/datafusion/core/tests/sqllogictests/test_files/errors.slt b/datafusion/core/tests/sqllogictests/test_files/errors.slt deleted file mode 100644 index 938209d21cc8e..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/errors.slt +++ /dev/null @@ -1,74 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -# create aggregate_test_100 table -statement ok -CREATE EXTERNAL TABLE aggregate_test_100 ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT, - c5 INT, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 BIGINT UNSIGNED NOT NULL, - c10 VARCHAR NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL -) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv' - -# csv_query_error -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'sin\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tsin\(Float64/Float32\) -SELECT sin(c1) FROM aggregate_test_100 - -# cast_expressions_error -statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c' to value of Int32 type -SELECT CAST(c1 AS INT) FROM aggregate_test_100 - -# aggregation_with_bad_arguments -statement error Error during planning: The function Count expects at least one argument -SELECT COUNT(DISTINCT) FROM aggregate_test_100 - -# query_cte_incorrect -statement error Error during planning: table 'datafusion\.public\.t' not found -WITH t AS (SELECT * FROM t) SELECT * from u - -statement error Error during planning: table 'datafusion\.public\.u' not found -WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u - -statement error Error during planning: table 'datafusion\.public\.u' not found -WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u - -# select_wildcard_without_table -statement error Error during planning: SELECT \* with no tables specified is not valid -SELECT * - -# invalid_qualified_table_references -statement error Error during planning: table 'datafusion\.nonexistentschema\.aggregate_test_100' not found -SELECT COUNT(*) FROM nonexistentschema.aggregate_test_100 - -statement error Error during planning: table 'nonexistentcatalog\.public\.aggregate_test_100' not found -SELECT COUNT(*) FROM nonexistentcatalog.public.aggregate_test_100 - -statement error Error during planning: Unsupported compound identifier '\[Ident \{ value: "way", quote_style: None \}, Ident \{ value: "too", quote_style: None \}, Ident \{ value: "many", quote_style: None \}, Ident \{ value: "namespaces", quote_style: None \}, Ident \{ value: "as", quote_style: None \}, Ident \{ value: "ident", quote_style: None \}, Ident \{ value: "prefixes", quote_style: None \}, Ident \{ value: "aggregate_test_100", quote_style: None \}\]' -SELECT COUNT(*) FROM way.too.many.namespaces.as.ident.prefixes.aggregate_test_100 diff --git a/datafusion/core/tests/sqllogictests/test_files/explain.slt b/datafusion/core/tests/sqllogictests/test_files/explain.slt deleted file mode 100644 index d230286adcb9c..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/explain.slt +++ /dev/null @@ -1,251 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -statement ok -CREATE EXTERNAL TABLE aggregate_test_100 ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT NOT NULL, - c5 INTEGER NOT NULL, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 INT UNSIGNED NOT NULL, - c10 BIGINT UNSIGNED NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL - ) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; - -query TT -explain SELECT c1 FROM aggregate_test_100 where c2 > 10 ----- -logical_plan -Projection: aggregate_test_100.c1 ---Filter: aggregate_test_100.c2 > Int8(10) -----TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] -physical_plan -ProjectionExec: expr=[c1@0 as c1] ---CoalesceBatchesExec: target_batch_size=8192 -----FilterExec: c2@1 > 10 -------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2], has_header=true - -# explain_csv_exec_scan_config - -statement ok -CREATE EXTERNAL TABLE aggregate_test_100_with_order ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT NOT NULL, - c5 INTEGER NOT NULL, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 INT UNSIGNED NOT NULL, - c10 BIGINT UNSIGNED NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL - ) -STORED AS CSV -WITH HEADER ROW -WITH ORDER (c1 ASC) -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; - -query TT -explain SELECT c1 FROM aggregate_test_100_with_order order by c1 ASC limit 10 ----- -logical_plan -Limit: skip=0, fetch=10 ---Sort: aggregate_test_100_with_order.c1 ASC NULLS LAST, fetch=10 -----TableScan: aggregate_test_100_with_order projection=[c1] -physical_plan -GlobalLimitExec: skip=0, fetch=10 ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], output_ordering=[c1@0 ASC NULLS LAST], has_header=true - - -## explain_physical_plan_only - -statement ok -set datafusion.explain.physical_plan_only = true - -query TT -EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3) ----- -physical_plan -ProjectionExec: expr=[2 as COUNT(UInt8(1))] ---EmptyExec: produce_one_row=true - -statement ok -set datafusion.explain.physical_plan_only = false - - -## explain nested -statement error Explain must be root of the plan -EXPLAIN explain select 1 - -statement ok -set datafusion.explain.physical_plan_only = true - -statement error Explain must be root of the plan -EXPLAIN explain select 1 - -statement ok -set datafusion.explain.physical_plan_only = false - -########## -# EXPLAIN VERBOSE will get pass prefixed with "logical_plan after" -########## - -statement ok -CREATE EXTERNAL TABLE simple_explain_test ( - a INT, - b INT, - c INT -) -STORED AS CSV -WITH HEADER ROW -LOCATION './tests/data/example.csv' - -query TT -EXPLAIN SELECT a, b, c FROM simple_explain_test ----- -logical_plan TableScan: simple_explain_test projection=[a, b, c] -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true - -# create a sink table, path is same with aggregate_test_100 table -# we do not overwrite this file, we only assert plan. -statement ok -CREATE EXTERNAL TABLE sink_table ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT NOT NULL, - c5 INTEGER NOT NULL, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 INT UNSIGNED NOT NULL, - c10 BIGINT UNSIGNED NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL - ) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; - -query TT -EXPLAIN INSERT INTO sink_table SELECT * FROM aggregate_test_100 ORDER by c1 ----- -logical_plan -Dml: op=[Insert] table=[sink_table] ---Projection: aggregate_test_100.c1 AS c1, aggregate_test_100.c2 AS c2, aggregate_test_100.c3 AS c3, aggregate_test_100.c4 AS c4, aggregate_test_100.c5 AS c5, aggregate_test_100.c6 AS c6, aggregate_test_100.c7 AS c7, aggregate_test_100.c8 AS c8, aggregate_test_100.c9 AS c9, aggregate_test_100.c10 AS c10, aggregate_test_100.c11 AS c11, aggregate_test_100.c12 AS c12, aggregate_test_100.c13 AS c13 -----Sort: aggregate_test_100.c1 ASC NULLS LAST -------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] -physical_plan -InsertExec: sink=CsvSink(writer_mode=Append, file_groups=[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]) ---ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c5@4 as c5, c6@5 as c6, c7@6 as c7, c8@7 as c8, c9@8 as c9, c10@9 as c10, c11@10 as c11, c12@11 as c12, c13@12 as c13] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true - -# test EXPLAIN VERBOSE -query TT -EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test ----- -initial_logical_plan -Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c ---TableScan: simple_explain_test -logical_plan after inline_table_scan SAME TEXT AS ABOVE -logical_plan after type_coercion SAME TEXT AS ABOVE -logical_plan after count_wildcard_rule SAME TEXT AS ABOVE -analyzed_logical_plan SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE -logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE -logical_plan after eliminate_join SAME TEXT AS ABOVE -logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE -logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE -logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE -logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE -logical_plan after eliminate_filter SAME TEXT AS ABOVE -logical_plan after eliminate_cross_join SAME TEXT AS ABOVE -logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after eliminate_limit SAME TEXT AS ABOVE -logical_plan after propagate_empty_relation SAME TEXT AS ABOVE -logical_plan after filter_null_join_keys SAME TEXT AS ABOVE -logical_plan after eliminate_outer_join SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE -logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection -Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c ---TableScan: simple_explain_test projection=[a, b, c] -logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c] -logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE -logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE -logical_plan after eliminate_join SAME TEXT AS ABOVE -logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE -logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE -logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after merge_projection SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE -logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE -logical_plan after eliminate_filter SAME TEXT AS ABOVE -logical_plan after eliminate_cross_join SAME TEXT AS ABOVE -logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after eliminate_limit SAME TEXT AS ABOVE -logical_plan after propagate_empty_relation SAME TEXT AS ABOVE -logical_plan after filter_null_join_keys SAME TEXT AS ABOVE -logical_plan after eliminate_outer_join SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE -logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE -logical_plan after push_down_projection SAME TEXT AS ABOVE -logical_plan after eliminate_projection SAME TEXT AS ABOVE -logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan TableScan: simple_explain_test projection=[a, b, c] -initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true -physical_plan after aggregate_statistics SAME TEXT AS ABOVE -physical_plan after join_selection SAME TEXT AS ABOVE -physical_plan after PipelineFixer SAME TEXT AS ABOVE -physical_plan after repartition SAME TEXT AS ABOVE -physical_plan after global_sort_selection SAME TEXT AS ABOVE -physical_plan after EnforceDistribution SAME TEXT AS ABOVE -physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE -physical_plan after EnforceSorting SAME TEXT AS ABOVE -physical_plan after coalesce_batches SAME TEXT AS ABOVE -physical_plan after PipelineChecker SAME TEXT AS ABOVE -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt deleted file mode 100644 index ca400e0ef4994..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ /dev/null @@ -1,2591 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -statement ok -CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER) - -statement ok -CREATE TABLE tab1(col0 INTEGER, col1 INTEGER, col2 INTEGER) - -statement ok -CREATE TABLE tab2(col0 INTEGER, col1 INTEGER, col2 INTEGER) - -statement ok -INSERT INTO tab0 VALUES(83,0,38) - -statement ok -INSERT INTO tab0 VALUES(26,0,79) - -statement ok -INSERT INTO tab0 VALUES(43,81,24) - -statement ok -INSERT INTO tab1 VALUES(22,6,8) - -statement ok -INSERT INTO tab1 VALUES(28,57,45) - -statement ok -INSERT INTO tab1 VALUES(82,44,71) - -statement ok -INSERT INTO tab2 VALUES(15,61,87) - -statement ok -INSERT INTO tab2 VALUES(91,59,79) - -statement ok -INSERT INTO tab2 VALUES(92,41,58) - -query I rowsort -SELECT - tab1.col0 * 84 + + 38 AS col2 FROM tab1 GROUP BY tab1.col0 ----- --1810 --2314 --6850 - -query I rowsort -SELECT + cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT DISTINCT - ( + col1 ) + - 51 AS col0 FROM tab1 AS cor0 GROUP BY col1 ----- --108 --57 --95 - -query I rowsort -SELECT col1 * cor0.col1 * 56 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -194936 -208376 -94136 - -query I rowsort label-4 -SELECT ALL + tab2.col1 / tab2.col1 FROM tab2 GROUP BY col1 ----- -1 -1 -1 - -query I rowsort -SELECT ALL + tab1.col0 FROM tab1 GROUP BY col0 ----- -22 -28 -82 - -query I rowsort -SELECT DISTINCT tab1.col0 AS col1 FROM tab1 GROUP BY tab1.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ALL col2 FROM tab1 GROUP BY col2 ----- -45 -71 -8 - -query I rowsort -SELECT ALL + cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- -26 -43 -83 - -query III rowsort -SELECT DISTINCT * FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 ----- -26 0 79 -43 81 24 -83 0 38 - -query III rowsort -SELECT * FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 ----- -26 0 79 -43 81 24 -83 0 38 - -query I rowsort -SELECT - 9 * cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- --369 --531 --549 - -query I rowsort -SELECT DISTINCT - 21 FROM tab2 GROUP BY col2 ----- --21 - -query I rowsort -SELECT DISTINCT - 97 AS col2 FROM tab1 GROUP BY col0 ----- --97 - -query I rowsort -SELECT + ( - 1 ) AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- --1 --1 --1 - -query I rowsort -SELECT - + cor0.col1 FROM tab0, tab0 cor0 GROUP BY cor0.col1 ----- --81 -0 - -query I rowsort -SELECT + cor0.col0 + 36 AS col2 FROM tab0 AS cor0 GROUP BY col0 ----- -119 -62 -79 - -query I rowsort -SELECT cor0.col1 AS col1 FROM tab0 AS cor0 GROUP BY col1 ----- -0 -81 - -query I rowsort -SELECT DISTINCT + cor0.col1 FROM tab2 cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT ALL + cor0.col0 + - col0 col1 FROM tab1 AS cor0 GROUP BY col0 ----- -0 -0 -0 - -query I rowsort -SELECT ALL 54 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -54 -54 -54 - -query I rowsort -SELECT 40 AS col1 FROM tab1 cor0 GROUP BY cor0.col0 ----- -40 -40 -40 - -query I rowsort -SELECT DISTINCT ( cor0.col0 ) AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -26 -43 -83 - -query I rowsort -SELECT 62 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -62 -62 -62 - -query I rowsort -SELECT 23 FROM tab2 GROUP BY tab2.col2 ----- -23 -23 -23 - -query I rowsort -SELECT + ( - tab0.col0 ) col2 FROM tab0, tab0 AS cor0 GROUP BY tab0.col0 ----- --26 --43 --83 - -query I rowsort -SELECT + cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col1 ----- -44 -57 -6 - -query I rowsort -SELECT cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col2 ----- -41 -59 -61 - -query I rowsort -SELECT DISTINCT + 80 + cor0.col2 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -104 -118 -159 - -query I rowsort -SELECT DISTINCT 30 * - 9 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- --270 - -query I rowsort -SELECT DISTINCT - col2 FROM tab1 AS cor0 GROUP BY col2 ----- --45 --71 --8 - -query I rowsort -SELECT ALL - col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- --45 --71 --8 - -query I rowsort -SELECT DISTINCT + 82 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -82 - -query I rowsort -SELECT 79 * 19 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -1501 -1501 -1501 - -query I rowsort -SELECT ALL ( + 68 ) FROM tab1 cor0 GROUP BY cor0.col2 ----- -68 -68 -68 - -query I rowsort -SELECT - col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- --22 --28 --82 - -query I rowsort -SELECT + 81 col2 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -81 -81 -81 - -query I rowsort -SELECT ALL cor0.col2 AS col1 FROM tab2 cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT ALL + cor0.col0 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT - cor0.col2 AS col0 FROM tab0 cor0 GROUP BY cor0.col2 ----- --24 --38 --79 - -query I rowsort -SELECT cor0.col0 FROM tab1 AS cor0 GROUP BY col0, cor0.col1, cor0.col1 ----- -22 -28 -82 - -query I rowsort -SELECT 58 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -58 -58 - -query I rowsort -SELECT ALL cor0.col1 + - 20 AS col1 FROM tab0 cor0 GROUP BY cor0.col1 ----- --20 -61 - -query I rowsort -SELECT ALL + col1 col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT DISTINCT - - 56 FROM tab2, tab0 AS cor0 GROUP BY cor0.col1 ----- -56 - -query I rowsort -SELECT - 10 AS col0 FROM tab2, tab1 AS cor0, tab2 AS cor1 GROUP BY cor1.col0 ----- --10 --10 --10 - -query I rowsort -SELECT 31 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -31 -31 -31 - -query I rowsort -SELECT col2 AS col0 FROM tab0 cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT + 70 AS col1 FROM tab0 GROUP BY col0 ----- -70 -70 -70 - -query I rowsort -SELECT DISTINCT cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT - cor0.col1 FROM tab2, tab2 AS cor0 GROUP BY cor0.col1 ----- --41 --59 --61 - -query I rowsort -SELECT DISTINCT + tab0.col0 col1 FROM tab0 GROUP BY tab0.col0 ----- -26 -43 -83 - -query I rowsort -SELECT DISTINCT - cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- --24 --38 --79 - -query I rowsort -SELECT + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT - 5 AS col2 FROM tab2, tab2 AS cor0, tab2 AS cor1 GROUP BY tab2.col1 ----- --5 --5 --5 - -query I rowsort -SELECT DISTINCT 0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -0 - -query I rowsort -SELECT DISTINCT - - tab2.col0 FROM tab2 GROUP BY col0 ----- -15 -91 -92 - -query III rowsort -SELECT DISTINCT * FROM tab2 AS cor0 GROUP BY cor0.col0, col1, cor0.col2 ----- -15 61 87 -91 59 79 -92 41 58 - -query I rowsort label-58 -SELECT 9 / + cor0.col0 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- -0 -0 -0 - -query I rowsort -SELECT ( - 72 ) AS col1 FROM tab1 cor0 GROUP BY cor0.col0, cor0.col2 ----- --72 --72 --72 - -query I rowsort -SELECT cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ( col0 ) FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort label-62 -SELECT ALL 59 / 26 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -2 -2 -2 - -query I rowsort -SELECT 15 FROM tab1 AS cor0 GROUP BY col2, col2 ----- -15 -15 -15 - -query I rowsort -SELECT CAST ( NULL AS INTEGER ) FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col2 ----- -NULL -NULL -NULL - -query I rowsort -SELECT ALL - 79 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- --79 --79 --79 - -query I rowsort -SELECT ALL 69 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -69 -69 -69 - -query I rowsort -SELECT ALL 37 col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -37 -37 - -query I rowsort -SELECT ALL 55 * 15 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -825 -825 -825 - -query I rowsort -SELECT ( 63 ) FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -63 -63 -63 - -query I rowsort -SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- --45 --71 --8 - -query I rowsort -SELECT - col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- --58 --79 --87 - -query I rowsort -SELECT ALL 81 * 11 FROM tab2 AS cor0 GROUP BY col1, cor0.col0 ----- -891 -891 -891 - -query I rowsort -SELECT ALL 9 FROM tab2 AS cor0 GROUP BY col2 ----- -9 -9 -9 - -query I rowsort -SELECT DISTINCT ( - 31 ) col1 FROM tab1 GROUP BY tab1.col0 ----- --31 - -query I rowsort label-75 -SELECT + + cor0.col0 / - cor0.col0 FROM tab1, tab0 AS cor0 GROUP BY cor0.col0 ----- --1 --1 --1 - -query I rowsort -SELECT cor0.col2 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT ALL cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -0 -81 - -query I rowsort -SELECT ALL + - ( - tab0.col2 ) AS col0 FROM tab0 GROUP BY tab0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT 72 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -72 -72 - -query I rowsort -SELECT - 20 - + col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- --101 --20 - -query I rowsort -SELECT - - 63 FROM tab1 GROUP BY tab1.col0 ----- -63 -63 -63 - -query I rowsort -SELECT cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col2, col1 ----- -45 -71 -8 - -query I rowsort -SELECT + cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -0 -81 - -query I rowsort -SELECT DISTINCT cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col1 ----- -44 -57 -6 - -query I rowsort -SELECT cor0.col0 - col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -0 -0 -0 - -query I rowsort -SELECT 50 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -50 -50 -50 - -query I rowsort -SELECT - 18 AS col0 FROM tab1 cor0 GROUP BY cor0.col2 ----- --18 --18 --18 - -query I rowsort -SELECT + cor0.col2 * cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- -1444 -576 -6241 - -query I rowsort -SELECT ALL 91 / cor0.col1 FROM tab2 AS cor0 GROUP BY col1, cor0.col1 ----- -1 -1 -2 - -query I rowsort -SELECT cor0.col2 AS col2 FROM tab0 AS cor0 GROUP BY col2 ----- -24 -38 -79 - -query I rowsort -SELECT ALL + 85 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -85 -85 -85 - -query I rowsort -SELECT + 49 AS col2 FROM tab0 cor0 GROUP BY cor0.col0 ----- -49 -49 -49 - -query I rowsort -SELECT cor0.col2 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT - col0 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- --15 --91 --92 - -query I rowsort -SELECT DISTINCT - 87 AS col1 FROM tab0 AS cor0 GROUP BY col0 ----- --87 - -query I rowsort -SELECT + 39 FROM tab0 AS cor0 GROUP BY col1 ----- -39 -39 - -query I rowsort -SELECT ALL cor0.col2 * + col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -3364 -6241 -7569 - -query I rowsort -SELECT 40 FROM tab0 GROUP BY tab0.col1 ----- -40 -40 - -query I rowsort -SELECT tab1.col2 AS col0 FROM tab1 GROUP BY tab1.col2 ----- -45 -71 -8 - -query I rowsort -SELECT tab2.col0 FROM tab2 GROUP BY tab2.col0 ----- -15 -91 -92 - -query I rowsort -SELECT + col0 * + col0 FROM tab0 GROUP BY tab0.col0 ----- -1849 -676 -6889 - -query I rowsort -SELECT ALL cor0.col2 + cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -158 -48 -76 - -query I rowsort -SELECT DISTINCT cor0.col2 FROM tab1 cor0 GROUP BY cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT ALL + cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT cor0.col2 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort label-106 -SELECT - 53 / cor0.col0 col0 FROM tab1 cor0 GROUP BY cor0.col0 ----- --1 --2 -0 - -query I rowsort -SELECT cor0.col1 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -0 -81 - -query I rowsort -SELECT DISTINCT + cor0.col1 col0 FROM tab2 cor0 GROUP BY cor0.col1, cor0.col0 ----- -41 -59 -61 - -query I rowsort -SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- --45 --71 --8 - -query I rowsort -SELECT cor0.col1 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col1 ----- -0 -81 - -query I rowsort -SELECT 25 AS col1 FROM tab2 cor0 GROUP BY cor0.col0 ----- -25 -25 -25 - -query I rowsort -SELECT cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT DISTINCT + 6 FROM tab1 cor0 GROUP BY col2, cor0.col0 ----- -6 - -query I rowsort -SELECT cor0.col2 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT ALL 72 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -72 -72 -72 - -query I rowsort -SELECT ALL + 73 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -73 -73 -73 - -query I rowsort -SELECT tab1.col0 AS col2 FROM tab1 GROUP BY col0 ----- -22 -28 -82 - -query I rowsort -SELECT + cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT DISTINCT - cor0.col1 col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- --81 -0 - -query I rowsort -SELECT cor0.col0 * 51 FROM tab1 AS cor0 GROUP BY col0 ----- -1122 -1428 -4182 - -query I rowsort -SELECT ALL + 89 FROM tab2, tab1 AS cor0, tab1 AS cor1 GROUP BY cor0.col2 ----- -89 -89 -89 - -query I rowsort -SELECT ALL + cor0.col0 - + cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -0 -0 -0 - -query I rowsort -SELECT ALL 71 AS col0 FROM tab0 GROUP BY col1 ----- -71 -71 - -query I rowsort -SELECT - ( + cor0.col0 ) AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- --26 --43 --83 - -query I rowsort -SELECT 62 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -62 -62 -62 - -query I rowsort -SELECT ALL - 97 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- --97 --97 --97 - -query I rowsort -SELECT DISTINCT + 29 * ( cor0.col0 ) + + 47 FROM tab1 cor0 GROUP BY cor0.col0 ----- -2425 -685 -859 - -query I rowsort -SELECT DISTINCT col2 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT ALL 40 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -40 -40 -40 - -query I rowsort -SELECT cor0.col1 + cor0.col1 AS col2 FROM tab2 cor0 GROUP BY cor0.col1 ----- -118 -122 -82 - -query I rowsort -SELECT ( + cor0.col1 ) FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT cor0.col1 * + cor0.col1 col1 FROM tab1 AS cor0 GROUP BY cor0.col1 ----- -1936 -3249 -36 - -query I rowsort -SELECT ALL + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT - 9 FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col1, col2 ----- --9 --9 --9 - -query I rowsort -SELECT ALL - 7 * cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col1 ----- --308 --399 --42 - -query I rowsort -SELECT - 21 AS col2 FROM tab1 cor0 GROUP BY cor0.col1, cor0.col1 ----- --21 --21 --21 - -query I rowsort -SELECT DISTINCT tab1.col2 FROM tab1 GROUP BY tab1.col2 ----- -45 -71 -8 - -query I rowsort -SELECT DISTINCT - 76 FROM tab2 GROUP BY tab2.col2 ----- --76 - -query I rowsort -SELECT DISTINCT - cor0.col1 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- --41 --59 --61 - -query I rowsort -SELECT cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -0 -81 - -query I rowsort -SELECT ALL - cor0.col2 + - 55 AS col1 FROM tab0 AS cor0 GROUP BY col2 ----- --134 --79 --93 - -query I rowsort -SELECT - + 28 FROM tab0, tab2 cor0 GROUP BY tab0.col1 ----- --28 --28 - -query I rowsort -SELECT ALL col1 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT ALL + 35 * 14 AS col1 FROM tab2 GROUP BY tab2.col1 ----- -490 -490 -490 - -query I rowsort -SELECT ALL cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, cor0.col1 ----- -15 -91 -92 - -query I rowsort -SELECT DISTINCT - cor0.col2 * 18 + + 56 FROM tab2 AS cor0 GROUP BY col2 ----- --1366 --1510 --988 - -query I rowsort -SELECT cor0.col0 FROM tab0 cor0 GROUP BY col0 ----- -26 -43 -83 - -query I rowsort -SELECT ALL - 38 AS col1 FROM tab2 GROUP BY tab2.col2 ----- --38 --38 --38 - -query I rowsort -SELECT - 79 FROM tab0, tab0 cor0, tab0 AS cor1 GROUP BY cor1.col0 ----- --79 --79 --79 - -query I rowsort -SELECT + cor0.col2 FROM tab1 cor0 GROUP BY cor0.col2, cor0.col1 ----- -45 -71 -8 - -query I rowsort -SELECT cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -26 -43 -83 - -query I rowsort -SELECT cor0.col2 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -24 -38 -79 - -query I rowsort -SELECT + - 57 AS col1 FROM tab2 GROUP BY tab2.col2 ----- --57 --57 --57 - -query I rowsort -SELECT ALL - cor0.col1 FROM tab2 cor0 GROUP BY cor0.col1 ----- --41 --59 --61 - -query I rowsort -SELECT DISTINCT cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT - cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- --26 --43 --83 - -query I rowsort -SELECT ( - cor0.col1 ) FROM tab1 AS cor0 GROUP BY cor0.col1 ----- --44 --57 --6 - -query I rowsort -SELECT DISTINCT - cor0.col2 FROM tab0 cor0 GROUP BY cor0.col2, cor0.col2 ----- --24 --38 --79 - -query I rowsort -SELECT DISTINCT tab1.col1 * ( + tab1.col1 ) FROM tab1 GROUP BY col1 ----- -1936 -3249 -36 - -query I rowsort -SELECT - cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- --41 --59 --61 - -query III rowsort -SELECT * FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 ----- -15 61 87 -91 59 79 -92 41 58 - -query I rowsort -SELECT + 83 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -83 -83 -83 - -query I rowsort -SELECT + ( 97 ) + - tab0.col1 FROM tab0, tab1 AS cor0 GROUP BY tab0.col1 ----- -16 -97 - -query I rowsort -SELECT 61 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -61 -61 -61 - -query I rowsort -SELECT ALL cor0.col2 FROM tab0 cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT cor0.col2 FROM tab0, tab1 AS cor0 GROUP BY cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT + - 3 FROM tab2 GROUP BY col1 ----- --3 --3 --3 - -query I rowsort -SELECT DISTINCT + 96 FROM tab2 GROUP BY tab2.col1 ----- -96 - -query I rowsort -SELECT ALL 81 FROM tab1 AS cor0 GROUP BY cor0.col1 ----- -81 -81 -81 - -query I rowsort -SELECT cor0.col0 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -26 -43 -83 - -query I rowsort -SELECT - + 51 col2 FROM tab2, tab2 AS cor0 GROUP BY cor0.col1 ----- --51 --51 --51 - -query I rowsort -SELECT cor0.col1 + - cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -0 -0 -0 - -query I rowsort -SELECT 35 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col1 ----- -35 -35 -35 - -query I rowsort -SELECT + tab2.col1 col0 FROM tab2 GROUP BY tab2.col1 ----- -41 -59 -61 - -query I rowsort -SELECT 37 AS col1 FROM tab0 AS cor0 GROUP BY col0 ----- -37 -37 -37 - -query I rowsort -SELECT + cor0.col1 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - -query I rowsort -SELECT cor0.col1 FROM tab2, tab1 AS cor0 GROUP BY cor0.col1 ----- -44 -57 -6 - -query I rowsort -SELECT ALL - col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- --22 --28 --82 - -query I rowsort -SELECT + 77 AS col1 FROM tab1 AS cor0 CROSS JOIN tab0 AS cor1 GROUP BY cor0.col2 ----- -77 -77 -77 - -query I rowsort -SELECT ALL cor0.col0 col1 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT + cor0.col2 * + cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- -1032 -2054 -3154 - -query I rowsort -SELECT DISTINCT 39 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -39 - -query III rowsort -SELECT DISTINCT * FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2, cor0.col1 ----- -22 6 8 -28 57 45 -82 44 71 - -query I rowsort -SELECT ALL + 28 FROM tab2 cor0 GROUP BY cor0.col0 ----- -28 -28 -28 - -query I rowsort -SELECT cor0.col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ALL cor0.col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT + ( col0 ) * col0 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -225 -8281 -8464 - -query I rowsort label-188 -SELECT - 21 - + 57 / cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- --21 --22 --23 - -query I rowsort -SELECT + 37 + cor0.col0 * cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2, col0 ----- -1342 -5373 -7226 - -query I rowsort -SELECT ALL cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -45 -71 -8 - -query III rowsort -SELECT * FROM tab1 AS cor0 GROUP BY col2, cor0.col1, cor0.col0 ----- -22 6 8 -28 57 45 -82 44 71 - -query I rowsort -SELECT ( cor0.col2 ) AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT DISTINCT 28 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -28 - -query I rowsort -SELECT ALL - 18 FROM tab0, tab1 AS cor0 GROUP BY cor0.col0 ----- --18 --18 --18 - -query I rowsort -SELECT DISTINCT cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT + col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT - cor0.col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 ----- --22 --28 --82 - -query I rowsort -SELECT 29 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 ----- -29 -29 -29 - -query I rowsort -SELECT - + cor0.col0 - 39 AS col0 FROM tab0, tab0 cor0 GROUP BY cor0.col0 ----- --122 --65 --82 - -query I rowsort -SELECT ALL 45 AS col0 FROM tab0 GROUP BY tab0.col0 ----- -45 -45 -45 - -query I rowsort -SELECT + 74 AS col1 FROM tab1 GROUP BY tab1.col0 ----- -74 -74 -74 - -query I rowsort -SELECT cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort label-203 -SELECT - cor0.col2 + CAST ( 80 AS INTEGER ) FROM tab1 AS cor0 GROUP BY col2 ----- -35 -72 -9 - -query I rowsort -SELECT DISTINCT - cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- --81 -0 - -query I rowsort -SELECT - 51 * + cor0.col2 FROM tab0, tab2 cor0, tab1 AS cor1 GROUP BY cor0.col2 ----- --2958 --4029 --4437 - -query I rowsort -SELECT ALL + col0 * cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -225 -8281 -8464 - -query I rowsort -SELECT DISTINCT ( col0 ) FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -26 -43 -83 - -query I rowsort -SELECT 87 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -87 -87 -87 - -query I rowsort -SELECT + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT DISTINCT + 45 col0 FROM tab1 AS cor0 GROUP BY col0 ----- -45 - -query I rowsort label-211 -SELECT ALL CAST ( NULL AS INTEGER ) FROM tab2 AS cor0 GROUP BY col1 ----- -NULL -NULL -NULL - -query I rowsort -SELECT ALL cor0.col1 + col1 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -0 -162 - -query I rowsort -SELECT - cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- --81 -0 - -query I rowsort -SELECT DISTINCT + 99 * 76 + + tab2.col1 AS col2 FROM tab2 GROUP BY col1 ----- -7565 -7583 -7585 - -query I rowsort -SELECT ALL 54 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -54 -54 -54 - -query I rowsort -SELECT + cor0.col2 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2, cor0.col0 ----- -58 -79 -87 - -query I rowsort -SELECT cor0.col0 + + 87 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -109 -115 -169 - -query I rowsort -SELECT cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, cor0.col1, cor0.col0 ----- -15 -91 -92 - -query I rowsort -SELECT ALL col0 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT DISTINCT - cor0.col0 - + cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- --182 --184 --30 - -query I rowsort -SELECT ALL - 68 * + cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col1 ----- --5508 -0 - -query I rowsort -SELECT col2 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2 ----- -24 -38 -79 - -query I rowsort -SELECT ALL - 11 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- --11 --11 --11 - -query I rowsort -SELECT 66 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -66 -66 -66 - -query I rowsort -SELECT - cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- --58 --79 --87 - -query I rowsort -SELECT ALL 37 FROM tab2, tab0 AS cor0 GROUP BY cor0.col1 ----- -37 -37 - -query I rowsort -SELECT DISTINCT + 20 col2 FROM tab0 GROUP BY tab0.col1 ----- -20 - -query I rowsort -SELECT 42 FROM tab0 cor0 GROUP BY col2 ----- -42 -42 -42 - -query I rowsort -SELECT ALL - cor0.col1 AS col1 FROM tab1 cor0 GROUP BY cor0.col1 ----- --44 --57 --6 - -query I rowsort -SELECT - col2 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- --58 --79 --87 - -query I rowsort -SELECT DISTINCT + 86 FROM tab1 GROUP BY tab1.col2 ----- -86 - -query I rowsort -SELECT + cor0.col1 AS col1 FROM tab2, tab0 cor0 GROUP BY cor0.col1 ----- -0 -81 - -query I rowsort -SELECT - 13 FROM tab0 cor0 GROUP BY cor0.col1 ----- --13 --13 - -query I rowsort -SELECT tab1.col0 AS col1 FROM tab1 GROUP BY tab1.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ALL cor0.col1 * cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -1681 -3481 -3721 - -query I rowsort -SELECT - cor0.col0 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- --15 --91 --92 - -query I rowsort -SELECT cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT ALL - 67 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- --67 --67 --67 - -query I rowsort -SELECT + 75 AS col2 FROM tab1 cor0 GROUP BY cor0.col0 ----- -75 -75 -75 - -query I rowsort -SELECT ALL cor0.col1 FROM tab0 AS cor0 GROUP BY col0, cor0.col1 ----- -0 -0 -81 - -query I rowsort -SELECT ALL + cor0.col1 FROM tab0 AS cor0 GROUP BY col1 ----- -0 -81 - -query I rowsort -SELECT DISTINCT - 38 - - cor0.col0 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- --12 -45 -5 - -query I rowsort -SELECT + cor0.col0 + - col0 + 21 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -21 -21 -21 - -query I rowsort -SELECT + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0, cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ALL - cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col0 ----- --26 --43 --83 - -query III rowsort -SELECT * FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col1, cor0.col0 ----- -26 0 79 -43 81 24 -83 0 38 - -query I rowsort -SELECT DISTINCT + + tab2.col2 FROM tab2, tab1 AS cor0 GROUP BY tab2.col2 ----- -58 -79 -87 - -query I rowsort -SELECT cor0.col0 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -15 -91 -92 - -query I rowsort -SELECT col0 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 ----- -15 -91 -92 - -query I rowsort -SELECT - cor0.col0 AS col1 FROM tab1 AS cor0 GROUP BY col0 ----- --22 --28 --82 - -query I rowsort -SELECT DISTINCT ( + 71 ) col1 FROM tab1 GROUP BY tab1.col2 ----- -71 - -query I rowsort -SELECT + 96 * 29 col1 FROM tab2, tab1 AS cor0 GROUP BY tab2.col0 ----- -2784 -2784 -2784 - -query I rowsort -SELECT + 3 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- -3 -3 -3 - -query I rowsort -SELECT 37 FROM tab0 AS cor0 GROUP BY col0 ----- -37 -37 -37 - -query I rowsort -SELECT 82 FROM tab0 cor0 GROUP BY cor0.col1 ----- -82 -82 - -query I rowsort -SELECT cor0.col2 FROM tab2 cor0 GROUP BY cor0.col2 ----- -58 -79 -87 - -query I rowsort -SELECT DISTINCT - 87 FROM tab1, tab2 AS cor0, tab2 AS cor1 GROUP BY tab1.col0 ----- --87 - -query I rowsort -SELECT 55 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col1 ----- -55 -55 -55 - -query I rowsort -SELECT DISTINCT 35 FROM tab0 cor0 GROUP BY cor0.col2, cor0.col0 ----- -35 - -query I rowsort -SELECT cor0.col0 FROM tab2 cor0 GROUP BY col0 ----- -15 -91 -92 - -query I rowsort -SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY col2 ----- --45 --71 --8 - -query I rowsort -SELECT ALL ( cor0.col2 ) AS col1 FROM tab2, tab1 AS cor0 GROUP BY cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT DISTINCT - col2 FROM tab1 GROUP BY tab1.col2 ----- --45 --71 --8 - -query I rowsort -SELECT 38 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col1 ----- -38 -38 -38 - -query I rowsort -SELECT - 16 * - cor0.col0 * 47 FROM tab0 AS cor0 GROUP BY cor0.col0 ----- -19552 -32336 -62416 - -query I rowsort -SELECT - 31 FROM tab2 AS cor0 GROUP BY cor0.col2 ----- --31 --31 --31 - -query I rowsort -SELECT ( + 34 ) AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -34 -34 -34 - -query I rowsort -SELECT cor0.col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 ----- -45 -71 -8 - -query I rowsort -SELECT DISTINCT 21 FROM tab0 AS cor0 GROUP BY cor0.col2 ----- -21 - -query I rowsort -SELECT 62 AS col2 FROM tab0 cor0 GROUP BY cor0.col1, cor0.col2 ----- -62 -62 -62 - -query I rowsort -SELECT cor0.col0 FROM tab1 cor0 GROUP BY cor0.col0, cor0.col1 ----- -22 -28 -82 - -query I rowsort -SELECT DISTINCT cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, col1 ----- -15 -91 -92 - -query I rowsort -SELECT DISTINCT cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 ----- -22 -28 -82 - -query I rowsort -SELECT ALL - ( 30 ) * + cor0.col1 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- --1230 --1770 --1830 - -query I rowsort -SELECT DISTINCT 94 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 ----- -94 - -query I rowsort -SELECT DISTINCT + col1 FROM tab2 AS cor0 GROUP BY cor0.col1 ----- -41 -59 -61 - - - -# Columns in the table are a,b,c,d. Source is CsvExec which is ordered by -# a,b,c column. Column a has cardinality 2, column b has cardinality 4. -# Column c has cardinality 100 (unique entries). Column d has cardinality 5. -statement ok -CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite2 ( - a0 INTEGER, - a INTEGER, - b INTEGER, - c INTEGER, - d INTEGER -) -STORED AS CSV -WITH HEADER ROW -WITH ORDER (a ASC, b ASC, c ASC) -LOCATION 'tests/data/window_2.csv'; - -# Create a table with 2 ordered columns. -# In the next step, we will expect to observe the removed sort execs. -statement ok -CREATE EXTERNAL TABLE multiple_ordered_table ( - a0 INTEGER, - a INTEGER, - b INTEGER, - c INTEGER, - d INTEGER -) -STORED AS CSV -WITH HEADER ROW -WITH ORDER (a ASC, b ASC) -WITH ORDER (c ASC) -LOCATION 'tests/data/window_2.csv'; - -# Expected a sort exec for b DESC -query TT -EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY b DESC; ----- -logical_plan -Projection: multiple_ordered_table.a ---Sort: multiple_ordered_table.b DESC NULLS FIRST -----TableScan: multiple_ordered_table projection=[a, b] -physical_plan -ProjectionExec: expr=[a@0 as a] ---SortExec: expr=[b@1 DESC] -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true - -# Final plan shouldn't have SortExec c ASC, -# because table already satisfies this ordering. -query TT -EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY c ASC; ----- -logical_plan -Projection: multiple_ordered_table.a ---Sort: multiple_ordered_table.c ASC NULLS LAST -----TableScan: multiple_ordered_table projection=[a, c] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true - -# Final plan shouldn't have SortExec a ASC, b ASC, -# because table already satisfies this ordering. -query TT -EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY a ASC, b ASC; ----- -logical_plan -Projection: multiple_ordered_table.a ---Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST -----TableScan: multiple_ordered_table projection=[a, b] -physical_plan -ProjectionExec: expr=[a@0 as a] ---CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true - -# test_window_agg_sort -statement ok -set datafusion.execution.target_partitions = 1; - -# test_source_sorted_groupby -query TT -EXPLAIN SELECT a, b, - SUM(c) as summation1 - FROM annotated_data_infinite2 - GROUP BY b, a ----- -logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotated_data_infinite2.c) AS summation1 ---Aggregate: groupBy=[[annotated_data_infinite2.b, annotated_data_infinite2.a]], aggr=[[SUM(annotated_data_infinite2.c)]] -----TableScan: annotated_data_infinite2 projection=[a, b, c] -physical_plan -ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] ---AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true - - -query III - SELECT a, b, - SUM(c) as summation1 - FROM annotated_data_infinite2 - GROUP BY b, a ----- -0 0 300 -0 1 925 -1 2 1550 -1 3 2175 - - -# test_source_sorted_groupby2 -# If ordering is not important for the aggregation function, we should ignore the ordering requirement. Hence -# "ORDER BY a DESC" should have no effect. -query TT -EXPLAIN SELECT a, d, - SUM(c ORDER BY a DESC) as summation1 - FROM annotated_data_infinite2 - GROUP BY d, a ----- -logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS summation1 ---Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] -----TableScan: annotated_data_infinite2 projection=[a, c, d] -physical_plan -ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] ---AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true - -query III -SELECT a, d, - SUM(c ORDER BY a DESC) as summation1 - FROM annotated_data_infinite2 - GROUP BY d, a ----- -0 0 292 -0 2 196 -0 1 315 -0 4 164 -0 3 258 -1 0 622 -1 3 299 -1 1 1043 -1 4 913 -1 2 848 - -# test_source_sorted_groupby3 - -query TT -EXPLAIN SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS first_c ---Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] -----TableScan: annotated_data_infinite2 projection=[a, b, c] -physical_plan -ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true - -query III -SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -0 0 0 -0 1 25 -1 2 50 -1 3 75 - -# test_source_sorted_groupby4 - -query TT -EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS last_c ---Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] -----TableScan: annotated_data_infinite2 projection=[a, b, c] -physical_plan -ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true - -query III -SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -0 0 24 -0 1 49 -1 2 74 -1 3 99 - -# when LAST_VALUE, or FIRST_VALUE value do not contain ordering requirement -# queries should still work, However, result depends on the scanning order and -# not deterministic -query TT -EXPLAIN SELECT a, b, LAST_VALUE(c) as last_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) AS last_c ---Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c)]] -----TableScan: annotated_data_infinite2 projection=[a, b, c] -physical_plan -ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] ---AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered -----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true - -query III -SELECT a, b, LAST_VALUE(c) as last_c - FROM annotated_data_infinite2 - GROUP BY a, b ----- -0 0 24 -0 1 49 -1 2 74 -1 3 99 - -statement ok -drop table annotated_data_infinite2; - -# create a table for testing -statement ok -CREATE TABLE sales_global (zip_code INT, - country VARCHAR(3), - sn INT, - ts TIMESTAMP, - currency VARCHAR(3), - amount FLOAT - ) as VALUES - (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), - (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), - (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), - (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), - (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), - (0, 'GRC', 4, '2022-01-03 10:00:00'::timestamp, 'EUR', 80.0) - -# test_ordering_sensitive_aggregation -# ordering sensitive requirement should add a SortExec in the final plan. To satisfy amount ASC -# in the aggregation -statement ok -set datafusion.execution.target_partitions = 1; - -query TT -EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts ---Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] -----TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] -----SortExec: expr=[amount@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - - -query T? -SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts - FROM sales_global - GROUP BY country ----- -GRC [30.0, 80.0] -FRA [50.0, 200.0] -TUR [75.0, 100.0] - -# test_ordering_sensitive_aggregation2 -# We should be able to satisfy the finest requirement among all aggregators, when we have multiple aggregators. -# Hence final plan should have SortExec: expr=[amount@1 DESC] to satisfy array_agg requirement. -query TT -EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM sales_global AS s - GROUP BY s.country ----- -logical_plan -Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] -----SubqueryAlias: s -------TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)] -----SortExec: expr=[amount@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?R -SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM sales_global AS s - GROUP BY s.country ----- -FRA [200.0, 50.0] 250 -TUR [100.0, 75.0] 175 -GRC [80.0, 30.0] 110 - -# test_ordering_sensitive_multiple_req -# Currently we do not support multiple ordering requirement for aggregation -# once this support is added. This test should change -# See issue: https://github.com/sqlparser-rs/sqlparser-rs/issues/875 -statement error DataFusion error: This feature is not implemented: ARRAY_AGG only supports a single ORDER BY expression\. Got 2 -SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC, s.country DESC) AS amounts, - SUM(s.amount ORDER BY s.amount DESC) AS sum1 - FROM sales_global AS s - GROUP BY s.country - -# test_ordering_sensitive_aggregation3 -# When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. -# test below should raise Plan Error. -statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported -SELECT ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - ARRAY_AGG(s.amount ORDER BY s.amount ASC) AS amounts2, - ARRAY_AGG(s.amount ORDER BY s.sn ASC) AS amounts3 - FROM sales_global AS s - GROUP BY s.country - -# test_ordering_sensitive_aggregation4 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should append requirement to -# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. -# This test checks for whether we can satisfy aggregation requirement in FullyOrdered mode. -query TT -EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -logical_plan -Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] -----SubqueryAlias: s -------Sort: sales_global.country ASC NULLS LAST ---------TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered -----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?R -SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -FRA [200.0, 50.0] 250 -GRC [80.0, 30.0] 110 -TUR [100.0, 75.0] 175 - -# test_ordering_sensitive_aggregation5 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to -# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. -# This test checks for whether we can satisfy aggregation requirement in PartiallyOrdered mode. -query TT -EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country, s.zip_code ----- -logical_plan -Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] -----SubqueryAlias: s -------Sort: sales_global.country ASC NULLS LAST ---------TableScan: sales_global projection=[zip_code, country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] ---AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=PartiallyOrdered -----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query TI?R -SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country, s.zip_code ----- -FRA 1 [200.0, 50.0] 250 -GRC 0 [80.0, 30.0] 110 -TUR 1 [100.0, 75.0] 175 - -# test_ordering_sensitive_aggregation6 -# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to -# the existing ordering. When group by expressions contain aggregation requirement, we shouldn't append redundant expression. -# Hence in the final plan SortExec should be SortExec: expr=[country@0 DESC] not SortExec: expr=[country@0 ASC NULLS LAST,country@0 DESC] -query TT -EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -logical_plan -Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(s.amount)]] -----SubqueryAlias: s -------Sort: sales_global.country ASC NULLS LAST ---------TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered -----SortExec: expr=[country@0 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?R -SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -FRA [200.0, 50.0] 250 -GRC [80.0, 30.0] 110 -TUR [100.0, 75.0] 175 - -# test_ordering_sensitive_aggregation7 -# Lexicographical ordering requirement can be given as -# argument to the aggregate functions -query TT -EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -logical_plan -Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(s.amount)]] -----SubqueryAlias: s -------Sort: sales_global.country ASC NULLS LAST ---------TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered -----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?R -SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, - SUM(s.amount) AS sum1 - FROM (SELECT * - FROM sales_global - ORDER BY country) AS s - GROUP BY s.country ----- -FRA [200.0, 50.0] 250 -GRC [80.0, 30.0] 110 -TUR [100.0, 75.0] 175 - -# test_reverse_aggregate_expr -# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering -# that have contradictory requirements at first glance. -query TT -EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, - FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2 - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 ---Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] -----TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] -----SortExec: expr=[amount@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?RR -SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, - FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2 - FROM sales_global - GROUP BY country ----- -FRA [200.0, 50.0] 50 50 -TUR [100.0, 75.0] 75 75 -GRC [80.0, 30.0] 30 30 - -# test_reverse_aggregate_expr2 -# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering -# that have contradictory requirements at first glance. -query TT -EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, - FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2 - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 ---Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] -----TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] -----SortExec: expr=[amount@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - -query T?RR -SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, - FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2 - FROM sales_global - GROUP BY country ----- -GRC [30.0, 80.0] 30 30 -FRA [50.0, 200.0] 50 50 -TUR [75.0, 100.0] 75 75 - -# test_reverse_aggregate_expr3 -# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering -# that have contradictory requirements at first glance. This algorithm shouldn't depend -# on the order of the aggregation expressions. -query TT -EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2, - ARRAY_AGG(amount ORDER BY amount ASC) AS amounts - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts ---Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] -----TableScan: sales_global projection=[country, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] -----SortExec: expr=[amount@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - -query TRR? -SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2, - ARRAY_AGG(amount ORDER BY amount ASC) AS amounts - FROM sales_global - GROUP BY country ----- -GRC 30 30 [30.0, 80.0] -FRA 50 50 [50.0, 200.0] -TUR 75 75 [75.0, 100.0] - -# test_reverse_aggregate_expr4 -# Ordering requirement by the ordering insensitive aggregators shouldn't have effect on -# final plan. Hence seemingly conflicting requirements by SUM and ARRAY_AGG shouldn't raise error. -query TT -EXPLAIN SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, - ARRAY_AGG(amount ORDER BY amount ASC) AS amounts - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts ---Aggregate: groupBy=[[sales_global.country]], aggr=[[SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] -----TableScan: sales_global projection=[country, ts, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[SUM(sales_global.amount), ARRAY_AGG(sales_global.amount)] -----SortExec: expr=[amount@2 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - -query TR? -SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, - ARRAY_AGG(amount ORDER BY amount ASC) AS amounts - FROM sales_global - GROUP BY country ----- -GRC 110 [30.0, 80.0] -FRA 250 [50.0, 200.0] -TUR 175 [75.0, 100.0] - -# test_reverse_aggregate_expr5 -# If all of the ordering sensitive aggregation functions are reversible -# we should be able to reverse requirements, if this helps to remove a SortExec. -# Hence in query below, FIRST_VALUE, and LAST_VALUE should be reversed to calculate its result according to `ts ASC` ordering. -# Please note that after `ts ASC` ordering because of inner query. There is no SortExec in the final plan. -query TT -EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, - LAST_VALUE(amount ORDER BY ts DESC) as lv1, - SUM(amount ORDER BY ts DESC) as sum1 - FROM (SELECT * - FROM sales_global - ORDER BY ts ASC) - GROUP BY country ----- -logical_plan -Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 ---Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] -----Sort: sales_global.ts ASC NULLS LAST -------TableScan: sales_global projection=[country, ts, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 ASC NULLS LAST] -------MemoryExec: partitions=1, partition_sizes=[1] - -query TRRR -SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, - LAST_VALUE(amount ORDER BY ts DESC) as lv1, - SUM(amount ORDER BY ts DESC) as sum1 - FROM (SELECT * - FROM sales_global - ORDER BY ts ASC) - GROUP BY country ----- -GRC 80 30 110 -FRA 200 50 250 -TUR 100 75 175 - -# If existing ordering doesn't satisfy requirement, we should do calculations -# on naive requirement (by convention, otherwise the final plan will be unintuitive), -# even if reverse ordering is possible. -# hence, below query should add `SortExec(ts DESC)` to the final plan. -query TT -EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, - LAST_VALUE(amount ORDER BY ts DESC) as lv1, - SUM(amount ORDER BY ts DESC) as sum1 - FROM sales_global - GROUP BY country ----- -logical_plan -Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 ---Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] -----TableScan: sales_global projection=[country, ts, amount] -physical_plan -ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] -----SortExec: expr=[ts@1 DESC] -------MemoryExec: partitions=1, partition_sizes=[1] - -query TRRR -SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, - LAST_VALUE(amount ORDER BY ts DESC) as lv1, - SUM(amount ORDER BY ts DESC) as sum1 - FROM sales_global - GROUP BY country ----- -TUR 100 75 175 -GRC 80 30 110 -FRA 200 50 250 - -# Run order-sensitive aggregators in multiple partitions -statement ok -set datafusion.execution.target_partitions = 2; - -# Currently, we do not support running order-sensitive aggregators in multiple partitions. -statement error This feature is not implemented: Order-sensitive aggregators is not supported on multiple partitions -SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, - FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, - LAST_VALUE(amount ORDER BY amount DESC) AS fv2 - FROM sales_global - GROUP BY country diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt deleted file mode 100644 index 38f1d2cd05d40..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt +++ /dev/null @@ -1,405 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -# Verify the information schema does not exit by default -statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found -SELECT * from information_schema.tables - -statement error DataFusion error: Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled -show all - -# Turn it on - -# expect that the queries now work -statement ok -set datafusion.catalog.information_schema = true; - -# Verify the information schema now does exist and is empty -query TTTT rowsort -SELECT * from information_schema.tables; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW - -# Disable information_schema and verify it now errors again -statement ok -set datafusion.catalog.information_schema = false - -statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found -SELECT * from information_schema.tables - -statement error Error during planning: table 'datafusion.information_schema.columns' not found -SELECT * from information_schema.columns; - - -############ -## Enable information schema for the rest of the test -############ -statement ok -set datafusion.catalog.information_schema = true - -############ -# New tables should show up in information schema -########### -statement ok -create table t as values (1); - -query TTTT rowsort -SELECT * from information_schema.tables; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW -datafusion public t BASE TABLE - -# Another new table should show up in information schema -statement ok -create table t2 as values (1); - -query TTTT rowsort -SELECT * from information_schema.tables; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW -datafusion public t BASE TABLE -datafusion public t2 BASE TABLE - -query TTTT rowsort -SELECT * from information_schema.tables WHERE tables.table_schema='information_schema'; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW - -query TTTT rowsort -SELECT * from information_schema.tables WHERE information_schema.tables.table_schema='information_schema'; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW - -query TTTT rowsort -SELECT * from information_schema.tables WHERE datafusion.information_schema.tables.table_schema='information_schema'; ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW - -# Cleanup -statement ok -drop table t - -statement ok -drop table t2 - -############ -## SHOW VARIABLES should work -########### - -# target_partitions defaults to num_cores, so set -# to a known value that is unlikely to be -# the real number of cores on a system -statement ok -SET datafusion.execution.target_partitions=7 - -# planning_concurrency defaults to num_cores, so set -# to a known value that is unlikely to be -# the real number of cores on a system -statement ok -SET datafusion.execution.planning_concurrency=13 - -# show all variables -query TT rowsort -SHOW ALL ----- -datafusion.catalog.create_default_catalog_and_schema true -datafusion.catalog.default_catalog datafusion -datafusion.catalog.default_schema public -datafusion.catalog.format NULL -datafusion.catalog.has_header false -datafusion.catalog.information_schema true -datafusion.catalog.location NULL -datafusion.execution.aggregate.scalar_update_factor 10 -datafusion.execution.batch_size 8192 -datafusion.execution.coalesce_batches true -datafusion.execution.collect_statistics false -datafusion.execution.parquet.enable_page_index true -datafusion.execution.parquet.metadata_size_hint NULL -datafusion.execution.parquet.pruning true -datafusion.execution.parquet.pushdown_filters false -datafusion.execution.parquet.reorder_filters false -datafusion.execution.parquet.skip_metadata true -datafusion.execution.planning_concurrency 13 -datafusion.execution.target_partitions 7 -datafusion.execution.time_zone +00:00 -datafusion.explain.logical_plan_only false -datafusion.explain.physical_plan_only false -datafusion.optimizer.allow_symmetric_joins_without_pruning true -datafusion.optimizer.enable_round_robin_repartition true -datafusion.optimizer.filter_null_join_keys false -datafusion.optimizer.hash_join_single_partition_threshold 1048576 -datafusion.optimizer.max_passes 3 -datafusion.optimizer.prefer_hash_join true -datafusion.optimizer.repartition_aggregations true -datafusion.optimizer.repartition_file_min_size 10485760 -datafusion.optimizer.repartition_file_scans true -datafusion.optimizer.repartition_joins true -datafusion.optimizer.repartition_sorts true -datafusion.optimizer.repartition_windows true -datafusion.optimizer.skip_failed_rules false -datafusion.optimizer.top_down_join_key_reordering true -datafusion.sql_parser.dialect generic -datafusion.sql_parser.enable_ident_normalization true -datafusion.sql_parser.parse_float_as_decimal false - -# show_variable_in_config_options -query TT -SHOW datafusion.execution.batch_size ----- -datafusion.execution.batch_size 8192 - -# show_time_zone_default_utc -# https://github.com/apache/arrow-datafusion/issues/3255 -query TT -SHOW TIME ZONE ----- -datafusion.execution.time_zone +00:00 - -# show_timezone_default_utc -# https://github.com/apache/arrow-datafusion/issues/3255 -query TT -SHOW TIMEZONE ----- -datafusion.execution.time_zone +00:00 - - -# information_schema_describe_table - -## some_table -statement ok -CREATE OR REPLACE TABLE some_table AS VALUES (1,2),(3,4); - -query TTT rowsort -DESCRIBE some_table ----- -column1 Int64 YES -column2 Int64 YES - -statement ok -DROP TABLE public.some_table; - -## public.some_table - -statement ok -CREATE OR REPLACE TABLE public.some_table AS VALUES (1,2),(3,4); - -query TTT rowsort -DESCRIBE public.some_table ----- -column1 Int64 YES -column2 Int64 YES - -statement ok -DROP TABLE public.some_table; - -## datafusion.public.some_table - -statement ok -CREATE OR REPLACE TABLE datafusion.public.some_table AS VALUES (1,2),(3,4); - -query TTT rowsort -DESCRIBE datafusion.public.some_table ----- -column1 Int64 YES -column2 Int64 YES - -statement ok -DROP TABLE datafusion.public.some_table; - -# information_schema_describe_table_not_exists - -statement error Error during planning: table 'datafusion.public.table' not found -describe table; - - -# information_schema_show_tables -query TTTT rowsort -SHOW TABLES ----- -datafusion information_schema columns VIEW -datafusion information_schema df_settings VIEW -datafusion information_schema tables VIEW -datafusion information_schema views VIEW - - -# information_schema_show_tables_no_information_schema - -statement ok -set datafusion.catalog.information_schema = false; - -statement error Error during planning: SHOW TABLES is not supported unless information_schema is enabled -SHOW TABLES - -statement ok -set datafusion.catalog.information_schema = true; - - -# information_schema_show_columns -statement ok -CREATE TABLE t AS SELECT 1::int as i; - -statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported -SHOW COLUMNS FROM t LIKE 'f'; - -statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported -SHOW COLUMNS FROM t WHERE column_name = 'bar'; - -query TTTTTT -SHOW COLUMNS FROM t; ----- -datafusion public t i Int32 NO - -# This isn't ideal but it is consistent behavior for `SELECT * from "T"` -statement error Error during planning: table 'datafusion.public.T' not found -SHOW columns from "T" - -# information_schema_show_columns_full_extended -query TTTTITTTIIIIIIT -SHOW FULL COLUMNS FROM t; ----- -datafusion public t i 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL - -# expect same as above -query TTTTITTTIIIIIIT -SHOW EXTENDED COLUMNS FROM t; ----- -datafusion public t i 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL - -# information_schema_show_columns_no_information_schema - -statement ok -set datafusion.catalog.information_schema = false; - -statement error Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled -SHOW COLUMNS FROM t - -statement ok -set datafusion.catalog.information_schema = true; - - -# information_schema_show_columns_names() -query TTTTTT -SHOW columns from public.t ----- -datafusion public t i Int32 NO - -query TTTTTT -SHOW columns from datafusion.public.t ----- -datafusion public t i Int32 NO - -statement error Error during planning: table 'datafusion.public.t2' not found -SHOW columns from t2 - -statement error Error during planning: table 'datafusion.public.t2' not found -SHOW columns from datafusion.public.t2 - - -# show_non_existing_variable -# FIXME -# currently we cannot know whether a variable exists, this will output 0 row instead -statement ok -SHOW SOMETHING_UNKNOWN; - -statement ok -DROP TABLE t; - -# show_unsupported_when_information_schema_off - -statement ok -set datafusion.catalog.information_schema = false; - -statement error Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled -SHOW SOMETHING - -statement ok -set datafusion.catalog.information_schema = true; - - - -# show_create_view() -statement ok -CREATE TABLE abc AS VALUES (1,2,3), (4,5,6); - -statement ok -CREATE VIEW xyz AS SELECT * FROM abc - -query TTTT -SHOW CREATE TABLE xyz ----- -datafusion public xyz CREATE VIEW xyz AS SELECT * FROM abc - -statement ok -DROP TABLE abc; - -statement ok -DROP VIEW xyz; - -# show_create_view_in_catalog -statement ok -CREATE TABLE abc AS VALUES (1,2,3), (4,5,6) - -statement ok -CREATE SCHEMA test; - -statement ok -CREATE VIEW test.xyz AS SELECT * FROM abc; - -query TTTT -SHOW CREATE TABLE test.xyz ----- -datafusion test xyz CREATE VIEW test.xyz AS SELECT * FROM abc - -statement error DataFusion error: Execution error: Cannot drop schema test because other tables depend on it: xyz -DROP SCHEMA test; - -statement ok -DROP TABLE abc; - -statement ok -DROP VIEW test.xyz - - -# show_external_create_table() -statement ok -CREATE EXTERNAL TABLE abc -STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv'; - -query TTTT -SHOW CREATE TABLE abc; ----- -datafusion public abc CREATE EXTERNAL TABLE abc STORED AS CSV LOCATION ../../testing/data/csv/aggregate_test_100.csv diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt deleted file mode 100644 index dae5bb94f4fe4..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/insert.slt +++ /dev/null @@ -1,236 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## INSERT tests -########## - - -statement ok -CREATE EXTERNAL TABLE aggregate_test_100 ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT, - c5 INT, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 BIGINT UNSIGNED NOT NULL, - c10 VARCHAR NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL -) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv' - -# test_insert_into - -statement ok -set datafusion.execution.target_partitions = 8; - -statement ok -CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); - -query TT -EXPLAIN -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -FROM aggregate_test_100 -ORDER by c1 ----- -logical_plan -Dml: op=[Insert] table=[table_without_values] ---Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -----Sort: aggregate_test_100.c1 ASC NULLS LAST -------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 ---------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] -----------TableScan: aggregate_test_100 projection=[c1, c4, c9] -physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] -----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1)), c1@0 as c1] ---------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: "SUM(aggregate_test_100.c4)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 8), input_partitions=8 -----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true - -query II -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -FROM aggregate_test_100 -ORDER by c1 ----- -100 - -# verify there is data now in the table -query I -SELECT COUNT(*) from table_without_values; ----- -100 - -# verify there is data now in the table -query II -SELECT * -FROM table_without_values -ORDER BY field1, field2 -LIMIT 5; ----- --70111 3 --65362 3 --62295 3 --56721 3 --55414 3 - -statement ok -drop table table_without_values; - - - -# test_insert_into_as_select_multi_partitioned -statement ok -CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) - -query TT -EXPLAIN -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 -FROM aggregate_test_100 ----- -logical_plan -Dml: op=[Insert] table=[table_without_values] ---Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] -------TableScan: aggregate_test_100 projection=[c1, c4, c9] -physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---CoalescePartitionsExec -----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: "SUM(aggregate_test_100.c4)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ---------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 8), input_partitions=8 ---------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true - - - -query II -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 -FROM aggregate_test_100 ----- -100 - -statement ok -drop table table_without_values; - - -# test_insert_into_as_select_single_partition - -statement ok -CREATE TABLE table_without_values AS SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 -FROM aggregate_test_100 - - -# // TODO: The generated plan is suboptimal since SortExec is in global state. -query TT -EXPLAIN -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 -FROM aggregate_test_100 -ORDER BY c1 ----- -logical_plan -Dml: op=[Insert] table=[table_without_values] ---Projection: a1 AS a1, a2 AS a2 -----Sort: aggregate_test_100.c1 ASC NULLS LAST -------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 ---------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] -----------TableScan: aggregate_test_100 projection=[c1, c4, c9] -physical_plan -InsertExec: sink=MemoryTable (partitions=8) ---ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] -----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] ---------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: "SUM(aggregate_test_100.c4)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 8), input_partitions=8 -----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true - - -query II -INSERT INTO table_without_values SELECT -SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, -COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 -FROM aggregate_test_100 -ORDER BY c1 ----- -100 - - -statement ok -drop table table_without_values; - -# test_insert_into_with_sort - -statement ok -create table table_without_values(c1 varchar not null); - -# verify that the sort order of the insert query is maintained into the -# insert (there should be a SortExec in the following plan) -# See https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 for more background -query TT -explain insert into table_without_values select c1 from aggregate_test_100 order by c1; ----- -logical_plan -Dml: op=[Insert] table=[table_without_values] ---Projection: aggregate_test_100.c1 AS c1 -----Sort: aggregate_test_100.c1 ASC NULLS LAST -------TableScan: aggregate_test_100 projection=[c1] -physical_plan -InsertExec: sink=MemoryTable (partitions=1) ---ProjectionExec: expr=[c1@0 as c1] -----SortExec: expr=[c1@0 ASC NULLS LAST] -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true - -query T -insert into table_without_values select c1 from aggregate_test_100 order by c1; ----- -100 - -query I -select count(*) from table_without_values; ----- -100 - -statement ok -drop table table_without_values; diff --git a/datafusion/core/tests/sqllogictests/test_files/joins.slt b/datafusion/core/tests/sqllogictests/test_files/joins.slt deleted file mode 100644 index eb8f72470c6cb..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/joins.slt +++ /dev/null @@ -1,624 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## Joins Tests -########## - -# create table t1 -statement ok -CREATE TABLE t1(a INT, b INT, c INT) AS VALUES -(1, 10, 50), -(2, 20, 60), -(3, 30, 70), -(4, 40, 80) - -# create table t2 -statement ok -CREATE TABLE t2(a INT, b INT, c INT) AS VALUES -(1, 100, 500), -(2, 200, 600), -(9, 300, 700), -(4, 400, 800) - -# equijoin -query II nosort -SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a ----- -1 100 -2 200 -4 400 - -query II nosort -SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a ----- -1 100 -2 200 -4 400 - -# inner_join_nulls -query ?? -SELECT * FROM (SELECT null AS id1) t1 -INNER JOIN (SELECT null AS id2) t2 ON id1 = id2 ----- - -statement ok -DROP TABLE t1 - -statement ok -DROP TABLE t2 - - -# create table a -statement ok -CREATE TABLE a(a INT, b INT, c INT) AS VALUES -(1, 10, 50), -(2, 20, 60), -(3, 30, 70), -(4, 40, 80) - -# create table b -statement ok -CREATE TABLE b(a INT, b INT, c INT) AS VALUES -(1, 100, 500), -(2, 200, 600), -(9, 300, 700), -(4, 400, 800) - -# issue_3002 -# // repro case for https://github.com/apache/arrow-datafusion/issues/3002 - -query II -select a.a, b.b from a join b on a.a = b.b ----- - -statement ok -DROP TABLE a - -statement ok -DROP TABLE b - -# create table t1 -statement ok -CREATE TABLE t1(t1_id INT, t1_name VARCHAR) AS VALUES -(11, 'a'), -(22, 'b'), -(33, 'c'), -(44, 'd'), -(77, 'e') - -# create table t2 -statement ok -CREATE TABLE t2(t2_id INT, t2_name VARCHAR) AS VALUES -(11, 'z'), -(22, 'y'), -(44, 'x'), -(55, 'w') - -# left_join_unbalanced -# // the t1_id is larger than t2_id so the join_selection optimizer should kick in -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id ----- -11 a z -22 b y -33 c NULL -44 d x -77 e NULL - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id ----- -11 a z -22 b y -33 c NULL -44 d x -77 e NULL - - -# cross_join_unbalanced -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name ----- -11 a w -11 a x -11 a y -11 a z -22 b w -22 b x -22 b y -22 b z -33 c w -33 c x -33 c y -33 c z -44 d w -44 d x -44 d y -44 d z -77 e w -77 e x -77 e y -77 e z - -statement ok -DROP TABLE t1 - -statement ok -DROP TABLE t2 - -# create table t1 -statement ok -CREATE TABLE t1(t1_id INT, t1_name VARCHAR) AS VALUES -(11, 'a'), -(22, 'b'), -(33, 'c'), -(44, 'd'), -(77, 'e'), -(88, NULL), -(99, NULL) - -# create table t2 -statement ok -CREATE TABLE t2(t2_id INT, t2_name VARCHAR) AS VALUES -(11, 'z'), -(22, NULL), -(44, 'x'), -(55, 'w'), -(99, 'u') - -# left_join_null_filter -# // Since t2 is the non-preserved side of the join, we cannot push down a NULL filter. -# // Note that this is only true because IS NULL does not remove nulls. For filters that -# // remove nulls, we can rewrite the join as an inner join and then push down the filter. -query IIT nosort -SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NULL ORDER BY t1_id ----- -22 22 NULL -33 NULL NULL -77 NULL NULL -88 NULL NULL - -# left_join_null_filter_on_join_column -# // Again, since t2 is the non-preserved side of the join, we cannot push down a NULL filter. -query IIT nosort -SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NULL ORDER BY t1_id ----- -33 NULL NULL -77 NULL NULL -88 NULL NULL - -# left_join_not_null_filter -query IIT nosort -SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id ----- -11 11 z -44 44 x -99 99 u - -# left_join_not_null_filter_on_join_column -query IIT nosort -SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id ----- -11 11 z -22 22 NULL -44 44 x -99 99 u - -# self_join_non_equijoin -query II nosort -SELECT x.t1_id, y.t1_id FROM t1 x JOIN t1 y ON x.t1_id = 11 AND y.t1_id = 44 ----- -11 44 - -# right_join_null_filter -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t2_id ----- -NULL NULL 55 -99 NULL 99 - -# right_join_null_filter_on_join_column -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NULL ORDER BY t2_id ----- -NULL NULL 55 - -# right_join_not_null_filter -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id ----- -11 a 11 -22 b 22 -44 d 44 - -# right_join_not_null_filter_on_join_column -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id ----- -11 a 11 -22 b 22 -44 d 44 -99 NULL 99 - -# full_join_null_filter -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id ----- -88 NULL NULL -99 NULL 99 -NULL NULL 55 - -# full_join_not_null_filter -query ITI nosort -SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id ----- -11 a 11 -22 b 22 -33 c NULL -44 d 44 -77 e NULL - -statement ok -DROP TABLE t1 - -statement ok -DROP TABLE t2 - -# create table t1 -statement ok -CREATE TABLE t1(id INT, t1_name VARCHAR, t1_int INT) AS VALUES -(11, 'a', 1), -(22, 'b', 2), -(33, 'c', 3), -(44, 'd', 4) - -# create table t2 -statement ok -CREATE TABLE t2(id INT, t2_name VARCHAR, t2_int INT) AS VALUES -(11, 'z', 3), -(22, 'y', 1), -(44, 'x', 3), -(55, 'w', 3) - -# left_join_using - -# set repartition_joins to true -statement ok -set datafusion.optimizer.repartition_joins = true - -query ITT nosort -SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id ----- -11 a z -22 b y -33 c NULL -44 d x - -# set repartition_joins to false -statement ok -set datafusion.optimizer.repartition_joins = false - -query ITT nosort -SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id ----- -11 a z -22 b y -33 c NULL -44 d x - -statement ok -DROP TABLE t1 - -statement ok -DROP TABLE t2 - -# create table t1 -statement ok -CREATE TABLE t1(t1_id INT, t1_name VARCHAR, t1_int INT) AS VALUES -(11, 'a', 1), -(22, 'b', 2), -(33, 'c', 3), -(44, 'd', 4) - -# create table t2 -statement ok -CREATE TABLE t2(t2_id INT, t2_name VARCHAR, t2_int INT) AS VALUES -(11, 'z', 3), -(22, 'y', 1), -(44, 'x', 3), -(55, 'w', 3) - -# cross_join - -# set repartition_joins to true -statement ok -set datafusion.optimizer.repartition_joins = true - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITITI rowsort -SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 ----- -11 a 11 z 3 -11 a 11 z 3 -11 a 22 y 1 -11 a 22 y 1 -11 a 44 x 3 -11 a 44 x 3 -11 a 55 w 3 -11 a 55 w 3 -22 b 11 z 3 -22 b 11 z 3 -22 b 22 y 1 -22 b 22 y 1 -22 b 44 x 3 -22 b 44 x 3 -22 b 55 w 3 -22 b 55 w 3 -33 c 11 z 3 -33 c 11 z 3 -33 c 22 y 1 -33 c 22 y 1 -33 c 44 x 3 -33 c 44 x 3 -33 c 55 w 3 -33 c 55 w 3 -44 d 11 z 3 -44 d 11 z 3 -44 d 22 y 1 -44 d 22 y 1 -44 d 44 x 3 -44 d 44 x 3 -44 d 55 w 3 -44 d 55 w 3 - -query ITT rowsort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2_data ----- -11 a w -11 a w -11 a x -11 a x -11 a y -11 a y -11 a z -11 a z -22 b w -22 b w -22 b x -22 b x -22 b y -22 b y -22 b z -22 b z -33 c w -33 c w -33 c x -33 c x -33 c y -33 c y -33 c z -33 c z -44 d w -44 d w -44 d x -44 d x -44 d y -44 d y -44 d z -44 d z - -# set repartition_joins to true -statement ok -set datafusion.optimizer.repartition_joins = false - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id ----- -11 a z -11 a y -11 a x -11 a w -22 b z -22 b y -22 b x -22 b w -33 c z -33 c y -33 c x -33 c w -44 d z -44 d y -44 d x -44 d w - -query ITITI rowsort -SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 ----- -11 a 11 z 3 -11 a 11 z 3 -11 a 22 y 1 -11 a 22 y 1 -11 a 44 x 3 -11 a 44 x 3 -11 a 55 w 3 -11 a 55 w 3 -22 b 11 z 3 -22 b 11 z 3 -22 b 22 y 1 -22 b 22 y 1 -22 b 44 x 3 -22 b 44 x 3 -22 b 55 w 3 -22 b 55 w 3 -33 c 11 z 3 -33 c 11 z 3 -33 c 22 y 1 -33 c 22 y 1 -33 c 44 x 3 -33 c 44 x 3 -33 c 55 w 3 -33 c 55 w 3 -44 d 11 z 3 -44 d 11 z 3 -44 d 22 y 1 -44 d 22 y 1 -44 d 44 x 3 -44 d 44 x 3 -44 d 55 w 3 -44 d 55 w 3 - -query ITT rowsort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2_data ----- -11 a w -11 a w -11 a x -11 a x -11 a y -11 a y -11 a z -11 a z -22 b w -22 b w -22 b x -22 b x -22 b y -22 b y -22 b z -22 b z -33 c w -33 c w -33 c x -33 c x -33 c y -33 c y -33 c z -33 c z -44 d w -44 d w -44 d x -44 d x -44 d y -44 d y -44 d z -44 d z - -statement ok -DROP TABLE t1 - -statement ok -DROP TABLE t2 diff --git a/datafusion/core/tests/sqllogictests/test_files/limit.slt b/datafusion/core/tests/sqllogictests/test_files/limit.slt deleted file mode 100644 index 253ca8f335afb..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/limit.slt +++ /dev/null @@ -1,302 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## Limit Tests -########## - -statement ok -CREATE EXTERNAL TABLE aggregate_test_100 ( - c1 VARCHAR NOT NULL, - c2 TINYINT NOT NULL, - c3 SMALLINT NOT NULL, - c4 SMALLINT, - c5 INT, - c6 BIGINT NOT NULL, - c7 SMALLINT NOT NULL, - c8 INT NOT NULL, - c9 BIGINT UNSIGNED NOT NULL, - c10 VARCHAR NOT NULL, - c11 FLOAT NOT NULL, - c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL -) -STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv' - -# async fn csv_query_limit -query T -SELECT c1 FROM aggregate_test_100 LIMIT 2 ----- -c -d - -# async fn csv_query_limit_bigger_than_nbr_of_rows -query I -SELECT c2 FROM aggregate_test_100 LIMIT 200 ----- -2 -5 -1 -1 -5 -4 -3 -3 -1 -4 -1 -4 -3 -2 -1 -1 -2 -1 -3 -2 -4 -1 -5 -4 -2 -1 -4 -5 -2 -3 -4 -2 -1 -5 -3 -1 -2 -3 -3 -3 -2 -4 -1 -3 -2 -5 -2 -1 -4 -1 -4 -2 -5 -4 -2 -3 -4 -4 -4 -5 -4 -2 -1 -2 -4 -2 -3 -5 -1 -1 -4 -2 -1 -2 -1 -1 -5 -4 -5 -2 -3 -2 -4 -1 -3 -4 -3 -2 -5 -3 -3 -2 -5 -5 -4 -1 -3 -3 -4 -4 - -# async fn csv_query_limit_with_same_nbr_of_rows -query I -SELECT c2 FROM aggregate_test_100 LIMIT 100 ----- -2 -5 -1 -1 -5 -4 -3 -3 -1 -4 -1 -4 -3 -2 -1 -1 -2 -1 -3 -2 -4 -1 -5 -4 -2 -1 -4 -5 -2 -3 -4 -2 -1 -5 -3 -1 -2 -3 -3 -3 -2 -4 -1 -3 -2 -5 -2 -1 -4 -1 -4 -2 -5 -4 -2 -3 -4 -4 -4 -5 -4 -2 -1 -2 -4 -2 -3 -5 -1 -1 -4 -2 -1 -2 -1 -1 -5 -4 -5 -2 -3 -2 -4 -1 -3 -4 -3 -2 -5 -3 -3 -2 -5 -5 -4 -1 -3 -3 -4 -4 - -# async fn csv_query_limit_zero -query T -SELECT c1 FROM aggregate_test_100 LIMIT 0 ----- - -# async fn csv_offset_without_limit_99 -query T -SELECT c1 FROM aggregate_test_100 OFFSET 99 ----- -e - -# async fn csv_offset_without_limit_100 -query T -SELECT c1 FROM aggregate_test_100 OFFSET 100 ----- - -# async fn csv_offset_without_limit_101 -query T -SELECT c1 FROM aggregate_test_100 OFFSET 101 ----- - -# async fn csv_query_offset -query T -SELECT c1 FROM aggregate_test_100 OFFSET 2 LIMIT 2 ----- -b -a - -# async fn csv_query_offset_the_same_as_nbr_of_rows -query T -SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 100 ----- - -# async fn csv_query_offset_bigger_than_nbr_of_rows -query T -SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101 ----- - -######## -# Clean up after the test -######## - -statement ok -drop table aggregate_test_100; diff --git a/datafusion/core/tests/sqllogictests/test_files/math.slt b/datafusion/core/tests/sqllogictests/test_files/math.slt deleted file mode 100644 index 152e8b78bdfa3..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/math.slt +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## Math expression Tests -########## - -statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv'; - -# Round -query R -SELECT ROUND(c1) FROM aggregate_simple ----- -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 - -# Round -query R -SELECT round(c1/3, 2) FROM aggregate_simple order by c1 ----- -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 - -# Round -query R -SELECT round(c1, 4) FROM aggregate_simple order by c1 ----- -0 -0 -0 -0 -0 -0 -0 -0 -0 -0 -0.0001 -0.0001 -0.0001 -0.0001 -0.0001 - -# Round -query RRRRRRRR -SELECT round(125.2345, -3), round(125.2345, -2), round(125.2345, -1), round(125.2345), round(125.2345, 0), round(125.2345, 1), round(125.2345, 2), round(125.2345, 3) ----- -0 100 130 125 125 125.2 125.23 125.235 - -# atan2 -query RRRRRRR -SELECT atan2(2.0, 1.0), atan2(-2.0, 1.0), atan2(2.0, -1.0), atan2(-2.0, -1.0), atan2(NULL, 1.0), atan2(2.0, NULL), atan2(NULL, NULL); ----- -1.107148717794 -1.107148717794 2.034443935796 -2.034443935796 NULL NULL NULL diff --git a/datafusion/core/tests/sqllogictests/test_files/subquery.slt b/datafusion/core/tests/sqllogictests/test_files/subquery.slt deleted file mode 100644 index 780b24be63fd6..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/subquery.slt +++ /dev/null @@ -1,109 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -############# -## Subquery Tests -############# - -# two tables for subquery -statement ok -CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES -(11, 'a', 1), -(22, 'b', 2), -(33, 'c', 3), -(44, 'd', 4); - -statement ok -CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES -(11, 'z', 3), -(22, 'y', 1), -(44, 'x', 3), -(55, 'w', 3); - - -# in_subquery_to_join_with_correlated_outer_filter -query ITI rowsort -select t1.t1_id, - t1.t1_name, - t1.t1_int -from t1 -where t1.t1_id + 12 in ( - select t2.t2_id + 1 from t2 where t1.t1_int > 0 - ) ----- -11 a 1 -33 c 3 -44 d 4 - -# not_in_subquery_to_join_with_correlated_outer_filter -query ITI rowsort -select t1.t1_id, - t1.t1_name, - t1.t1_int -from t1 -where t1.t1_id + 12 not in ( - select t2.t2_id + 1 from t2 where t1.t1_int > 0 - ) ----- -22 b 2 - -# in subquery with two parentheses, see #5529 -query ITI rowsort -select t1.t1_id, - t1.t1_name, - t1.t1_int -from t1 -where t1.t1_id in (( - select t2.t2_id from t2 - )) ----- -11 a 1 -22 b 2 -44 d 4 - -query ITI rowsort -select t1.t1_id, - t1.t1_name, - t1.t1_int -from t1 -where t1.t1_id in (( - select t2.t2_id from t2 - )) -and t1.t1_int < 3 ----- -11 a 1 -22 b 2 - -query ITI rowsort -select t1.t1_id, - t1.t1_name, - t1.t1_int -from t1 -where t1.t1_id not in (( - select t2.t2_id from t2 where t2.t2_int = 3 - )) ----- -22 b 2 -33 c 3 - -# VALUES in subqueries, see 6017 -query I -select t1_id -from t1 -where t1_int = (select max(i) from (values (1)) as s(i)); ----- -11 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part deleted file mode 100644 index bc6d166b8680f..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part +++ /dev/null @@ -1,181 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -query TT -explain select - ps_partkey, - sum(ps_supplycost * ps_availqty) as value -from - partsupp, - supplier, - nation -where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' -group by - ps_partkey having - sum(ps_supplycost * ps_availqty) > ( - select - sum(ps_supplycost * ps_availqty) * 0.0001 - from - partsupp, - supplier, - nation - where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' - ) -order by - value desc -limit 10; ----- -logical_plan -Limit: skip=0, fetch=10 ---Sort: value DESC NULLS FIRST, fetch=10 -----Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value -------Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.__value ---------CrossJoin: -----------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost ---------------Inner Join: supplier.s_nationkey = nation.n_nationkey -----------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey ---------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] ---------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -----------------Projection: nation.n_nationkey -------------------Filter: nation.n_name = Utf8("GERMANY") ---------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] -----------SubqueryAlias: __scalar_sq_1 -------------Projection: CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) AS __value ---------------Aggregate: groupBy=[[]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost -------------------Inner Join: supplier.s_nationkey = nation.n_nationkey ---------------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -----------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -------------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] -------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] ---------------------Projection: nation.n_nationkey -----------------------Filter: nation.n_name = Utf8("GERMANY") -------------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] -physical_plan -GlobalLimitExec: skip=0, fetch=10 ---SortExec: fetch=10, expr=[value@1 DESC] -----ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] -------CoalesceBatchesExec: target_batch_size=8192 ---------FilterExec: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > __value@2 -----------CrossJoinExec -------------CoalescePartitionsExec ---------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 ---------------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -----------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_availqty@1 as ps_availqty, ps_supplycost@2 as ps_supplycost] -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })] -----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 3 }], 4), input_partitions=4 ---------------------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_availqty@2 as ps_availqty, ps_supplycost@3 as ps_supplycost, s_nationkey@5 as s_nationkey] -----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_suppkey", index: 1 }, Column { name: "s_suppkey", index: 0 })] ---------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 1 }], 4), input_partitions=4 -------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], has_header=false ---------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 -------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 ---------------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] -----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------FilterExec: n_name@1 = GERMANY ---------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false -------------ProjectionExec: expr=[CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as __value] ---------------AggregateExec: mode=Final, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -----------------CoalescePartitionsExec -------------------AggregateExec: mode=Partial, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] ---------------------ProjectionExec: expr=[ps_availqty@0 as ps_availqty, ps_supplycost@1 as ps_supplycost] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 2 }, Column { name: "n_nationkey", index: 0 })] ---------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 2 }], 4), input_partitions=4 -------------------------------ProjectionExec: expr=[ps_availqty@1 as ps_availqty, ps_supplycost@2 as ps_supplycost, s_nationkey@4 as s_nationkey] ---------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_suppkey", index: 0 }, Column { name: "s_suppkey", index: 0 })] -------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 0 }], 4), input_partitions=4 -----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], has_header=false -------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 -----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ---------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 -------------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] ---------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------FilterExec: n_name@1 = GERMANY -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false - - - -query IR -select - ps_partkey, - sum(ps_supplycost * ps_availqty) as value -from - partsupp, - supplier, - nation -where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' -group by - ps_partkey having - sum(ps_supplycost * ps_availqty) > ( - select - sum(ps_supplycost * ps_availqty) * 0.0001 - from - partsupp, - supplier, - nation - where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' - ) -order by - value desc -limit 10; ----- -12098 16227681.21 -5134 15709338.52 -13334 15023662.41 -17052 14351644.2 -3452 14070870.14 -12552 13332469.18 -1084 13170428.29 -5797 13038622.72 -12633 12892561.61 -403 12856217.34 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part deleted file mode 100644 index 64848f41fc5af..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part +++ /dev/null @@ -1,118 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -query TT -explain select - c_count, - count(*) as custdist -from - ( - select - c_custkey, - count(o_orderkey) - from - customer left outer join orders on - c_custkey = o_custkey - and o_comment not like '%special%requests%' - group by - c_custkey - ) as c_orders (c_custkey, c_count) -group by - c_count -order by - custdist desc, - c_count desc -limit 10; ----- -logical_plan -Limit: skip=0, fetch=10 ---Sort: custdist DESC NULLS FIRST, c_count DESC NULLS FIRST, fetch=10 -----Projection: c_count, COUNT(UInt8(1)) AS custdist -------Aggregate: groupBy=[[c_count]], aggr=[[COUNT(UInt8(1))]] ---------Projection: c_orders.COUNT(orders.o_orderkey) AS c_count -----------SubqueryAlias: c_orders -------------Projection: COUNT(orders.o_orderkey) ---------------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[COUNT(orders.o_orderkey)]] -----------------Projection: customer.c_custkey, orders.o_orderkey -------------------Left Join: customer.c_custkey = orders.o_custkey ---------------------TableScan: customer projection=[c_custkey] ---------------------Projection: orders.o_orderkey, orders.o_custkey -----------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") -------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] -physical_plan -GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC] -----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC] -------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as custdist] ---------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(UInt8(1))] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "c_count", index: 0 }], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[COUNT(UInt8(1))] -----------------ProjectionExec: expr=[COUNT(orders.o_orderkey)@0 as c_count] -------------------ProjectionExec: expr=[COUNT(orders.o_orderkey)@1 as COUNT(orders.o_orderkey)] ---------------------AggregateExec: mode=Single, gby=[c_custkey@0 as c_custkey], aggr=[COUNT(orders.o_orderkey)] -----------------------ProjectionExec: expr=[c_custkey@0 as c_custkey, o_orderkey@1 as o_orderkey] -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 1 })] -----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 ---------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], has_header=false -----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 1 }], 4), input_partitions=4 ---------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] -----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------FilterExec: o_comment@2 NOT LIKE %special%requests% ---------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_comment], has_header=false - - - -query II -select - c_count, - count(*) as custdist -from - ( - select - c_custkey, - count(o_orderkey) - from - customer left outer join orders on - c_custkey = o_custkey - and o_comment not like '%special%requests%' - group by - c_custkey - ) as c_orders (c_custkey, c_count) -group by - c_count -order by - custdist desc, - c_count desc -limit 10; ----- -0 5000 -10 665 -9 657 -11 621 -12 567 -8 564 -13 492 -18 482 -7 480 -20 456 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q22.slt.part b/datafusion/core/tests/sqllogictests/test_files/tpch/q22.slt.part deleted file mode 100644 index 9c7dd85ccd82f..0000000000000 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q22.slt.part +++ /dev/null @@ -1,161 +0,0 @@ - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -query TT -explain select - cntrycode, - count(*) as numcust, - sum(c_acctbal) as totacctbal -from - ( - select - substring(c_phone from 1 for 2) as cntrycode, - c_acctbal - from - customer - where - substring(c_phone from 1 for 2) in - ('13', '31', '23', '29', '30', '18', '17') - and c_acctbal > ( - select - avg(c_acctbal) - from - customer - where - c_acctbal > 0.00 - and substring(c_phone from 1 for 2) in - ('13', '31', '23', '29', '30', '18', '17') - ) - and not exists ( - select - * - from - orders - where - o_custkey = c_custkey - ) - ) as custsale -group by - cntrycode -order by - cntrycode; ----- -logical_plan -Sort: custsale.cntrycode ASC NULLS LAST ---Projection: custsale.cntrycode, COUNT(UInt8(1)) AS numcust, SUM(custsale.c_acctbal) AS totacctbal -----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(custsale.c_acctbal)]] -------SubqueryAlias: custsale ---------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -----------Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_11.__value -------------CrossJoin: ---------------Projection: customer.c_phone, customer.c_acctbal -----------------LeftAnti Join: customer.c_custkey = __correlated_sq_13.o_custkey -------------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) ---------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] -------------------SubqueryAlias: __correlated_sq_13 ---------------------TableScan: orders projection=[o_custkey] ---------------SubqueryAlias: __scalar_sq_11 -----------------Projection: AVG(customer.c_acctbal) AS __value -------------------Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] ---------------------Projection: customer.c_acctbal -----------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) -------------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] -physical_plan -SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] ---SortExec: expr=[cntrycode@0 ASC NULLS LAST] -----ProjectionExec: expr=[cntrycode@0 as cntrycode, COUNT(UInt8(1))@1 as numcust, SUM(custsale.c_acctbal)@2 as totacctbal] -------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(UInt8(1)), SUM(custsale.c_acctbal)] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "cntrycode", index: 0 }], 4), input_partitions=1 -------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(UInt8(1)), SUM(custsale.c_acctbal)] ---------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------FilterExec: CAST(c_acctbal@1 AS Decimal128(19, 6)) > __value@2 ---------------------CrossJoinExec -----------------------CoalescePartitionsExec -------------------------ProjectionExec: expr=[c_phone@1 as c_phone, c_acctbal@2 as c_acctbal] ---------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 0 })] -------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 -----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------FilterExec: Use substr(c_phone@1, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) ---------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], has_header=false -------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 0 }], 4), input_partitions=4 -----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_custkey], has_header=false -----------------------ProjectionExec: expr=[AVG(customer.c_acctbal)@0 as __value] -------------------------AggregateExec: mode=Final, gby=[], aggr=[AVG(customer.c_acctbal)] ---------------------------CoalescePartitionsExec -----------------------------AggregateExec: mode=Partial, gby=[], aggr=[AVG(customer.c_acctbal)] -------------------------------ProjectionExec: expr=[c_acctbal@1 as c_acctbal] ---------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], has_header=false - - - -query TIR -select - cntrycode, - count(*) as numcust, - sum(c_acctbal) as totacctbal -from - ( - select - substring(c_phone from 1 for 2) as cntrycode, - c_acctbal - from - customer - where - substring(c_phone from 1 for 2) in - ('13', '31', '23', '29', '30', '18', '17') - and c_acctbal > ( - select - avg(c_acctbal) - from - customer - where - c_acctbal > 0.00 - and substring(c_phone from 1 for 2) in - ('13', '31', '23', '29', '30', '18', '17') - ) - and not exists ( - select - * - from - orders - where - o_custkey = c_custkey - ) - ) as custsale -group by - cntrycode -order by - cntrycode; ----- -13 94 714035.05 -17 96 722560.15 -18 99 738012.52 -23 93 708285.25 -29 85 632693.46 -30 87 646748.02 -31 87 647372.5 diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index a4875d5cbf338..4db97c75cb33e 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -557,7 +557,6 @@ async fn tpcds_physical_q5() -> Result<()> { create_physical_plan(5).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q6() -> Result<()> { create_physical_plan(6).await @@ -568,13 +567,11 @@ async fn tpcds_physical_q7() -> Result<()> { create_physical_plan(7).await } -#[ignore] // The type of Int32 = Int64 of binary physical should be same #[tokio::test] async fn tpcds_physical_q8() -> Result<()> { create_physical_plan(8).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await @@ -601,7 +598,6 @@ async fn tpcds_physical_q13() -> Result<()> { create_physical_plan(13).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q14() -> Result<()> { create_physical_plan(14).await @@ -647,7 +643,6 @@ async fn tpcds_physical_q22() -> Result<()> { create_physical_plan(22).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q23() -> Result<()> { create_physical_plan(23).await @@ -755,7 +750,6 @@ async fn tpcds_physical_q43() -> Result<()> { create_physical_plan(43).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await @@ -807,7 +801,6 @@ async fn tpcds_physical_q53() -> Result<()> { create_physical_plan(53).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q54() -> Result<()> { create_physical_plan(54).await @@ -828,7 +821,6 @@ async fn tpcds_physical_q57() -> Result<()> { create_physical_plan(57).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q58() -> Result<()> { create_physical_plan(58).await @@ -965,7 +957,6 @@ async fn tpcds_physical_q84() -> Result<()> { create_physical_plan(84).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q85() -> Result<()> { create_physical_plan(85).await @@ -1054,7 +1045,7 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { let sql = fs::read_to_string(filename).expect("Could not read query"); let config = SessionConfig::default(); - let ctx = SessionContext::with_config(config); + let ctx = SessionContext::new_with_config(config); let tables = get_table_definitions(); for table in &tables { ctx.register_table( diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs new file mode 100644 index 0000000000000..6c6d966cc3aab --- /dev/null +++ b/datafusion/core/tests/user_defined/mod.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Tests for user defined Scalar functions +mod user_defined_scalar_functions; + +/// Tests for User Defined Aggregate Functions +mod user_defined_aggregates; + +/// Tests for User Defined Plans +mod user_defined_plan; + +/// Tests for User Defined Window Functions +mod user_defined_window_functions; + +/// Tests for User Defined Table Functions +mod user_defined_table_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs new file mode 100644 index 0000000000000..fb0ecd02c6b09 --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -0,0 +1,626 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains end to end demonstrations of creating +//! user defined aggregate functions + +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::Int32Array; +use arrow_schema::Schema; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use datafusion::datasource::MemTable; +use datafusion::{ + arrow::{ + array::{ArrayRef, Float64Array, TimestampNanosecondArray}, + datatypes::{DataType, Field, Float64Type, TimeUnit, TimestampNanosecondType}, + record_batch::RecordBatch, + }, + assert_batches_eq, + error::Result, + logical_expr::{ + AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, + StateTypeFunction, TypeSignature, Volatility, + }, + physical_plan::Accumulator, + prelude::SessionContext, + scalar::ScalarValue, +}; +use datafusion_common::{ + assert_contains, cast::as_primitive_array, exec_err, DataFusionError, +}; +use datafusion_expr::create_udaf; +use datafusion_physical_expr::expressions::AvgAccumulator; + +/// Test to show the contents of the setup +#[tokio::test] +async fn test_setup() { + let TestContext { ctx, test_state: _ } = TestContext::new(); + let sql = "SELECT * from t order by time"; + let expected = [ + "+-------+----------------------------+", + "| value | time |", + "+-------+----------------------------+", + "| 2.0 | 1970-01-01T00:00:00.000002 |", + "| 3.0 | 1970-01-01T00:00:00.000003 |", + "| 1.0 | 1970-01-01T00:00:00.000004 |", + "| 5.0 | 1970-01-01T00:00:00.000005 |", + "| 5.0 | 1970-01-01T00:00:00.000005 |", + "+-------+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +/// Basic user defined aggregate +#[tokio::test] +async fn test_udaf() { + let TestContext { ctx, test_state } = TestContext::new(); + assert!(!test_state.update_batch()); + let sql = "SELECT time_sum(time) from t"; + let expected = [ + "+----------------------------+", + "| time_sum(t.time) |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000019 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + // normal aggregates call update_batch + assert!(test_state.update_batch()); + assert!(!test_state.retract_batch()); +} + +/// User defined aggregate used as a window function +#[tokio::test] +async fn test_udaf_as_window() { + let TestContext { ctx, test_state } = TestContext::new(); + let sql = "SELECT time_sum(time) OVER() as time_sum from t"; + let expected = [ + "+----------------------------+", + "| time_sum |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000019 |", + "| 1970-01-01T00:00:00.000019 |", + "| 1970-01-01T00:00:00.000019 |", + "| 1970-01-01T00:00:00.000019 |", + "| 1970-01-01T00:00:00.000019 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + // aggregate over the entire window function call update_batch + assert!(test_state.update_batch()); + assert!(!test_state.retract_batch()); +} + +/// User defined aggregate used as a window function with a window frame +#[tokio::test] +async fn test_udaf_as_window_with_frame() { + let TestContext { ctx, test_state } = TestContext::new(); + let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; + let expected = [ + "+----------------------------+", + "| time_sum |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000005 |", + "| 1970-01-01T00:00:00.000009 |", + "| 1970-01-01T00:00:00.000012 |", + "| 1970-01-01T00:00:00.000014 |", + "| 1970-01-01T00:00:00.000010 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + // user defined aggregates with window frame should be calling retract batch + assert!(test_state.update_batch()); + assert!(test_state.retract_batch()); +} + +/// Ensure that User defined aggregate used as a window function with a window +/// frame, but that does not implement retract_batch, returns an error +#[tokio::test] +async fn test_udaf_as_window_with_frame_without_retract_batch() { + let test_state = Arc::new(TestState::new().with_error_on_retract_batch()); + + let TestContext { ctx, test_state: _ } = TestContext::new_with_test_state(test_state); + let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; + // Note if this query ever does start working + let err = execute(&ctx, sql).await.unwrap_err(); + assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\""); +} + +/// Basic query for with a udaf returning a structure +#[tokio::test] +async fn test_udaf_returning_struct() { + let TestContext { ctx, test_state: _ } = TestContext::new(); + let sql = "SELECT first(value, time) from t"; + let expected = [ + "+------------------------------------------------+", + "| first(t.value,t.time) |", + "+------------------------------------------------+", + "| {value: 2.0, time: 1970-01-01T00:00:00.000002} |", + "+------------------------------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +/// Demonstrate extracting the fields from a structure using a subquery +#[tokio::test] +async fn test_udaf_returning_struct_subquery() { + let TestContext { ctx, test_state: _ } = TestContext::new(); + let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq"; + let expected = [ + "+-----------------+----------------------------+", + "| sq.first[value] | sq.first[time] |", + "+-----------------+----------------------------+", + "| 2.0 | 1970-01-01T00:00:00.000002 |", + "+-----------------+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +#[tokio::test] +async fn test_udaf_shadows_builtin_fn() { + let TestContext { + mut ctx, + test_state, + } = TestContext::new(); + let sql = "SELECT sum(arrow_cast(time, 'Int64')) from t"; + + // compute with builtin `sum` aggregator + let expected = [ + "+-------------+", + "| SUM(t.time) |", + "+-------------+", + "| 19000 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); + + // Register `TimeSum` with name `sum`. This will shadow the builtin one + let sql = "SELECT sum(time) from t"; + TimeSum::register(&mut ctx, test_state.clone(), "sum"); + let expected = [ + "+----------------------------+", + "| sum(t.time) |", + "+----------------------------+", + "| 1970-01-01T00:00:00.000019 |", + "+----------------------------+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +async fn execute(ctx: &SessionContext, sql: &str) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch1 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![4, 5]))], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + + // define a udaf, using a DataFusion's accumulator + let my_avg = create_udaf( + "my_avg", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?; + + let expected = [ + "+-------------+", + "| my_avg(t.a) |", + "+-------------+", + "| 3.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +#[tokio::test] +async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + // Note capitalization + let my_avg = create_udaf( + "MY_AVG", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|_| Ok(Box::::default())), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + // doesn't work as it was registered as non lowercase + let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err(); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_avg\'")); + + // Can call it if you put quotes + let result = ctx + .sql("SELECT \"MY_AVG\"(i) FROM t") + .await? + .collect() + .await?; + + let expected = [ + "+-------------+", + "| MY_AVG(t.i) |", + "+-------------+", + "| 1.0 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +/// Returns an context with a table "t" and the "first" and "time_sum" +/// aggregate functions registered. +/// +/// "t" contains this data: +/// +/// ```text +/// value | time +/// 3.0 | 1970-01-01T00:00:00.000003 +/// 2.0 | 1970-01-01T00:00:00.000002 +/// 1.0 | 1970-01-01T00:00:00.000004 +/// 5.0 | 1970-01-01T00:00:00.000005 +/// 5.0 | 1970-01-01T00:00:00.000005 +/// ``` +struct TestContext { + ctx: SessionContext, + test_state: Arc, +} + +impl TestContext { + fn new() -> Self { + let test_state = Arc::new(TestState::new()); + Self::new_with_test_state(test_state) + } + + fn new_with_test_state(test_state: Arc) -> Self { + let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]); + let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]); + + let batch = RecordBatch::try_from_iter(vec![ + ("value", Arc::new(value) as _), + ("time", Arc::new(time) as _), + ]) + .unwrap(); + + let mut ctx = SessionContext::new(); + + ctx.register_batch("t", batch).unwrap(); + + // Tell DataFusion about the "first" function + FirstSelector::register(&mut ctx); + // Tell DataFusion about the "time_sum" function + TimeSum::register(&mut ctx, Arc::clone(&test_state), "time_sum"); + + Self { ctx, test_state } + } +} + +#[derive(Debug, Default)] +struct TestState { + /// was update_batch called? + update_batch: AtomicBool, + /// was retract_batch called? + retract_batch: AtomicBool, + /// should the udaf throw an error if retract batch is called? Can + /// only be configured at construction time. + error_on_retract_batch: bool, +} + +impl TestState { + fn new() -> Self { + Default::default() + } + + /// Has `update_batch` been called? + fn update_batch(&self) -> bool { + self.update_batch.load(Ordering::SeqCst) + } + + /// Set the `update_batch` flag + fn set_update_batch(&self) { + self.update_batch.store(true, Ordering::SeqCst) + } + + /// Has `retract_batch` been called? + fn retract_batch(&self) -> bool { + self.retract_batch.load(Ordering::SeqCst) + } + + /// set the `retract_batch` flag + fn set_retract_batch(&self) { + self.retract_batch.store(true, Ordering::SeqCst) + } + + /// Is this state configured to return an error on retract batch? + fn error_on_retract_batch(&self) -> bool { + self.error_on_retract_batch + } + + /// Configure the test to return error on retract batch + fn with_error_on_retract_batch(mut self) -> Self { + self.error_on_retract_batch = true; + self + } +} + +/// Models a user defined aggregate function that computes the a sum +/// of timestamps (not a quantity that has much real world meaning) +#[derive(Debug)] +struct TimeSum { + sum: i64, + test_state: Arc, +} + +impl TimeSum { + fn new(test_state: Arc) -> Self { + Self { sum: 0, test_state } + } + + fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { + let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + + // Returns the same type as its input + let return_type = Arc::new(timestamp_type.clone()); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::clone(&return_type))); + + let state_type = Arc::new(vec![timestamp_type.clone()]); + let state_type: StateTypeFunction = + Arc::new(move |_| Ok(Arc::clone(&state_type))); + + let volatility = Volatility::Immutable; + + let signature = Signature::exact(vec![timestamp_type], volatility); + + let captured_state = Arc::clone(&test_state); + let accumulator: AccumulatorFactoryFunction = + Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); + + let time_sum = + AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); + + // register the selector as "time_sum" + ctx.register_udaf(time_sum) + } +} + +impl Accumulator for TimeSum { + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.test_state.set_update_batch(); + assert_eq!(values.len(), 1); + let arr = &values[0]; + let arr = arr.as_primitive::(); + + for v in arr.values().iter() { + println!("Adding {v}"); + self.sum += v; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // merge and update is the same for time sum + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + println!("Evaluating to {}", self.sum); + Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None)) + } + + fn size(&self) -> usize { + // accurate size estimates are not important for this example + 42 + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if self.test_state.error_on_retract_batch() { + return exec_err!("Error in Retract Batch"); + } + + self.test_state.set_retract_batch(); + assert_eq!(values.len(), 1); + let arr = &values[0]; + let arr = arr.as_primitive::(); + + for v in arr.values().iter() { + println!("Retracting {v}"); + self.sum -= v; + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + !self.test_state.error_on_retract_batch() + } +} + +/// Models a specialized timeseries aggregate function +/// called a "selector" in InfluxQL and Flux. +/// +/// It returns the value and corresponding timestamp of the +/// input with the earliest timestamp as a structure. +#[derive(Debug, Clone)] +struct FirstSelector { + value: f64, + time: i64, +} + +impl FirstSelector { + /// Create a new empty selector + fn new() -> Self { + Self { + value: 0.0, + time: i64::MAX, + } + } + + fn register(ctx: &mut SessionContext) { + let return_type = Arc::new(Self::output_datatype()); + let state_type = Arc::new(Self::state_datatypes()); + + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + + // Possible input signatures + let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; + + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::new(Self::new()))); + + let volatility = Volatility::Immutable; + + let name = "first"; + + let first = AggregateUDF::new( + name, + &Signature::one_of(signatures, volatility), + &return_type, + &accumulator, + &state_type, + ); + + // register the selector as "first" + ctx.register_udaf(first) + } + + /// Return the schema fields + fn fields() -> Fields { + vec![ + Field::new("value", DataType::Float64, true), + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + ] + .into() + } + + fn output_datatype() -> DataType { + DataType::Struct(Self::fields()) + } + + fn input_datatypes() -> Vec { + vec![ + DataType::Float64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ] + } + + // Internally, keep the data types as this type + fn state_datatypes() -> Vec { + vec![ + DataType::Float64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ] + } + + /// Convert to a set of ScalarValues + fn to_state(&self) -> Vec { + vec![ + ScalarValue::Float64(Some(self.value)), + ScalarValue::TimestampNanosecond(Some(self.time), None), + ] + } + + /// return this selector as a single scalar (struct) value + fn to_scalar(&self) -> ScalarValue { + ScalarValue::Struct(Some(self.to_state()), Self::fields()) + } +} + +impl Accumulator for FirstSelector { + fn state(&self) -> Result> { + let state = self.to_state().into_iter().collect::>(); + + Ok(state) + } + + /// produce the output structure + fn evaluate(&self) -> Result { + Ok(self.to_scalar()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // cast argumets to the appropriate type (DataFusion will type + // check these based on the declared allowed input types) + let v = as_primitive_array::(&values[0])?; + let t = as_primitive_array::(&values[1])?; + + // Update the actual values + for (value, time) in v.iter().zip(t.iter()) { + if let (Some(time), Some(value)) = (time, value) { + if time < self.time { + self.value = value; + self.time = time; + } + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // same logic is needed as in update_batch + self.update_batch(states) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/core/tests/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs similarity index 95% rename from datafusion/core/tests/user_defined_plan.rs rename to datafusion/core/tests/user_defined/user_defined_plan.rs index 7738a123949f9..d4a8842c0a7ad 100644 --- a/datafusion/core/tests/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -58,7 +58,9 @@ //! N elements, reducing the total amount of required buffer memory. //! -use futures::{Stream, StreamExt}; +use std::fmt::Debug; +use std::task::{Context, Poll}; +use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use arrow::{ array::{Int64Array, StringArray}, @@ -68,7 +70,7 @@ use arrow::{ }; use datafusion::{ common::cast::{as_int64_array, as_string_array}, - common::DFSchemaRef, + common::{internal_err, DFSchemaRef}, error::{DataFusionError, Result}, execution::{ context::{QueryPlanner, SessionState, TaskContext}, @@ -80,19 +82,16 @@ use datafusion::{ }, optimizer::{optimize_children, OptimizerConfig, OptimizerRule}, physical_plan::{ - expressions::PhysicalSortExpr, - planner::{DefaultPhysicalPlanner, ExtensionPlanner}, - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, - RecordBatchStream, SendableRecordBatchStream, Statistics, + expressions::PhysicalSortExpr, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; -use fmt::Debug; -use std::task::{Context, Poll}; -use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; - use async_trait::async_trait; +use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. @@ -219,10 +218,8 @@ async fn topk_query() -> Result<()> { async fn topk_plan() -> Result<()> { let mut ctx = setup_table(make_topk_context()).await?; - let mut expected = vec![ - "| logical_plan after topk | TopK: k=3 |", - "| | TableScan: sales projection=[customer_id,revenue] |", - ].join("\n"); + let mut expected = ["| logical_plan after topk | TopK: k=3 |", + "| | TableScan: sales projection=[customer_id,revenue] |"].join("\n"); let explain_query = format!("EXPLAIN VERBOSE {QUERY}"); let actual_output = exec_sql(&mut ctx, &explain_query).await?; @@ -248,10 +245,10 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::with_config_rt(config, runtime) + let state = SessionState::new_with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - SessionContext::with_state(state) + SessionContext::new_with_state(state) } // ------ The implementation of the TopK code follows ----- @@ -422,6 +419,20 @@ impl Debug for TopKExec { } } +impl DisplayAs for TopKExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "TopKExec: k={}", self.k) + } + } + } +} + #[async_trait] impl ExecutionPlan for TopKExec { /// Return a reference to Any that can be used for downcasting @@ -466,9 +477,7 @@ impl ExecutionPlan for TopKExec { context: Arc, ) -> Result { if 0 != partition { - return Err(DataFusionError::Internal(format!( - "TopKExec invalid partition {partition}" - ))); + return internal_err!("TopKExec invalid partition {partition}"); } Ok(Box::pin(TopKReader { @@ -479,22 +488,10 @@ impl ExecutionPlan for TopKExec { })) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "TopKExec: k={}", self.k) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // to improve the optimizability of this plan // better statistics inference could be provided - Statistics::default() + Ok(Statistics::new_unknown(&self.schema())) } } diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs similarity index 54% rename from datafusion/core/tests/sql/udf.rs rename to datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a31028fd71cb6..985b0bd5bc767 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,26 +15,56 @@ // specific language governing permissions and limitations // under the License. -use super::*; -use arrow::compute::add; +use arrow::compute::kernels::numeric::add; +use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::prelude::*; use datafusion::{ execution::registry::FunctionRegistry, - physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function}, + physical_plan::functions::make_scalar_function, test_util, }; -use datafusion_common::{cast::as_int32_array, ScalarValue}; -use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder}; +use datafusion_common::cast::as_float64_array; +use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; +use datafusion_expr::{ + create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility, +}; +use std::sync::Arc; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and /// physical plan have the same schema. #[tokio::test] async fn csv_query_custom_udf_with_cast() -> Result<()> { - let ctx = create_ctx(); + let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = execute(&ctx, sql).await; - let expected = vec![vec!["0.6584408483418833"]]; - assert_float_eq(&expected, &actual); + let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let expected = [ + "+------------------------------------------+", + "| AVG(custom_sqrt(aggregate_test_100.c11)) |", + "+------------------------------------------+", + "| 0.6584408483418833 |", + "+------------------------------------------+", + ]; + assert_batches_eq!(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_avg_sqrt() -> Result<()> { + let ctx = create_udf_context(); + register_aggregate_csv(&ctx).await?; + // Note it is a different column (c12) than above (c11) + let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; + let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let expected = [ + "+------------------------------------------+", + "| AVG(custom_sqrt(aggregate_test_100.c12)) |", + "+------------------------------------------+", + "| 0.6706002946036462 |", + "+------------------------------------------+", + ]; + assert_batches_eq!(&expected, &actual); Ok(()) } @@ -92,7 +122,7 @@ async fn scalar_udf() -> Result<()> { let result = DataFrame::new(ctx.state(), plan).collect().await?; - let expected = vec![ + let expected = [ "+-----+-----+-----------------+", "| a | b | my_add(t.a,t.b) |", "+-----+-----+-----------------+", @@ -148,7 +178,7 @@ async fn scalar_udf_zero_params() -> Result<()> { )); let result = plan_and_collect(&ctx, "select get_100() a from t").await?; - let expected = vec![ + let expected = [ "+-----+", // "| a |", // "+-----+", // @@ -156,22 +186,22 @@ async fn scalar_udf_zero_params() -> Result<()> { "| 100 |", // "| 100 |", // "| 100 |", // - "+-----+", // + "+-----+", ]; assert_batches_eq!(expected, &result); let result = plan_and_collect(&ctx, "select get_100() a").await?; - let expected = vec![ + let expected = [ "+-----+", // "| a |", // "+-----+", // "| 100 |", // - "+-----+", // + "+-----+", ]; assert_batches_eq!(expected, &result); let result = plan_and_collect(&ctx, "select get_100() from t where a=999").await?; - let expected = vec![ + let expected = [ "++", // "++", ]; @@ -179,53 +209,36 @@ async fn scalar_udf_zero_params() -> Result<()> { Ok(()) } -/// tests the creation, registration and usage of a UDAF #[tokio::test] -async fn simple_udaf() -> Result<()> { +async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( + let batch = RecordBatch::try_new( Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![4, 5]))], + vec![Arc::new(Int32Array::from(vec![-100]))], )?; - let ctx = SessionContext::new(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - // define a udaf, using a DataFusion's accumulator - let my_avg = create_udaf( - "my_avg", - DataType::Float64, - Arc::new(DataType::Float64), + ctx.register_batch("t", batch)?; + // register a UDF that has the same name as a builtin function (abs) and just returns 1 regardless of input + ctx.register_udf(create_udf( + "abs", + vec![DataType::Int32], + Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), - Arc::new(vec![DataType::UInt64, DataType::Float64]), - ); - - ctx.register_udaf(my_avg); - - let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?; + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), + )); - let expected = vec![ - "+-------------+", - "| my_avg(t.a) |", - "+-------------+", - "| 3.0 |", - "+-------------+", + // Make sure that the UDF is used instead of the built-in function + let result = plan_and_collect(&ctx, "select abs(a) a from t").await?; + let expected = [ + "+---+", // + "| a |", // + "+---+", // + "| 1 |", // + "+---+", ]; assert_batches_eq!(expected, &result); - Ok(()) } @@ -258,7 +271,7 @@ async fn udaf_as_window_func() -> Result<()> { let my_acc = create_udaf( "my_acc", - DataType::Int32, + vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, Arc::new(|_| Ok(Box::new(MyAccumulator))), @@ -286,3 +299,123 @@ async fn udaf_as_window_func() -> Result<()> { assert_eq!(format!("{:?}", dataframe.logical_plan()), expected); Ok(()) } + +#[tokio::test] +async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + ctx.register_udf(create_udf( + "MY_FUNC", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + )); + + // doesn't work as it was registered with non lowercase + let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") + .await + .unwrap_err(); + assert!(err + .to_string() + .contains("Error during planning: Invalid function \'my_func\'")); + + // Can call it if you put quotes + let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; + + let expected = [ + "+--------------+", + "| MY_FUNC(t.i) |", + "+--------------+", + "| 1 |", + "+--------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + +#[tokio::test] +async fn test_user_defined_functions_with_alias() -> Result<()> { + let ctx = SessionContext::new(); + let arr = Int32Array::from(vec![1]); + let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; + ctx.register_batch("t", batch).unwrap(); + + let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = make_scalar_function(myfunc); + + let udf = create_udf( + "dummy", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + ) + .with_aliases(vec!["dummy_alias"]); + + ctx.register_udf(udf); + + let expected = [ + "+------------+", + "| dummy(t.i) |", + "+------------+", + "| 1 |", + "+------------+", + ]; + let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; + assert_batches_eq!(expected, &result); + + let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; + assert_batches_eq!(expected, &alias_result); + + Ok(()) +} + +fn create_udf_context() -> SessionContext { + let ctx = SessionContext::new(); + // register a custom UDF + ctx.register_udf(create_udf( + "custom_sqrt", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(custom_sqrt), + )); + + ctx +} + +fn custom_sqrt(args: &[ColumnarValue]) -> Result { + let arg = &args[0]; + if let ColumnarValue::Array(v) = arg { + let input = as_float64_array(v).expect("cast failed"); + let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); + Ok(ColumnarValue::Array(Arc::new(array))) + } else { + unimplemented!() + } +} + +async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + +/// Execute SQL and return results as a RecordBatch +async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs new file mode 100644 index 0000000000000..b5d10b1c5b9ba --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::csv::reader::Format; +use arrow::csv::ReaderBuilder; +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::function::TableFunctionImpl; +use datafusion::datasource::TableProvider; +use datafusion::error::Result; +use datafusion::execution::context::SessionState; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + +/// test simple udtf with define read_csv with parameters +#[tokio::test] +async fn test_simple_read_csv_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + // read csv with at most 5 rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str()) + .await? + .collect() + .await?; + + let excepted = [ + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+", ]; + assert_batches_eq!(excepted, &rbs); + + // just run, return all rows + let rbs = ctx + .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .await? + .collect() + .await?; + let excepted = [ + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| n_nationkey | n_name | n_regionkey | n_comment |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+", + "| 1 | ARGENTINA | 1 | al foxes promise slyly according to the regular accounts. bold requests alon |", + "| 2 | BRAZIL | 1 | y alongside of the pending deposits. carefully special packages are about the ironic forges. slyly special |", + "| 3 | CANADA | 1 | eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold |", + "| 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d |", + "| 5 | ETHIOPIA | 0 | ven packages wake quickly. regu |", + "| 6 | FRANCE | 3 | refully final requests. regular, ironi |", + "| 7 | GERMANY | 3 | l platelets. regular accounts x-ray: unusual, regular acco |", + "| 8 | INDIA | 2 | ss excuses cajole slyly across the packages. deposits print aroun |", + "| 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull |", + "| 10 | IRAN | 4 | efully alongside of the slyly final dependencies. |", + "+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+" + ]; + assert_batches_eq!(excepted, &rbs); + + Ok(()) +} + +struct SimpleCsvTable { + schema: SchemaRef, + exprs: Vec, + batches: Vec, +} + +#[async_trait] +impl TableProvider for SimpleCsvTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let batches = if !self.exprs.is_empty() { + let max_return_lines = self.interpreter_expr(state).await?; + // get max return rows from self.batches + let mut batches = vec![]; + let mut lines = 0; + for batch in &self.batches { + let batch_lines = batch.num_rows(); + if lines + batch_lines > max_return_lines as usize { + let batch_lines = max_return_lines as usize - lines; + batches.push(batch.slice(0, batch_lines)); + break; + } else { + batches.push(batch.clone()); + lines += batch_lines; + } + } + batches + } else { + self.batches.clone() + }; + Ok(Arc::new(MemoryExec::try_new( + &[batches], + TableProvider::schema(self), + projection.cloned(), + )?)) + } +} + +impl SimpleCsvTable { + async fn interpreter_expr(&self, state: &SessionState) -> Result { + use datafusion::logical_expr::expr_rewriter::normalize_col; + use datafusion::logical_expr::utils::columnize_expr; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }); + let logical_plan = Projection::try_new( + vec![columnize_expr( + normalize_col(self.exprs[0].clone(), &plan)?, + plan.schema(), + )], + Arc::new(plan), + ) + .map(LogicalPlan::Projection)?; + let rbs = collect( + state.create_physical_plan(&logical_plan).await?, + Arc::new(TaskContext::from(state)), + ) + .await?; + let limit = rbs[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + Ok(limit) + } +} + +struct SimpleCsvTableFunc {} + +impl TableFunctionImpl for SimpleCsvTableFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + let mut new_exprs = vec![]; + let mut filepath = String::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + filepath = path.clone() + } + expr => new_exprs.push(expr.clone()), + } + } + let (schema, batches) = read_csv_batches(filepath)?; + let table = SimpleCsvTable { + schema, + exprs: new_exprs.clone(), + batches, + }; + Ok(Arc::new(table)) + } +} + +fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { + let mut file = File::open(csv_path)?; + let (schema, _) = Format::default() + .with_header(true) + .infer_schema(&mut file, None)?; + file.rewind()?; + + let reader = ReaderBuilder::new(Arc::new(schema.clone())) + .with_header(true) + .build(file)?; + let mut batches = vec![]; + for bacth in reader { + batches.push(bacth?); + } + let schema = Arc::new(schema); + Ok((schema, batches)) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs new file mode 100644 index 0000000000000..5f99391572174 --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -0,0 +1,566 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains end to end tests of creating +//! user defined window functions + +use std::{ + ops::Range, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use arrow::array::AsArray; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; +use arrow_schema::DataType; +use datafusion::{assert_batches_eq, prelude::SessionContext}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction, + Signature, Volatility, WindowUDF, +}; + +/// A query with a window function evaluated over the entire partition +const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ + odd_counter(val) OVER (PARTITION BY x ORDER BY y) \ + from t ORDER BY x, y"; + +/// A query with a window function evaluated over a moving window +const BOUNDED_WINDOW_QUERY: &str = + "SELECT x, y, val, \ + odd_counter(val) OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ + from t ORDER BY x, y"; + +/// Test to show the contents of the setup +#[tokio::test] +async fn test_setup() { + let test_state = TestState::new(); + let TestContext { ctx, test_state: _ } = TestContext::new(test_state); + + let sql = "SELECT * from t order by x, y"; + let expected = vec![ + "+---+---+-----+", + "| x | y | val |", + "+---+---+-----+", + "| 1 | a | 0 |", + "| 1 | b | 1 |", + "| 1 | c | 2 |", + "| 2 | d | 3 |", + "| 2 | e | 4 |", + "| 2 | f | 5 |", + "| 2 | g | 6 |", + "| 2 | h | 6 |", + "| 2 | i | 6 |", + "| 2 | j | 6 |", + "+---+---+-----+", + ]; + assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); +} + +/// Basic user defined window function +#[tokio::test] +async fn test_udwf() { + let test_state = TestState::new(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 2 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + // evaluated on two distinct batches + assert_eq!(test_state.evaluate_all_called(), 2); +} + +/// Basic user defined window function with bounded window +#[tokio::test] +async fn test_udwf_bounded_window_ignores_frame() { + let test_state = TestState::new(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + // Since the UDWF doesn't say it needs the window frame, the frame is ignored + let expected = vec![ + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 2 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 2); +} + +/// Basic user defined window function with bounded window +#[tokio::test] +async fn test_udwf_bounded_window() { + let test_state = TestState::new().with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 1 |", + "| 2 | g | 6 | 1 |", + "| 2 | h | 6 | 0 |", + "| 2 | i | 6 | 0 |", + "| 2 | j | 6 | 0 |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // Evaluate is called for each input rows + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// Basic stateful user defined window function +#[tokio::test] +async fn test_stateful_udwf() { + let test_state = TestState::new() + .with_supports_bounded_execution() + .with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 0 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 1 |", + "| 2 | f | 5 | 2 |", + "| 2 | g | 6 | 2 |", + "| 2 | h | 6 | 2 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 2 |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// Basic stateful user defined window function with bounded window +#[tokio::test] +async fn test_stateful_udwf_bounded_window() { + let test_state = TestState::new() + .with_supports_bounded_execution() + .with_uses_window_frame(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 1 |", + "| 2 | g | 6 | 1 |", + "| 2 | h | 6 | 0 |", + "| 2 | i | 6 | 0 |", + "| 2 | j | 6 | 0 |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // Evaluate and update_state is called for each input row + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +/// user defined window function using rank +#[tokio::test] +async fn test_udwf_query_include_rank() { + let test_state = TestState::new().with_include_rank(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 3 |", + "| 1 | b | 1 | 2 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 7 |", + "| 2 | e | 4 | 6 |", + "| 2 | f | 5 | 5 |", + "| 2 | g | 6 | 4 |", + "| 2 | h | 6 | 3 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 1 |", + "+---+---+-----+-----------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 0); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_all_with_rank_called(), 2); +} + +/// user defined window function with bounded window using rank +#[tokio::test] +async fn test_udwf_bounded_query_include_rank() { + let test_state = TestState::new().with_include_rank(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 3 |", + "| 1 | b | 1 | 2 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 7 |", + "| 2 | e | 4 | 6 |", + "| 2 | f | 5 | 5 |", + "| 2 | g | 6 | 4 |", + "| 2 | h | 6 | 3 |", + "| 2 | i | 6 | 2 |", + "| 2 | j | 6 | 1 |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + assert_eq!(test_state.evaluate_called(), 0); + assert_eq!(test_state.evaluate_all_called(), 0); + // evaluated on 2 distinct batches (when x=1 and x=2) + assert_eq!(test_state.evaluate_all_with_rank_called(), 2); +} + +/// Basic user defined window function that can return NULL. +#[tokio::test] +async fn test_udwf_bounded_window_returns_null() { + let test_state = TestState::new() + .with_uses_window_frame() + .with_null_for_zero(); + let TestContext { ctx, test_state } = TestContext::new(test_state); + + let expected = vec![ + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + "| 1 | a | 0 | 1 |", + "| 1 | b | 1 | 1 |", + "| 1 | c | 2 | 1 |", + "| 2 | d | 3 | 1 |", + "| 2 | e | 4 | 2 |", + "| 2 | f | 5 | 1 |", + "| 2 | g | 6 | 1 |", + "| 2 | h | 6 | |", + "| 2 | i | 6 | |", + "| 2 | j | 6 | |", + "+---+---+-----+--------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!( + expected, + &execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap() + ); + // Evaluate is called for each input rows + assert_eq!(test_state.evaluate_called(), 10); + assert_eq!(test_state.evaluate_all_called(), 0); +} + +async fn execute(ctx: &SessionContext, sql: &str) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Returns an context with a table "t" and the "first" and "time_sum" +/// aggregate functions registered. +/// +/// "t" contains this data: +/// +/// ```text +/// x | y | val +/// 1 | a | 0 +/// 1 | b | 1 +/// 1 | c | 2 +/// 2 | d | 3 +/// 2 | e | 4 +/// 2 | f | 5 +/// 2 | g | 6 +/// 2 | h | 6 +/// 2 | i | 6 +/// 2 | j | 6 +/// ``` +struct TestContext { + ctx: SessionContext, + test_state: Arc, +} + +impl TestContext { + fn new(test_state: TestState) -> Self { + let test_state = Arc::new(test_state); + let x = Int64Array::from(vec![1, 1, 1, 2, 2, 2, 2, 2, 2, 2]); + let y = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]); + let val = Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6, 6, 6, 6]); + + let batch = RecordBatch::try_from_iter(vec![ + ("x", Arc::new(x) as _), + ("y", Arc::new(y) as _), + ("val", Arc::new(val) as _), + ]) + .unwrap(); + + let mut ctx = SessionContext::new(); + + ctx.register_batch("t", batch).unwrap(); + + // Tell DataFusion about the window function + OddCounter::register(&mut ctx, Arc::clone(&test_state)); + + Self { ctx, test_state } + } +} + +#[derive(Debug, Default)] +struct TestState { + /// How many times was `evaluate_all` called? + evaluate_all_called: AtomicUsize, + /// How many times was `evaluate` called? + evaluate_called: AtomicUsize, + /// How many times was `evaluate_all_with_rank` called? + evaluate_all_with_rank_called: AtomicUsize, + /// should the functions say they use the window frame? + uses_window_frame: bool, + /// should the functions say they support bounded execution + supports_bounded_execution: bool, + /// should the functions they need include rank + include_rank: bool, + /// should the functions return NULL for 0s? + null_for_zero: bool, +} + +impl TestState { + fn new() -> Self { + Default::default() + } + + /// Set that this function should use the window frame + fn with_uses_window_frame(mut self) -> Self { + self.uses_window_frame = true; + self + } + + /// Set that this function should use bounded / stateful execution + fn with_supports_bounded_execution(mut self) -> Self { + self.supports_bounded_execution = true; + self + } + + /// Set that this function should include rank + fn with_include_rank(mut self) -> Self { + self.include_rank = true; + self + } + + // Set that this function should return NULL instead of zero. + fn with_null_for_zero(mut self) -> Self { + self.null_for_zero = true; + self + } + + /// return the evaluate_all_called counter + fn evaluate_all_called(&self) -> usize { + self.evaluate_all_called.load(Ordering::SeqCst) + } + + /// update the evaluate_all_called counter + fn inc_evaluate_all_called(&self) { + self.evaluate_all_called.fetch_add(1, Ordering::SeqCst); + } + + /// return the evaluate_called counter + fn evaluate_called(&self) -> usize { + self.evaluate_called.load(Ordering::SeqCst) + } + + /// update the evaluate_called counter + fn inc_evaluate_called(&self) { + self.evaluate_called.fetch_add(1, Ordering::SeqCst); + } + + /// return the evaluate_all_with_rank_called counter + fn evaluate_all_with_rank_called(&self) -> usize { + self.evaluate_all_with_rank_called.load(Ordering::SeqCst) + } + + /// update the evaluate_all_with_rank_called counter + fn inc_evaluate_all_with_rank_called(&self) { + self.evaluate_all_with_rank_called + .fetch_add(1, Ordering::SeqCst); + } +} + +// Partition Evaluator that counts the number of odd numbers in the window frame using evaluate +#[derive(Debug)] +struct OddCounter { + test_state: Arc, +} + +impl OddCounter { + fn new(test_state: Arc) -> Self { + Self { test_state } + } + + fn register(ctx: &mut SessionContext, test_state: Arc) { + let name = "odd_counter"; + let volatility = Volatility::Immutable; + + let signature = Signature::exact(vec![DataType::Int64], volatility); + + let return_type = Arc::new(DataType::Int64); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::clone(&return_type))); + + let partition_evaluator_factory: PartitionEvaluatorFactory = + Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state))))); + + ctx.register_udwf(WindowUDF::new( + name, + &signature, + &return_type, + &partition_evaluator_factory, + )) + } +} + +impl PartitionEvaluator for OddCounter { + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + println!("evaluate, values: {values:#?}, range: {range:?}"); + + self.test_state.inc_evaluate_called(); + let values: &Int64Array = values[0].as_primitive(); + let values = values.slice(range.start, range.len()); + let scalar = ScalarValue::Int64( + match (odd_count(&values), self.test_state.null_for_zero) { + (0, true) => None, + (n, _) => Some(n), + }, + ); + Ok(scalar) + } + + fn evaluate_all( + &mut self, + values: &[arrow_array::ArrayRef], + num_rows: usize, + ) -> Result { + println!("evaluate_all, values: {values:#?}, num_rows: {num_rows}"); + + self.test_state.inc_evaluate_all_called(); + Ok(odd_count_arr(values[0].as_primitive(), num_rows)) + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + self.test_state.inc_evaluate_all_with_rank_called(); + println!("evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}"); + // when evaluating with ranks, just return the inverse rank instead + let array: Int64Array = ranks_in_partition + .iter() + // cloned range is an iterator + .cloned() + .flatten() + .map(|v| (num_rows - v) as i64) + .collect(); + Ok(Arc::new(array)) + } + + fn supports_bounded_execution(&self) -> bool { + self.test_state.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.test_state.uses_window_frame + } + + fn include_rank(&self) -> bool { + self.test_state.include_rank + } +} + +/// returns the number of entries in arr that are odd +fn odd_count(arr: &Int64Array) -> i64 { + arr.iter().filter_map(|x| x.map(|x| x % 2)).sum() +} + +/// returns an array of num_rows that has the number of odd values in `arr` +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { + let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); + Arc::new(array) +} diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs deleted file mode 100644 index 1047f73df4cd6..0000000000000 --- a/datafusion/core/tests/user_defined_aggregates.rs +++ /dev/null @@ -1,241 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains end to end demonstrations of creating -//! user defined aggregate functions - -use arrow::datatypes::Fields; -use std::sync::Arc; - -use datafusion::{ - arrow::{ - array::{ArrayRef, Float64Array, TimestampNanosecondArray}, - datatypes::{DataType, Field, Float64Type, TimeUnit, TimestampNanosecondType}, - record_batch::RecordBatch, - }, - assert_batches_eq, - error::Result, - logical_expr::{ - AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, TypeSignature, Volatility, - }, - physical_plan::Accumulator, - prelude::SessionContext, - scalar::ScalarValue, -}; -use datafusion_common::cast::as_primitive_array; - -#[tokio::test] -/// Basic query for with a udaf returning a structure -async fn test_udf_returning_struct() { - let ctx = udaf_struct_context(); - let sql = "SELECT first(value, time) from t"; - let expected = vec![ - "+------------------------------------------------+", - "| first(t.value,t.time) |", - "+------------------------------------------------+", - "| {value: 2.0, time: 1970-01-01T00:00:00.000002} |", - "+------------------------------------------------+", - ]; - assert_batches_eq!(expected, &execute(&ctx, sql).await); -} - -#[tokio::test] -/// Demonstrate extracting the fields from the a structure using a subquery -async fn test_udf_returning_struct_sq() { - let ctx = udaf_struct_context(); - let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq"; - let expected = vec![ - "+-----------------+----------------------------+", - "| sq.first[value] | sq.first[time] |", - "+-----------------+----------------------------+", - "| 2.0 | 1970-01-01T00:00:00.000002 |", - "+-----------------+----------------------------+", - ]; - assert_batches_eq!(expected, &execute(&ctx, sql).await); -} - -async fn execute(ctx: &SessionContext, sql: &str) -> Vec { - ctx.sql(sql).await.unwrap().collect().await.unwrap() -} - -/// Returns an context with a table "t" and the "first" aggregate registered. -/// -/// "t" contains this data: -/// -/// ```text -/// value | time -/// 3.0 | 1970-01-01T00:00:00.000003 -/// 2.0 | 1970-01-01T00:00:00.000002 -/// 1.0 | 1970-01-01T00:00:00.000004 -/// ``` -fn udaf_struct_context() -> SessionContext { - let value: Float64Array = vec![3.0, 2.0, 1.0].into_iter().map(Some).collect(); - let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000]); - - let batch = RecordBatch::try_from_iter(vec![ - ("value", Arc::new(value) as _), - ("time", Arc::new(time) as _), - ]) - .unwrap(); - - let mut ctx = SessionContext::new(); - ctx.register_batch("t", batch).unwrap(); - - // Tell datafusion about the "first" function - register_aggregate(&mut ctx); - - ctx -} - -fn register_aggregate(ctx: &mut SessionContext) { - let return_type = Arc::new(FirstSelector::output_datatype()); - let state_type = Arc::new(FirstSelector::state_datatypes()); - - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); - - // Possible input signatures - let signatures = vec![TypeSignature::Exact(FirstSelector::input_datatypes())]; - - let accumulator: AccumulatorFunctionImplementation = - Arc::new(|_| Ok(Box::new(FirstSelector::new()))); - - let volatility = Volatility::Immutable; - - let name = "first"; - - let first = AggregateUDF::new( - name, - &Signature::one_of(signatures, volatility), - &return_type, - &accumulator, - &state_type, - ); - - // register the selector as "first" - ctx.register_udaf(first) -} - -/// This structureg models a specialized timeseries aggregate function -/// called a "selector" in InfluxQL and Flux. -/// -/// It returns the value and corresponding timestamp of the -/// input with the earliest timestamp as a structure. -#[derive(Debug, Clone)] -struct FirstSelector { - value: f64, - time: i64, -} - -impl FirstSelector { - /// Create a new empty selector - fn new() -> Self { - Self { - value: 0.0, - time: i64::MAX, - } - } - - /// Return the schema fields - fn fields() -> Fields { - vec![ - Field::new("value", DataType::Float64, true), - Field::new( - "time", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - ] - .into() - } - - // output data type - fn output_datatype() -> DataType { - DataType::Struct(Self::fields()) - } - - // input argument data types - fn input_datatypes() -> Vec { - vec![ - DataType::Float64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ] - } - - // Internally, keep the data types as this type - fn state_datatypes() -> Vec { - vec![ - DataType::Float64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ] - } - - /// Convert to a set of ScalarValues - fn to_state(&self) -> Vec { - vec![ - ScalarValue::Float64(Some(self.value)), - ScalarValue::TimestampNanosecond(Some(self.time), None), - ] - } - - /// return this selector as a single scalar (struct) value - fn to_scalar(&self) -> ScalarValue { - ScalarValue::Struct(Some(self.to_state()), Self::fields()) - } -} - -impl Accumulator for FirstSelector { - fn state(&self) -> Result> { - let state = self.to_state().into_iter().collect::>(); - - Ok(state) - } - - /// produce the output structure - fn evaluate(&self) -> Result { - Ok(self.to_scalar()) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // cast argumets to the appropriate type (DataFusion will type - // check these based on the declared allowed input types) - let v = as_primitive_array::(&values[0])?; - let t = as_primitive_array::(&values[1])?; - - // Update the actual values - for (value, time) in v.iter().zip(t.iter()) { - if let (Some(time), Some(value)) = (time, value) { - if time < self.time { - self.value = value; - self.time = time; - } - } - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // same logic is needed as in update_batch - self.update_batch(states) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/core/tests/user_defined_integration.rs b/datafusion/core/tests/user_defined_integration.rs new file mode 100644 index 0000000000000..4f9cc89529adb --- /dev/null +++ b/datafusion/core/tests/user_defined_integration.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Run all tests that are found in the `user_defined` directory +mod user_defined; + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); +} diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index f75d79a0fa752..e9bb87e9f8ac3 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -18,10 +18,10 @@ [package] name = "datafusion-execution" description = "Execution configuration support for DataFusion query engine" -keywords = [ "arrow", "query", "sql" ] +keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -33,13 +33,16 @@ name = "datafusion_execution" path = "src/lib.rs" [dependencies] -dashmap = "5.4.0" -datafusion-common = { path = "../common", version = "26.0.0" } -datafusion-expr = { path = "../expr", version = "26.0.0" } +arrow = { workspace = true } +chrono = { version = "0.4", default-features = false } +dashmap = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +futures = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } -log = "^0.4" -object_store = "0.6.1" -parking_lot = "0.12" -rand = "0.8" -tempfile = "3" -url = "2.2" +log = { workspace = true } +object_store = { workspace = true } +parking_lot = { workspace = true } +rand = { workspace = true } +tempfile = { workspace = true } +url = { workspace = true } diff --git a/datafusion/execution/README.md b/datafusion/execution/README.md new file mode 100644 index 0000000000000..67aac6be82b3f --- /dev/null +++ b/datafusion/execution/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Common + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides execution runtime such as the memory pools and disk manager. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs new file mode 100644 index 0000000000000..97529263688bf --- /dev/null +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cache::CacheAccessor; +use datafusion_common::{Result, Statistics}; +use object_store::path::Path; +use object_store::ObjectMeta; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +/// The cache of listing files statistics. +/// if set [`CacheManagerConfig::with_files_statistics_cache`] +/// Will avoid infer same file statistics repeatedly during the session lifetime, +/// this cache will store in [`crate::runtime_env::RuntimeEnv`]. +pub type FileStatisticsCache = + Arc, Extra = ObjectMeta>>; + +pub type ListFilesCache = + Arc>, Extra = ObjectMeta>>; + +impl Debug for dyn CacheAccessor, Extra = ObjectMeta> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Cache name: {} with length: {}", self.name(), self.len()) + } +} + +impl Debug for dyn CacheAccessor>, Extra = ObjectMeta> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Cache name: {} with length: {}", self.name(), self.len()) + } +} + +#[derive(Default, Debug)] +pub struct CacheManager { + file_statistic_cache: Option, + list_files_cache: Option, +} + +impl CacheManager { + pub fn try_new(config: &CacheManagerConfig) -> Result> { + let mut manager = CacheManager::default(); + if let Some(cc) = &config.table_files_statistics_cache { + manager.file_statistic_cache = Some(cc.clone()) + } + if let Some(lc) = &config.list_files_cache { + manager.list_files_cache = Some(lc.clone()) + } + Ok(Arc::new(manager)) + } + + /// Get the cache of listing files statistics. + pub fn get_file_statistic_cache(&self) -> Option { + self.file_statistic_cache.clone() + } + + /// Get the cache of objectMeta under same path. + pub fn get_list_files_cache(&self) -> Option { + self.list_files_cache.clone() + } +} + +#[derive(Clone, Default)] +pub struct CacheManagerConfig { + /// Enable cache of files statistics when listing files. + /// Avoid get same file statistics repeatedly in same datafusion session. + /// Default is disable. Fow now only supports Parquet files. + pub table_files_statistics_cache: Option, + /// Enable cache of file metadata when listing files. + /// This setting avoids listing file meta of the same path repeatedly + /// in same session, which may be expensive in certain situations (e.g. remote object storage). + /// Note that if this option is enabled, DataFusion will not see any updates to the underlying + /// location. + /// Default is disable. + pub list_files_cache: Option, +} + +impl CacheManagerConfig { + pub fn with_files_statistics_cache( + mut self, + cache: Option, + ) -> Self { + self.table_files_statistics_cache = cache; + self + } + + pub fn with_list_files_cache(mut self, cache: Option) -> Self { + self.list_files_cache = cache; + self + } +} diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs new file mode 100644 index 0000000000000..25f9b9fa4d687 --- /dev/null +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::cache::CacheAccessor; + +use datafusion_common::Statistics; + +use dashmap::DashMap; +use object_store::path::Path; +use object_store::ObjectMeta; + +/// Collected statistics for files +/// Cache is invalided when file size or last modification has changed +#[derive(Default)] +pub struct DefaultFileStatisticsCache { + statistics: DashMap)>, +} + +impl CacheAccessor> for DefaultFileStatisticsCache { + type Extra = ObjectMeta; + + /// Get `Statistics` for file location. + fn get(&self, k: &Path) -> Option> { + self.statistics + .get(k) + .map(|s| Some(s.value().1.clone())) + .unwrap_or(None) + } + + /// Get `Statistics` for file location. Returns None if file has changed or not found. + fn get_with_extra(&self, k: &Path, e: &Self::Extra) -> Option> { + self.statistics + .get(k) + .map(|s| { + let (saved_meta, statistics) = s.value(); + if saved_meta.size != e.size + || saved_meta.last_modified != e.last_modified + { + // file has changed + None + } else { + Some(statistics.clone()) + } + }) + .unwrap_or(None) + } + + /// Save collected file statistics + fn put(&self, _key: &Path, _value: Arc) -> Option> { + panic!("Put cache in DefaultFileStatisticsCache without Extra not supported.") + } + + fn put_with_extra( + &self, + key: &Path, + value: Arc, + e: &Self::Extra, + ) -> Option> { + self.statistics + .insert(key.clone(), (e.clone(), value)) + .map(|x| x.1) + } + + fn remove(&mut self, k: &Path) -> Option> { + self.statistics.remove(k).map(|x| x.1 .1) + } + + fn contains_key(&self, k: &Path) -> bool { + self.statistics.contains_key(k) + } + + fn len(&self) -> usize { + self.statistics.len() + } + + fn clear(&self) { + self.statistics.clear() + } + fn name(&self) -> String { + "DefaultFileStatisticsCache".to_string() + } +} + +/// Collected files metadata for listing files. +/// Cache will not invalided until user call remove or clear. +#[derive(Default)] +pub struct DefaultListFilesCache { + statistics: DashMap>>, +} + +impl CacheAccessor>> for DefaultListFilesCache { + type Extra = ObjectMeta; + + fn get(&self, k: &Path) -> Option>> { + self.statistics.get(k).map(|x| x.value().clone()) + } + + fn get_with_extra( + &self, + _k: &Path, + _e: &Self::Extra, + ) -> Option>> { + panic!("Not supported DefaultListFilesCache get_with_extra") + } + + fn put( + &self, + key: &Path, + value: Arc>, + ) -> Option>> { + self.statistics.insert(key.clone(), value) + } + + fn put_with_extra( + &self, + _key: &Path, + _value: Arc>, + _e: &Self::Extra, + ) -> Option>> { + panic!("Not supported DefaultListFilesCache put_with_extra") + } + + fn remove(&mut self, k: &Path) -> Option>> { + self.statistics.remove(k).map(|x| x.1) + } + + fn contains_key(&self, k: &Path) -> bool { + self.statistics.contains_key(k) + } + + fn len(&self) -> usize { + self.statistics.len() + } + + fn clear(&self) { + self.statistics.clear() + } + + fn name(&self) -> String { + "DefaultListFilesCache".to_string() + } +} + +#[cfg(test)] +mod tests { + use crate::cache::cache_unit::{DefaultFileStatisticsCache, DefaultListFilesCache}; + use crate::cache::CacheAccessor; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use chrono::DateTime; + use datafusion_common::Statistics; + use object_store::path::Path; + use object_store::ObjectMeta; + + #[test] + fn test_statistics_cache() { + let meta = ObjectMeta { + location: Path::from("test"), + last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") + .unwrap() + .into(), + size: 1024, + e_tag: None, + version: None, + }; + let cache = DefaultFileStatisticsCache::default(); + assert!(cache.get_with_extra(&meta.location, &meta).is_none()); + + cache.put_with_extra( + &meta.location, + Statistics::new_unknown(&Schema::new(vec![Field::new( + "test_column", + DataType::Timestamp(TimeUnit::Second, None), + false, + )])) + .into(), + &meta, + ); + assert!(cache.get_with_extra(&meta.location, &meta).is_some()); + + // file size changed + let mut meta2 = meta.clone(); + meta2.size = 2048; + assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); + + // file last_modified changed + let mut meta2 = meta.clone(); + meta2.last_modified = DateTime::parse_from_rfc3339("2022-09-27T22:40:00+02:00") + .unwrap() + .into(); + assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); + + // different file + let mut meta2 = meta; + meta2.location = Path::from("test2"); + assert!(cache.get_with_extra(&meta2.location, &meta2).is_none()); + } + + #[test] + fn test_list_file_cache() { + let meta = ObjectMeta { + location: Path::from("test"), + last_modified: DateTime::parse_from_rfc3339("2022-09-27T22:36:00+02:00") + .unwrap() + .into(), + size: 1024, + e_tag: None, + version: None, + }; + + let cache = DefaultListFilesCache::default(); + assert!(cache.get(&meta.location).is_none()); + + cache.put(&meta.location, vec![meta.clone()].into()); + assert_eq!( + cache.get(&meta.location).unwrap().first().unwrap().clone(), + meta.clone() + ); + } +} diff --git a/datafusion/execution/src/cache/mod.rs b/datafusion/execution/src/cache/mod.rs new file mode 100644 index 0000000000000..da19bff5658af --- /dev/null +++ b/datafusion/execution/src/cache/mod.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cache_manager; +pub mod cache_unit; + +/// The cache accessor, users usually working on this interface while manipulating caches. +/// This interface does not get `mut` references and thus has to handle its own +/// locking via internal mutability. It can be accessed via multiple concurrent queries +/// during planning and execution. + +pub trait CacheAccessor: Send + Sync { + // Extra info but not part of the cache key or cache value. + type Extra: Clone; + + /// Get value from cache. + fn get(&self, k: &K) -> Option; + /// Get value from cache. + fn get_with_extra(&self, k: &K, e: &Self::Extra) -> Option; + /// Put value into cache. Returns the old value associated with the key if there was one. + fn put(&self, key: &K, value: V) -> Option; + /// Put value into cache. Returns the old value associated with the key if there was one. + fn put_with_extra(&self, key: &K, value: V, e: &Self::Extra) -> Option; + /// Remove an entry from the cache, returning value if they existed in the map. + fn remove(&mut self, k: &K) -> Option; + /// Check if the cache contains a specific key. + fn contains_key(&self, k: &K) -> bool; + /// Fetch the total number of cache entries. + fn len(&self) -> usize; + /// Check if the Cache collection is empty or not. + fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Remove all entries from the cache. + fn clear(&self); + /// Return the cache name. + fn name(&self) -> String; +} diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 97770eb99c581..8556335b395a9 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -25,7 +25,7 @@ use std::{ use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; /// Configuration options for Execution context -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct SessionConfig { /// Configuration options options: ConfigOptions, @@ -86,7 +86,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + self.set(key, ScalarValue::from(value)) } /// Customize batch size @@ -145,6 +145,14 @@ impl SessionConfig { self.options.optimizer.repartition_sorts } + /// Prefer existing sort (true) or maximize parallelism (false). See + /// [prefer_existing_sort] for more details + /// + /// [prefer_existing_sort]: datafusion_common::config::OptimizerOptions::prefer_existing_sort + pub fn prefer_existing_sort(&self) -> bool { + self.options.optimizer.prefer_existing_sort + } + /// Are statistics collected during execution? pub fn collect_statistics(&self) -> bool { self.options.execution.collect_statistics @@ -215,6 +223,15 @@ impl SessionConfig { self } + /// Prefer existing sort (true) or maximize parallelism (false). See + /// [prefer_existing_sort] for more details + /// + /// [prefer_existing_sort]: datafusion_common::config::OptimizerOptions::prefer_existing_sort + pub fn with_prefer_existing_sort(mut self, enabled: bool) -> Self { + self.options.optimizer.prefer_existing_sort = enabled; + self + } + /// Enables or disables the use of pruning predicate for parquet readers to skip row groups pub fn with_parquet_pruning(mut self, enabled: bool) -> Self { self.options.execution.parquet.pruning = enabled; @@ -274,6 +291,32 @@ impl SessionConfig { self.options.optimizer.enable_round_robin_repartition } + /// Set the size of [`sort_spill_reservation_bytes`] to control + /// memory pre-reservation + /// + /// [`sort_spill_reservation_bytes`]: datafusion_common::config::ExecutionOptions::sort_spill_reservation_bytes + pub fn with_sort_spill_reservation_bytes( + mut self, + sort_spill_reservation_bytes: usize, + ) -> Self { + self.options.execution.sort_spill_reservation_bytes = + sort_spill_reservation_bytes; + self + } + + /// Set the size of [`sort_in_place_threshold_bytes`] to control + /// how sort does things. + /// + /// [`sort_in_place_threshold_bytes`]: datafusion_common::config::ExecutionOptions::sort_in_place_threshold_bytes + pub fn with_sort_in_place_threshold_bytes( + mut self, + sort_in_place_threshold_bytes: usize, + ) -> Self { + self.options.execution.sort_in_place_threshold_bytes = + sort_in_place_threshold_bytes; + self + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 107c58fbe327d..fa9a75b2f496e 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -22,7 +22,7 @@ use datafusion_common::{DataFusionError, Result}; use log::debug; use parking_lot::Mutex; use rand::{thread_rng, Rng}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; use tempfile::{Builder, NamedTempFile, TempDir}; @@ -75,7 +75,7 @@ pub struct DiskManager { /// /// If `Some(vec![])` a new OS specified temporary directory will be created /// If `None` an error will be returned (configured not to spill) - local_dirs: Mutex>>, + local_dirs: Mutex>>>, } impl DiskManager { @@ -102,11 +102,21 @@ impl DiskManager { } } + /// Return true if this disk manager supports creating temporary + /// files. If this returns false, any call to `create_tmp_file` + /// will error. + pub fn tmp_files_enabled(&self) -> bool { + self.local_dirs.lock().is_some() + } + /// Return a temporary file from a randomized choice in the configured locations /// /// If the file can not be created for some reason, returns an /// error message referencing the request description - pub fn create_tmp_file(&self, request_description: &str) -> Result { + pub fn create_tmp_file( + &self, + request_description: &str, + ) -> Result { let mut guard = self.local_dirs.lock(); let local_dirs = guard.as_mut().ok_or_else(|| { DataFusionError::ResourcesExhausted(format!( @@ -124,18 +134,42 @@ impl DiskManager { request_description, ); - local_dirs.push(tempdir); + local_dirs.push(Arc::new(tempdir)); } let dir_index = thread_rng().gen_range(0..local_dirs.len()); - Builder::new() - .tempfile_in(&local_dirs[dir_index]) - .map_err(DataFusionError::IoError) + Ok(RefCountedTempFile { + parent_temp_dir: local_dirs[dir_index].clone(), + tempfile: Builder::new() + .tempfile_in(local_dirs[dir_index].as_ref()) + .map_err(DataFusionError::IoError)?, + }) + } +} + +/// A wrapper around a [`NamedTempFile`] that also contains +/// a reference to its parent temporary directory +#[derive(Debug)] +pub struct RefCountedTempFile { + /// The reference to the directory in which temporary files are created to ensure + /// it is not cleaned up prior to the NamedTempFile + #[allow(dead_code)] + parent_temp_dir: Arc, + tempfile: NamedTempFile, +} + +impl RefCountedTempFile { + pub fn path(&self) -> &Path { + self.tempfile.path() + } + + pub fn inner(&self) -> &NamedTempFile { + &self.tempfile } } /// Setup local dirs by creating one new dir in each of the given dirs -fn create_local_dirs(local_dirs: Vec) -> Result> { +fn create_local_dirs(local_dirs: Vec) -> Result>> { local_dirs .iter() .map(|root| { @@ -147,6 +181,7 @@ fn create_local_dirs(local_dirs: Vec) -> Result> { .tempdir_in(root) .map_err(DataFusionError::IoError) }) + .map(|result| result.map(Arc::new)) .collect() } @@ -198,6 +233,7 @@ mod tests { ); let dm = DiskManager::try_new(config)?; + assert!(dm.tmp_files_enabled()); let actual = dm.create_tmp_file("Testing")?; // the file should be in one of the specified local directories @@ -210,8 +246,9 @@ mod tests { fn test_disabled_disk_manager() { let config = DiskManagerConfig::Disabled; let manager = DiskManager::try_new(config).unwrap(); + assert!(!manager.tmp_files_enabled()); assert_eq!( - manager.create_tmp_file("Testing").unwrap_err().to_string(), + manager.create_tmp_file("Testing").unwrap_err().strip_backtrace(), "Resources exhausted: Memory Exhausted while Testing (DiskManager is disabled)", ) } @@ -241,4 +278,41 @@ mod tests { assert!(found, "Can't find {file_path:?} in dirs: {dirs:?}"); } + + #[test] + fn test_temp_file_still_alive_after_disk_manager_dropped() -> Result<()> { + // Test for the case using OS arranged temporary directory + let config = DiskManagerConfig::new(); + let dm = DiskManager::try_new(config)?; + let temp_file = dm.create_tmp_file("Testing")?; + let temp_file_path = temp_file.path().to_owned(); + assert!(temp_file_path.exists()); + + drop(dm); + assert!(temp_file_path.exists()); + + drop(temp_file); + assert!(!temp_file_path.exists()); + + // Test for the case using specified directories + let local_dir1 = TempDir::new()?; + let local_dir2 = TempDir::new()?; + let local_dir3 = TempDir::new()?; + let local_dirs = [local_dir1.path(), local_dir2.path(), local_dir3.path()]; + let config = DiskManagerConfig::new_specified( + local_dirs.iter().map(|p| p.into()).collect(), + ); + let dm = DiskManager::try_new(config)?; + let temp_file = dm.create_tmp_file("Testing")?; + let temp_file_path = temp_file.path().to_owned(); + assert!(temp_file_path.exists()); + + drop(dm); + assert!(temp_file_path.exists()); + + drop(temp_file); + assert!(!temp_file_path.exists()); + + Ok(()) + } } diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 46ffe12942568..a1a1551c2ca61 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -17,14 +17,17 @@ //! DataFusion execution configuration and runtime structures +pub mod cache; pub mod config; pub mod disk_manager; pub mod memory_pool; pub mod object_store; pub mod registry; pub mod runtime_env; +mod stream; mod task; pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; +pub use stream::{RecordBatchStream, SendableRecordBatchStream}; pub use task::TaskContext; diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index d002cda8d8aba..55555014f2ef7 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! Manages all available memory during query execution use datafusion_common::Result; -use std::sync::Arc; +use std::{cmp::Ordering, sync::Arc}; mod pool; pub mod proxy; @@ -46,9 +46,9 @@ pub use pool::*; /// /// The following memory pool implementations are available: /// -/// * [`UnboundedMemoryPool`](pool::UnboundedMemoryPool) -/// * [`GreedyMemoryPool`](pool::GreedyMemoryPool) -/// * [`FairSpillPool`](pool::FairSpillPool) +/// * [`UnboundedMemoryPool`] +/// * [`GreedyMemoryPool`] +/// * [`FairSpillPool`] pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Registers a new [`MemoryConsumer`] /// @@ -77,7 +77,9 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { fn reserved(&self) -> usize; } -/// A memory consumer that can be tracked by [`MemoryReservation`] in a [`MemoryPool`] +/// A memory consumer that can be tracked by [`MemoryReservation`] in +/// a [`MemoryPool`]. All allocations are registered to a particular +/// `MemoryConsumer`; #[derive(Debug)] pub struct MemoryConsumer { name: String, @@ -113,20 +115,40 @@ impl MemoryConsumer { pub fn register(self, pool: &Arc) -> MemoryReservation { pool.register(&self); MemoryReservation { - consumer: self, + registration: Arc::new(SharedRegistration { + pool: Arc::clone(pool), + consumer: self, + }), size: 0, - policy: Arc::clone(pool), } } } -/// A [`MemoryReservation`] tracks a reservation of memory in a [`MemoryPool`] -/// that is freed back to the pool on drop +/// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. +/// +/// Calls [`MemoryPool::unregister`] on drop to return any memory to +/// the underlying pool. #[derive(Debug)] -pub struct MemoryReservation { +struct SharedRegistration { + pool: Arc, consumer: MemoryConsumer, +} + +impl Drop for SharedRegistration { + fn drop(&mut self) { + self.pool.unregister(&self.consumer); + } +} + +/// A [`MemoryReservation`] tracks an individual reservation of a +/// number of bytes of memory in a [`MemoryPool`] that is freed back +/// to the pool on drop. +/// +/// The reservation can be grown or shrunk over time. +#[derive(Debug)] +pub struct MemoryReservation { + registration: Arc, size: usize, - policy: Arc, } impl MemoryReservation { @@ -135,7 +157,13 @@ impl MemoryReservation { self.size } - /// Frees all bytes from this reservation returning the number of bytes freed + /// Returns [MemoryConsumer] for this [MemoryReservation] + pub fn consumer(&self) -> &MemoryConsumer { + &self.registration.consumer + } + + /// Frees all bytes from this reservation back to the underlying + /// pool, returning the number of bytes freed. pub fn free(&mut self) -> usize { let size = self.size; if size != 0 { @@ -151,13 +179,12 @@ impl MemoryReservation { /// Panics if `capacity` exceeds [`Self::size`] pub fn shrink(&mut self, capacity: usize) { let new_size = self.size.checked_sub(capacity).unwrap(); - self.policy.shrink(self, capacity); + self.registration.pool.shrink(self, capacity); self.size = new_size } /// Sets the size of this reservation to `capacity` pub fn resize(&mut self, capacity: usize) { - use std::cmp::Ordering; match capacity.cmp(&self.size) { Ordering::Greater => self.grow(capacity - self.size), Ordering::Less => self.shrink(self.size - capacity), @@ -167,7 +194,6 @@ impl MemoryReservation { /// Try to set the size of this reservation to `capacity` pub fn try_resize(&mut self, capacity: usize) -> Result<()> { - use std::cmp::Ordering; match capacity.cmp(&self.size) { Ordering::Greater => self.try_grow(capacity - self.size)?, Ordering::Less => self.shrink(self.size - capacity), @@ -178,22 +204,55 @@ impl MemoryReservation { /// Increase the size of this reservation by `capacity` bytes pub fn grow(&mut self, capacity: usize) { - self.policy.grow(self, capacity); + self.registration.pool.grow(self, capacity); self.size += capacity; } - /// Try to increase the size of this reservation by `capacity` bytes + /// Try to increase the size of this reservation by `capacity` + /// bytes, returning error if there is insufficient capacity left + /// in the pool. pub fn try_grow(&mut self, capacity: usize) -> Result<()> { - self.policy.try_grow(self, capacity)?; + self.registration.pool.try_grow(self, capacity)?; self.size += capacity; Ok(()) } + + /// Splits off `capacity` bytes from this [`MemoryReservation`] + /// into a new [`MemoryReservation`] with the same + /// [`MemoryConsumer`]. + /// + /// This can be useful to free part of this reservation with RAAI + /// style dropping + /// + /// # Panics + /// + /// Panics if `capacity` exceeds [`Self::size`] + pub fn split(&mut self, capacity: usize) -> MemoryReservation { + self.size = self.size.checked_sub(capacity).unwrap(); + Self { + size: capacity, + registration: self.registration.clone(), + } + } + + /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn new_empty(&self) -> Self { + Self { + size: 0, + registration: self.registration.clone(), + } + } + + /// Splits off all the bytes from this [`MemoryReservation`] into + /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn take(&mut self) -> MemoryReservation { + self.split(self.size) + } } impl Drop for MemoryReservation { fn drop(&mut self) { self.free(); - self.policy.unregister(&self.consumer); } } @@ -253,4 +312,59 @@ mod tests { a2.try_grow(25).unwrap(); assert_eq!(pool.reserved(), 25); } + + #[test] + fn test_split() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 20); + + // take 5 from r1, should still have same reservation split + let r2 = r1.split(5); + assert_eq!(r1.size(), 15); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 20); + + // dropping r1 frees 15 but retains 5 as they have the same consumer + drop(r1); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 5); + } + + #[test] + fn test_new_empty() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.new_empty(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 20); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 25); + } + + #[test] + fn test_take() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.take(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 0); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 25); + + // r1 can still grow again + r1.try_grow(3).unwrap(); + assert_eq!(r1.size(), 3); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 28); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 7b68a86244b70..4a491630fe205 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -49,7 +49,7 @@ impl MemoryPool for UnboundedMemoryPool { /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. /// /// This pool works well for queries that do not need to spill or have -/// a single spillable operator. See [`GreedyMemoryPool`] if there are +/// a single spillable operator. See [`FairSpillPool`] if there are /// multiple spillable operators that all will spill. #[derive(Debug)] pub struct GreedyMemoryPool { @@ -84,7 +84,11 @@ impl MemoryPool for GreedyMemoryPool { (new_used <= self.pool_size).then_some(new_used) }) .map_err(|used| { - insufficient_capacity_err(reservation, additional, self.pool_size - used) + insufficient_capacity_err( + reservation, + additional, + self.pool_size.saturating_sub(used), + ) })?; Ok(()) } @@ -139,6 +143,7 @@ struct FairSpillPoolState { impl FairSpillPool { /// Allocate up to `limit` bytes pub fn new(pool_size: usize) -> Self { + debug!("Created new FairSpillPool(pool_size={pool_size})"); Self { pool_size, state: Mutex::new(FairSpillPoolState { @@ -159,13 +164,14 @@ impl MemoryPool for FairSpillPool { fn unregister(&self, consumer: &MemoryConsumer) { if consumer.can_spill { - self.state.lock().num_spill -= 1; + let mut state = self.state.lock(); + state.num_spill = state.num_spill.checked_sub(1).unwrap(); } } fn grow(&self, reservation: &MemoryReservation, additional: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable += additional, false => state.unspillable += additional, } @@ -173,7 +179,7 @@ impl MemoryPool for FairSpillPool { fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable -= shrink, false => state.unspillable -= shrink, } @@ -182,7 +188,7 @@ impl MemoryPool for FairSpillPool { fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => { // The total amount of memory available to spilling consumers let spill_available = self.pool_size.saturating_sub(state.unspillable); @@ -230,7 +236,7 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.consumer.name, reservation.size, available)) + DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available)) } #[cfg(test)] @@ -247,7 +253,7 @@ mod tests { r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("s1") + let mut r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -255,11 +261,11 @@ mod tests { assert_eq!(pool.reserved(), 4000); - let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + let err = r2.try_grow(1).unwrap_err().strip_backtrace(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); - let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + let err = r2.try_grow(1).unwrap_err().strip_backtrace(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); r1.shrink(1990); r2.shrink(2000); @@ -269,7 +275,7 @@ mod tests { r1.try_grow(10).unwrap(); assert_eq!(pool.reserved(), 20); - // Can grow a2 to 80 as only spilling consumer + // Can grow r2 to 80 as only spilling consumer r2.try_grow(80).unwrap(); assert_eq!(pool.reserved(), 100); @@ -279,19 +285,19 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("s2") + let mut r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); - let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + let err = r3.try_grow(70).unwrap_err().strip_backtrace(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - //Shrinking a2 to zero doesn't allow a3 to allocate more than 45 + //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); - let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + let err = r3.try_grow(70).unwrap_err().strip_backtrace(); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - // But dropping a2 does + // But dropping r2 does drop(r2); assert_eq!(pool.reserved(), 20); r3.try_grow(80).unwrap(); @@ -301,7 +307,7 @@ mod tests { assert_eq!(pool.reserved(), 80); let mut r4 = MemoryConsumer::new("s4").register(&pool); - let err = r4.try_grow(30).unwrap_err().to_string(); + let err = r4.try_grow(30).unwrap_err().strip_backtrace(); assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated - maximum available is 20"); } } diff --git a/datafusion/execution/src/memory_pool/proxy.rs b/datafusion/execution/src/memory_pool/proxy.rs index 43532f9a81f13..ced977b3bff33 100644 --- a/datafusion/execution/src/memory_pool/proxy.rs +++ b/datafusion/execution/src/memory_pool/proxy.rs @@ -26,6 +26,11 @@ pub trait VecAllocExt { /// [Push](Vec::push) new element to vector and store additional allocated bytes in `accounting` (additive). fn push_accounted(&mut self, x: Self::T, accounting: &mut usize); + + /// Return the amount of memory allocated by this Vec (not + /// recursively counting any heap allocations contained within the + /// structure). Does not include the size of `self` + fn allocated_size(&self) -> usize; } impl VecAllocExt for Vec { @@ -44,6 +49,9 @@ impl VecAllocExt for Vec { self.push(x); } + fn allocated_size(&self) -> usize { + std::mem::size_of::() * self.capacity() + } } /// Extension trait for [`RawTable`] to account for allocations. @@ -76,7 +84,7 @@ impl RawTableAllocExt for RawTable { Err(x) => { // need to request more memory - let bump_elements = (self.capacity() * 2).max(16); + let bump_elements = self.capacity().max(16); let bump_size = bump_elements * std::mem::size_of::(); *accounting = (*accounting).checked_add(bump_size).expect("overflow"); diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index 803f703452afc..5a1cdb769098c 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -20,7 +20,7 @@ //! and query data inside these systems. use dashmap::DashMap; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; use object_store::local::LocalFileSystem; use object_store::ObjectStore; use std::sync::Arc; @@ -40,9 +40,9 @@ impl ObjectStoreUrl { let remaining = &parsed[url::Position::BeforePath..]; if !remaining.is_empty() && remaining != "/" { - return Err(DataFusionError::Execution(format!( + return exec_err!( "ObjectStoreUrl must only contain scheme and authority, got: {remaining}" - ))); + ); } // Always set path for consistency @@ -234,20 +234,20 @@ mod tests { assert_eq!(url.as_str(), "s3://username:password@host:123/"); let err = ObjectStoreUrl::parse("s3://bucket:invalid").unwrap_err(); - assert_eq!(err.to_string(), "External error: invalid port number"); + assert_eq!(err.strip_backtrace(), "External error: invalid port number"); let err = ObjectStoreUrl::parse("s3://bucket?").unwrap_err(); - assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?"); + assert_eq!(err.strip_backtrace(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?"); let err = ObjectStoreUrl::parse("s3://bucket?foo=bar").unwrap_err(); - assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?foo=bar"); + assert_eq!(err.strip_backtrace(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: ?foo=bar"); let err = ObjectStoreUrl::parse("s3://host:123/foo").unwrap_err(); - assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo"); + assert_eq!(err.strip_backtrace(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo"); let err = ObjectStoreUrl::parse("s3://username:password@host:123/foo").unwrap_err(); - assert_eq!(err.to_string(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo"); + assert_eq!(err.strip_backtrace(), "Execution error: ObjectStoreUrl must only contain scheme and authority, got: /foo"); } #[test] diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index ef06c74cc2923..9ba487e715b3b 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -18,7 +18,7 @@ //! FunctionRegistry trait use datafusion_common::Result; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode}; +use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::{collections::HashSet, sync::Arc}; /// A registry knows how to build logical expressions out of user-defined function' names @@ -31,6 +31,9 @@ pub trait FunctionRegistry { /// Returns a reference to the udaf named `name`. fn udaf(&self, name: &str) -> Result>; + + /// Returns a reference to the udwf named `name`. + fn udwf(&self, name: &str) -> Result>; } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 8f9c594681d0e..e78a9e0de9f04 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -24,6 +24,7 @@ use crate::{ object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; +use crate::cache::cache_manager::{CacheManager, CacheManagerConfig}; use datafusion_common::{DataFusionError, Result}; use object_store::ObjectStore; use std::fmt::{Debug, Formatter}; @@ -33,19 +34,22 @@ use url::Url; #[derive(Clone)] /// Execution runtime environment that manages system resources such -/// as memory, disk and storage. +/// as memory, disk, cache and storage. /// /// A [`RuntimeEnv`] is created from a [`RuntimeConfig`] and has the /// following resource management functionality: /// /// * [`MemoryPool`]: Manage memory /// * [`DiskManager`]: Manage temporary files on local disk +/// * [`CacheManager`]: Manage temporary cache data during the session lifetime /// * [`ObjectStoreRegistry`]: Manage mapping URLs to object store instances pub struct RuntimeEnv { /// Runtime memory management pub memory_pool: Arc, /// Manage temporary files during query execution pub disk_manager: Arc, + /// Manage temporary cache during query execution + pub cache_manager: Arc, /// Object Store Registry pub object_store_registry: Arc, } @@ -62,6 +66,7 @@ impl RuntimeEnv { let RuntimeConfig { memory_pool, disk_manager, + cache_manager, object_store_registry, } = config; @@ -71,6 +76,7 @@ impl RuntimeEnv { Ok(Self { memory_pool, disk_manager: DiskManager::try_new(disk_manager)?, + cache_manager: CacheManager::try_new(&cache_manager)?, object_store_registry, }) } @@ -116,6 +122,8 @@ pub struct RuntimeConfig { /// /// Defaults to using an [`UnboundedMemoryPool`] if `None` pub memory_pool: Option>, + /// CacheManager to manage cache data + pub cache_manager: CacheManagerConfig, /// ObjectStoreRegistry to get object store based on url pub object_store_registry: Arc, } @@ -132,6 +140,7 @@ impl RuntimeConfig { Self { disk_manager: Default::default(), memory_pool: Default::default(), + cache_manager: Default::default(), object_store_registry: Arc::new(DefaultObjectStoreRegistry::default()), } } @@ -148,6 +157,12 @@ impl RuntimeConfig { self } + /// Customize cache policy + pub fn with_cache_manager(mut self, cache_manager: CacheManagerConfig) -> Self { + self.cache_manager = cache_manager; + self + } + /// Customize object store registry pub fn with_object_store_registry( mut self, diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs new file mode 100644 index 0000000000000..7fc5e458b86b5 --- /dev/null +++ b/datafusion/execution/src/stream.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::Result; +use futures::Stream; +use std::pin::Pin; + +/// Trait for types that stream [arrow::record_batch::RecordBatch] +pub trait RecordBatchStream: Stream> { + /// Returns the schema of this `RecordBatchStream`. + /// + /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this + /// stream should have the same schema as returned from this method. + fn schema(&self) -> SchemaRef; +} + +/// Trait for a [`Stream`] of [`RecordBatch`]es +pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index ca1bc9369e351..52c183b1612c6 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -22,21 +22,25 @@ use std::{ use datafusion_common::{ config::{ConfigOptions, Extensions}, - DataFusionError, Result, + plan_datafusion_err, DataFusionError, Result, }; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ - config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, - runtime_env::RuntimeEnv, + config::SessionConfig, + memory_pool::MemoryPool, + registry::FunctionRegistry, + runtime_env::{RuntimeConfig, RuntimeEnv}, }; /// Task Execution Context /// -/// A [`TaskContext`] has represents the state available during a single query's -/// execution. +/// A [`TaskContext`] contains the state available during a single +/// query's execution. Please see [`SessionContext`] for a user level +/// multi-query API. /// -/// # Task Context +/// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html +#[derive(Debug)] pub struct TaskContext { /// Session Id session_id: String, @@ -48,18 +52,43 @@ pub struct TaskContext { scalar_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, + /// Window functions associated with this task context + window_functions: HashMap>, /// Runtime environment associated with this task context runtime: Arc, } +impl Default for TaskContext { + fn default() -> Self { + let runtime = RuntimeEnv::new(RuntimeConfig::new()) + .expect("defauly runtime created successfully"); + + // Create a default task context, mostly useful for testing + Self { + session_id: "DEFAULT".to_string(), + task_id: None, + session_config: SessionConfig::new(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + runtime: Arc::new(runtime), + } + } +} + impl TaskContext { - /// Create a new task context instance + /// Create a new [`TaskContext`] instance. + /// + /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s + /// + /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, aggregate_functions: HashMap>, + window_functions: HashMap>, runtime: Arc, ) -> Self { Self { @@ -68,6 +97,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, } } @@ -92,6 +122,7 @@ impl TaskContext { config.set(&k, &v)?; } let session_config = SessionConfig::from(config); + let window_functions = HashMap::new(); Ok(Self::new( Some(task_id), @@ -99,6 +130,7 @@ impl TaskContext { session_config, scalar_functions, aggregate_functions, + window_functions, runtime, )) } @@ -127,6 +159,18 @@ impl TaskContext { pub fn runtime_env(&self) -> Arc { self.runtime.clone() } + + /// Update the [`ConfigOptions`] + pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { + self.session_config = session_config; + self + } + + /// Update the [`RuntimeEnv`] + pub fn with_runtime(mut self, runtime: Arc) -> Self { + self.runtime = runtime; + self + } } impl FunctionRegistry for TaskContext { @@ -138,18 +182,24 @@ impl FunctionRegistry for TaskContext { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDF named \"{name}\" in the TaskContext" - )) + plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext") }) } fn udaf(&self, name: &str) -> Result> { let result = self.aggregate_functions.get(name); + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext") + }) + } + + fn udwf(&self, name: &str) -> Result> { + let result = self.window_functions.get(name); + result.cloned().ok_or_else(|| { DataFusionError::Internal(format!( - "There is no UDAF named \"{name}\" in the TaskContext" + "There is no UDWF named \"{name}\" in the TaskContext" )) }) } @@ -186,6 +236,7 @@ mod tests { session_config, HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 76b3ef136ebae..3e05dae61954a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-expr" description = "Logical plan and expression representation for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -35,14 +35,17 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } -datafusion-common = { path = "../common", version = "26.0.0" } -lazy_static = { version = "^1.4.0" } -sqlparser = "0.34" -strum = { version = "0.24", features = ["derive"] } -strum_macros = "0.24" +arrow-array = { workspace = true } +datafusion-common = { workspace = true } +paste = "^1.0" +sqlparser = { workspace = true } +strum = { version = "0.25.0", features = ["derive"] } +strum_macros = "0.25.0" [dev-dependencies] -ctor = "0.2.0" -env_logger = "0.10" +ctor = { workspace = true } +env_logger = { workspace = true } diff --git a/datafusion/expr/README.md b/datafusion/expr/README.md index bcce30be39d95..b086f930e871b 100644 --- a/datafusion/expr/README.md +++ b/datafusion/expr/README.md @@ -19,7 +19,7 @@ # DataFusion Logical Plan and Expressions -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for logical plans and expressions. diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 7e941d0cff97f..32de88b3d99f3 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -18,71 +18,178 @@ //! Accumulator module contains the trait definition for aggregation function's accumulators. use arrow::array::ArrayRef; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use std::fmt::Debug; -/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. +/// Describes an aggregate functions's state. +/// +/// `Accumulator`s are stateful objects that live throughout the +/// evaluation of multiple rows and aggregate multiple values together +/// into a final output aggregate. /// /// An accumulator knows how to: -/// * update its state from inputs via `update_batch` -/// * retract an update to its state from given inputs via `retract_batch` -/// * convert its internal state to a vector of aggregate values -/// * update its state from multiple accumulators' states via `merge_batch` -/// * compute the final value from its internal state via `evaluate` +/// * update its state from inputs via [`update_batch`] +/// +/// * compute the final value from its internal state via [`evaluate`] +/// +/// * retract an update to its state from given inputs via +/// [`retract_batch`] (when used as a window aggregate [window +/// function]) +/// +/// * convert its internal state to a vector of aggregate values via +/// [`state`] and combine the state from multiple accumulators' +/// via [`merge_batch`], as part of efficient multi-phase grouping. +/// +/// [`update_batch`]: Self::update_batch +/// [`retract_batch`]: Self::retract_batch +/// [`state`]: Self::state +/// [`evaluate`]: Self::evaluate +/// [`merge_batch`]: Self::merge_batch +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) pub trait Accumulator: Send + Sync + Debug { - /// Returns the partial intermediate state of the accumulator. This - /// partial state is serialied as `Arrays` and then combined with - /// other partial states from different instances of this - /// accumulator (that ran on different partitions, for - /// example). + /// Updates the accumulator's state from its input. /// - /// The state can be and often is a different type than the output - /// type of the [`Accumulator`]. + /// `values` contains the arguments to this aggregate function. + /// + /// For example, the `SUM` accumulator maintains a running sum, + /// and `update_batch` adds each of the input values to the + /// running sum. + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; + + /// Returns the final aggregate value. + /// + /// For example, the `SUM` accumulator maintains a running sum, + /// and `evaluate` will produce that running sum as its output. + fn evaluate(&self) -> Result; + + /// Returns the allocated size required for this accumulator, in + /// bytes, including `Self`. /// - /// See [`Self::merge_batch`] for more details on the merging process. + /// This value is used to calculate the memory used during + /// execution so DataFusion can stay within its allotted limit. + /// + /// "Allocated" means that for internal containers such as `Vec`, + /// the `capacity` should be used not the `len`. + fn size(&self) -> usize; + + /// Returns the intermediate state of the accumulator. + /// + /// Intermediate state is used for "multi-phase" grouping in + /// DataFusion, where an aggregate is computed in parallel with + /// multiple `Accumulator` instances, as illustrated below: + /// + /// # MultiPhase Grouping + /// + /// ```text + /// ▲ + /// │ evaluate() is called to + /// │ produce the final aggregate + /// │ value per group + /// │ + /// ┌─────────────────────────┐ + /// │GroupBy │ + /// │(AggregateMode::Final) │ state() is called for each + /// │ │ group and the resulting + /// └─────────────────────────┘ RecordBatches passed to the + /// ▲ + /// │ + /// ┌────────────────┴───────────────┐ + /// │ │ + /// │ │ + /// ┌─────────────────────────┐ ┌─────────────────────────┐ + /// │ GroubyBy │ │ GroubyBy │ + /// │(AggregateMode::Partial) │ │(AggregateMode::Partial) │ + /// └─────────────────────────┘ └────────────▲────────────┘ + /// ▲ │ + /// │ │ update_batch() is called for + /// │ │ each input RecordBatch + /// .─────────. .─────────. + /// ,─' '─. ,─' '─. + /// ; Input : ; Input : + /// : Partition 0 ; : Partition 1 ; + /// ╲ ╱ ╲ ╱ + /// '─. ,─' '─. ,─' + /// `───────' `───────' + /// ``` + /// + /// The partial state is serialied as `Arrays` and then combined + /// with other partial states from different instances of this + /// Accumulator (that ran on different partitions, for example). + /// + /// The state can be and often is a different type than the output + /// type of the [`Accumulator`] and needs different merge + /// operations (for example, the partial state for `COUNT` needs + /// to be summed together) /// /// Some accumulators can return multiple values for their /// intermediate states. For example average, tracks `sum` and /// `n`, and this function should return /// a vector of two values, sum and n. /// - /// `ScalarValue::List` can also be used to pass multiple values - /// if the number of intermediate values is not known at planning - /// time (e.g. median) + /// Note that [`ScalarValue::List`] can be used to pass multiple + /// values if the number of intermediate values is not known at + /// planning time (e.g. for `MEDIAN`) fn state(&self) -> Result>; - /// Updates the accumulator's state from a vector of arrays. - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; + /// Updates the accumulator's state from an `Array` containing one + /// or more intermediate values. + /// + /// For some aggregates (such as `SUM`), merge_batch is the same + /// as `update_batch`, but for some aggregrates (such as `COUNT`) + /// the operations differ. See [`Self::state`] for more details on how + /// state is used and merged. + /// + /// The `states` array passed was formed by concatenating the + /// results of calling [`Self::state`] on zero or more other + /// `Accumulator` instances. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; - /// Retracts an update (caused by the given inputs) to + /// Retracts (removed) an update (caused by the given inputs) to /// accumulator's state. /// /// This is the inverse operation of [`Self::update_batch`] and is used - /// to incrementally calculate window aggregates where the OVER + /// to incrementally calculate window aggregates where the `OVER` /// clause defines a bounded window. + /// + /// # Example + /// + /// For example, given the following input partition + /// + /// ```text + /// │ current │ + /// window + /// │ │ + /// ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐ + /// Input │ A │ B │ C │ D │ E │ F │ G │ H │ I │ + /// partition └────┴────┴────┴────┼────┴────┴────┴────┼────┘ + /// + /// │ next │ + /// window + /// ``` + /// + /// First, [`Self::evaluate`] will be called to produce the output + /// for the current window. + /// + /// Then, to advance to the next window: + /// + /// First, [`Self::retract_batch`] will be called with the values + /// that are leaving the window, `[B, C, D]` and then + /// [`Self::update_batch`] will be called with the values that are + /// entering the window, `[F, G, H]`. fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { // TODO add retract for all accumulators - Err(DataFusionError::Internal( - "Retract should be implemented for aggregate functions when used with custom window frame queries".to_string() - )) + internal_err!( + "Retract should be implemented for aggregate functions when used with custom window frame queries" + ) } - /// Updates the accumulator's state from an `Array` containing one - /// or more intermediate values. + /// Does the accumulator support incrementally updating its value + /// by *removing* values. /// - /// The `states` array passed was formed by concatenating the - /// results of calling `[state]` on zero or more other accumulator - /// instances. - /// - /// `states` is an array of the same types as returned by [`Self::state`] - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; - - /// Returns the final aggregate value based on its current state. - fn evaluate(&self) -> Result; - - /// Allocated size required for this accumulator, in bytes, including `Self`. - /// Allocated means that for internal containers such as `Vec`, the `capacity` should be used - /// not the `len` - fn size(&self) -> usize; + /// If this function returns true, [`Self::retract_batch`] will be + /// called for sliding window functions such as queries with an + /// `OVER (ROWS BETWEEN 1 PRECEDING AND 2 FOLLOWING)` + fn supports_retract_batch(&self) -> bool { + false + } } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 5b0676a815099..cea72c3cb5e6b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -17,14 +17,17 @@ //! Aggregate function module contains all built-in aggregate functions definitions +use crate::utils; use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; use strum_macros::EnumIter; /// Enum of all built-in aggregate functions +// Contributor's guide for adding new aggregate functions +// https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { /// count @@ -61,6 +64,24 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Slope from linear regression + RegrSlope, + /// Intercept from linear regression + RegrIntercept, + /// Number of input rows in which both expressions are not null + RegrCount, + /// R-squared value from linear regression + RegrR2, + /// Average of the independent variable + RegrAvgx, + /// Average of the dependent variable + RegrAvgy, + /// Sum of squares of the independent variable + RegrSXX, + /// Sum of squares of the dependent variable + RegrSYY, + /// Sum of products of pairs of numbers + RegrSXY, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -79,10 +100,12 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", @@ -95,13 +118,22 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", - Variance => "VARIANCE", - VariancePop => "VARIANCE_POP", + Variance => "VAR", + VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - Covariance => "COVARIANCE", - CovariancePop => "COVARIANCE_POP", - Correlation => "CORRELATION", + Covariance => "COVAR", + CovariancePop => "COVAR_POP", + Correlation => "CORR", + RegrSlope => "REGR_SLOPE", + RegrIntercept => "REGR_INTERCEPT", + RegrCount => "REGR_COUNT", + RegrR2 => "REGR_R2", + RegrAvgx => "REGR_AVGX", + RegrAvgy => "REGR_AVGY", + RegrSXX => "REGR_SXX", + RegrSYY => "REGR_SYY", + RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", ApproxMedian => "APPROX_MEDIAN", @@ -111,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -141,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -152,6 +186,15 @@ impl FromStr for AggregateFunction { "var" => AggregateFunction::Variance, "var_pop" => AggregateFunction::VariancePop, "var_samp" => AggregateFunction::Variance, + "regr_slope" => AggregateFunction::RegrSlope, + "regr_intercept" => AggregateFunction::RegrIntercept, + "regr_count" => AggregateFunction::RegrCount, + "regr_r2" => AggregateFunction::RegrR2, + "regr_avgx" => AggregateFunction::RegrAvgx, + "regr_avgy" => AggregateFunction::RegrAvgy, + "regr_sxx" => AggregateFunction::RegrSXX, + "regr_syy" => AggregateFunction::RegrSYY, + "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_distinct" => AggregateFunction::ApproxDistinct, "approx_median" => AggregateFunction::ApproxMedian, @@ -162,9 +205,7 @@ impl FromStr for AggregateFunction { // other "grouping" => AggregateFunction::Grouping, _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))); + return plan_err!("There is no built-in function named {name}"); } }) } @@ -172,58 +213,97 @@ impl FromStr for AggregateFunction { /// Returns the datatype of the aggregate function. /// This is used to get the returned data type for aggregate expr. +#[deprecated( + since = "27.0.0", + note = "please use `AggregateFunction::return_type` instead" +)] pub fn return_type( fun: &AggregateFunction, input_expr_types: &[DataType], ) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. + fun.return_type(input_expr_types) +} - let coerced_data_types = crate::type_coercion::aggregates::coerce_types( - fun, - input_expr_types, - &signature(fun), - )?; +impl AggregateFunction { + /// Returns the datatype of the aggregate function given its argument types + /// + /// This is used to get the returned data type for aggregate expr. + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. - match fun { - AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Ok(DataType::Int64) - } - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => Ok(coerced_data_types[0].clone()), - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean), - AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), - AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]), - AggregateFunction::Covariance => covariance_return_type(&coerced_data_types[0]), - AggregateFunction::CovariancePop => { - covariance_return_type(&coerced_data_types[0]) - } - AggregateFunction::Correlation => correlation_return_type(&coerced_data_types[0]), - AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), - AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), - AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), - AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( - "item", - coerced_data_types[0].clone(), - true, - )))), - AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::ApproxPercentileContWithWeight => { - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::ApproxMedian | AggregateFunction::Median => { - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::FirstValue | AggregateFunction::LastValue => { - Ok(coerced_data_types[0].clone()) + let coerced_data_types = coerce_types(self, input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(DataType::Int64) + } + AggregateFunction::Max | AggregateFunction::Min => { + // For min and max agg function, the returned type is same as input type. + // The coerced_data_types is same with input_types. + Ok(coerced_data_types[0].clone()) + } + AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::BitAnd + | AggregateFunction::BitOr + | AggregateFunction::BitXor => Ok(coerced_data_types[0].clone()), + AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { + Ok(DataType::Boolean) + } + AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), + AggregateFunction::VariancePop => { + variance_return_type(&coerced_data_types[0]) + } + AggregateFunction::Covariance => { + covariance_return_type(&coerced_data_types[0]) + } + AggregateFunction::CovariancePop => { + covariance_return_type(&coerced_data_types[0]) + } + AggregateFunction::Correlation => { + correlation_return_type(&coerced_data_types[0]) + } + AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => Ok(DataType::Float64), + AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), + AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( + "item", + coerced_data_types[0].clone(), + true, + )))), + AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), + AggregateFunction::ApproxPercentileContWithWeight => { + Ok(coerced_data_types[0].clone()) + } + AggregateFunction::ApproxMedian | AggregateFunction::Median => { + Ok(coerced_data_types[0].clone()) + } + AggregateFunction::Grouping => Ok(DataType::Int32), + AggregateFunction::FirstValue | AggregateFunction::LastValue => { + Ok(coerced_data_types[0].clone()) + } + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -236,79 +316,127 @@ pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { let coerced_data_types = crate::type_coercion::aggregates::coerce_types( &fun, input_expr_types, - &signature(&fun), + &fun.signature(), )?; avg_sum_type(&coerced_data_types[0]) } /// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `AggregateFunction::signature` instead" +)] pub fn signature(fun: &AggregateFunction) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match fun { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), - AggregateFunction::ApproxDistinct - | AggregateFunction::Grouping - | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => { - Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable) - } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) - } - AggregateFunction::Avg - | AggregateFunction::Sum - | AggregateFunction::Variance - | AggregateFunction::VariancePop - | AggregateFunction::Stddev - | AggregateFunction::StddevPop - | AggregateFunction::Median - | AggregateFunction::ApproxMedian - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => { - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::Covariance | AggregateFunction::CovariancePop => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::Correlation => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::ApproxPercentileCont => { - // Accept any numeric value paired with a float64 percentile - let with_tdigest_size = NUMERICS.iter().map(|t| { - TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()]) - }); - Signature::one_of( + fun.signature() +} + +impl AggregateFunction { + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), + AggregateFunction::ApproxDistinct + | AggregateFunction::Grouping + | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), + AggregateFunction::Min | AggregateFunction::Max => { + let valid = STRINGS + .iter() + .chain(NUMERICS.iter()) + .chain(TIMESTAMPS.iter()) + .chain(DATES.iter()) + .chain(TIMES.iter()) + .chain(BINARYS.iter()) + .cloned() + .collect::>(); + Signature::uniform(1, valid, Volatility::Immutable) + } + AggregateFunction::BitAnd + | AggregateFunction::BitOr + | AggregateFunction::BitXor => { + Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable) + } + AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { + Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) + } + AggregateFunction::Avg + | AggregateFunction::Sum + | AggregateFunction::Variance + | AggregateFunction::VariancePop + | AggregateFunction::Stddev + | AggregateFunction::StddevPop + | AggregateFunction::Median + | AggregateFunction::ApproxMedian + | AggregateFunction::FirstValue + | AggregateFunction::LastValue => { + Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) + } + AggregateFunction::Covariance + | AggregateFunction::CovariancePop + | AggregateFunction::Correlation + | AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => { + Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) + } + AggregateFunction::ApproxPercentileCont => { + // Accept any numeric value paired with a float64 percentile + let with_tdigest_size = NUMERICS.iter().map(|t| { + TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()]) + }); + Signature::one_of( + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .chain(with_tdigest_size) + .collect(), + Volatility::Immutable, + ) + } + AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( + // Accept any numeric value paired with a float64 percentile NUMERICS .iter() - .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) - .chain(with_tdigest_size) + .map(|t| { + TypeSignature::Exact(vec![ + t.clone(), + t.clone(), + DataType::Float64, + ]) + }) .collect(), Volatility::Immutable, - ) + ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + + #[test] + // Test for AggregateFuncion's Display and from_str() implementations. + // For each variant in AggregateFuncion, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in AggregateFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = + AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); + assert_eq!(func_from_str, func_original); } - AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![t.clone(), t.clone(), DataType::Float64]) - }) - .collect(), - Volatility::Immutable, - ), } } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 2272997fae061..977b556b26cf0 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -17,17 +17,30 @@ //! Built-in functions module contains all the built-in functions definitions. -use crate::Volatility; -use datafusion_common::{DataFusionError, Result}; +use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + +use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::signature::TIMEZONE_WILDCARD; +use crate::type_coercion::binary::get_wider_type; +use crate::type_coercion::functions::data_types; +use crate::{ + conditional_expressions, struct_expressions, FuncMonotonicity, Signature, + TypeSignature, Volatility, +}; + +use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + use strum::IntoEnumIterator; use strum_macros::EnumIter; -use lazy_static::lazy_static; - /// Enum of all built-in scalar functions +// Contributor's guide for adding new scalar functions +// https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-scalar-function #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions @@ -57,10 +70,14 @@ pub enum BuiltinScalarFunction { Cos, /// cos Cosh, + /// Decode + Decode, /// degrees Degrees, /// Digest Digest, + /// Encode + Encode, /// exp Exp, /// factorial @@ -71,6 +88,10 @@ pub enum BuiltinScalarFunction { Gcd, /// lcm, Least common multiple Lcm, + /// isnan + Isnan, + /// iszero + Iszero, /// ln, Natural logarithm Ln, /// log, same as log10 @@ -79,6 +100,8 @@ pub enum BuiltinScalarFunction { Log10, /// log2 Log2, + /// nanvl + Nanvl, /// pi Pi, /// power @@ -101,16 +124,34 @@ pub enum BuiltinScalarFunction { Tanh, /// trunc Trunc, + /// cot + Cot, // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, + /// array_has + ArrayHas, + /// array_has_all + ArrayHasAll, + /// array_has_any + ArrayHasAny, + /// array_pop_front + ArrayPopFront, + /// array_pop_back + ArrayPopBack, /// array_dims ArrayDims, - /// array_fill - ArrayFill, + /// array_distinct + ArrayDistinct, + /// array_element + ArrayElement, + /// array_empty + ArrayEmpty, /// array_length ArrayLength, /// array_ndims @@ -123,16 +164,40 @@ pub enum BuiltinScalarFunction { ArrayPrepend, /// array_remove ArrayRemove, + /// array_remove_n + ArrayRemoveN, + /// array_remove_all + ArrayRemoveAll, + /// array_repeat + ArrayRepeat, /// array_replace ArrayReplace, + /// array_replace_n + ArrayReplaceN, + /// array_replace_all + ArrayReplaceAll, + /// array_slice + ArraySlice, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, + /// array_union + ArrayUnion, + /// array_except + ArrayExcept, /// cardinality Cardinality, /// construct an array from columns MakeArray, - /// trim_array - TrimArray, + /// Flatten + Flatten, + /// Range + Range, + + // struct functions + /// struct + Struct, // string functions /// ascii @@ -197,6 +262,8 @@ pub enum BuiltinScalarFunction { SHA512, /// split_part SplitPart, + /// string_to_array + StringToArray, /// starts_with StartsWith, /// strpos @@ -211,6 +278,8 @@ pub enum BuiltinScalarFunction { ToTimestampMillis, /// to_timestamp_micros ToTimestampMicros, + /// to_timestamp_nanos + ToTimestampNanos, /// to_timestamp_seconds ToTimestampSeconds, /// from_unixtime @@ -231,48 +300,64 @@ pub enum BuiltinScalarFunction { Uuid, /// regexp_match RegexpMatch, - /// struct - Struct, /// arrow_typeof ArrowTypeof, + /// overlay + OverLay, + /// levenshtein + Levenshtein, + /// substr_index + SubstrIndex, + /// find_in_set + FindInSet, } -lazy_static! { - /// Maps the sql function name to `BuiltinScalarFunction` - static ref NAME_TO_FUNCTION: HashMap<&'static str, BuiltinScalarFunction> = { - let mut map: HashMap<&'static str, BuiltinScalarFunction> = HashMap::new(); +/// Maps the sql function name to `BuiltinScalarFunction` +fn name_to_function() -> &'static HashMap<&'static str, BuiltinScalarFunction> { + static NAME_TO_FUNCTION_LOCK: OnceLock> = + OnceLock::new(); + NAME_TO_FUNCTION_LOCK.get_or_init(|| { + let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - let a = aliases(&func); - a.iter().for_each(|a| {map.insert(a, func);}); + func.aliases().iter().for_each(|&a| { + map.insert(a, func); + }); }); map - }; + }) +} - /// Maps `BuiltinScalarFunction` --> canonical sql function - /// First alias in the array is used to display function names - static ref FUNCTION_TO_NAME: HashMap = { - let mut map: HashMap = HashMap::new(); +/// Maps `BuiltinScalarFunction` --> canonical sql function +/// First alias in the array is used to display function names +fn function_to_name() -> &'static HashMap { + static FUNCTION_TO_NAME_LOCK: OnceLock> = + OnceLock::new(); + FUNCTION_TO_NAME_LOCK.get_or_init(|| { + let mut map = HashMap::new(); BuiltinScalarFunction::iter().for_each(|func| { - map.insert(func, aliases(&func).first().unwrap_or(&"NO_ALIAS")); + map.insert(func, *func.aliases().first().unwrap_or(&"NO_ALIAS")); }); map - }; + }) } impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. + #[deprecated( + since = "32.0.0", + note = "please use TypeSignature::supports_zero_argument instead" + )] pub fn supports_zero_argument(&self) -> bool { - matches!( - self, - BuiltinScalarFunction::Pi - | BuiltinScalarFunction::Random - | BuiltinScalarFunction::Now - | BuiltinScalarFunction::CurrentDate - | BuiltinScalarFunction::CurrentTime - | BuiltinScalarFunction::Uuid - ) + self.signature().type_signature.supports_zero_argument() + } + + /// Returns the name of this function + pub fn name(&self) -> &str { + // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction + function_to_name().get(self).unwrap() } + /// Returns the [Volatility] of the builtin function. pub fn volatility(&self) -> Volatility { match self { @@ -289,16 +374,21 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, BuiltinScalarFunction::Cosh => Volatility::Immutable, + BuiltinScalarFunction::Decode => Volatility::Immutable, BuiltinScalarFunction::Degrees => Volatility::Immutable, + BuiltinScalarFunction::Encode => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Floor => Volatility::Immutable, BuiltinScalarFunction::Gcd => Volatility::Immutable, + BuiltinScalarFunction::Isnan => Volatility::Immutable, + BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Lcm => Volatility::Immutable, BuiltinScalarFunction::Ln => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, @@ -307,24 +397,43 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sinh => Volatility::Immutable, BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Cbrt => Volatility::Immutable, + BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Tan => Volatility::Immutable, BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, + BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, + BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, + BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, + BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, - BuiltinScalarFunction::ArrayFill => Volatility::Immutable, + BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable, + BuiltinScalarFunction::ArrayElement => Volatility::Immutable, + BuiltinScalarFunction::ArrayExcept => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, + BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable, + BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, BuiltinScalarFunction::ArrayPrepend => Volatility::Immutable, + BuiltinScalarFunction::ArrayRepeat => Volatility::Immutable, BuiltinScalarFunction::ArrayRemove => Volatility::Immutable, + BuiltinScalarFunction::ArrayRemoveN => Volatility::Immutable, + BuiltinScalarFunction::ArrayRemoveAll => Volatility::Immutable, BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, + BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, + BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, + BuiltinScalarFunction::Flatten => Volatility::Immutable, + BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, + BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, + BuiltinScalarFunction::Range => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, - BuiltinScalarFunction::TrimArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -357,6 +466,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SHA512 => Volatility::Immutable, BuiltinScalarFunction::Digest => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, + BuiltinScalarFunction::StringToArray => Volatility::Immutable, BuiltinScalarFunction::StartsWith => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, @@ -364,6 +474,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampNanos => Volatility::Immutable, BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, BuiltinScalarFunction::Trim => Volatility::Immutable, @@ -372,6 +483,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Struct => Volatility::Immutable, BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, + BuiltinScalarFunction::OverLay => Volatility::Immutable, + BuiltinScalarFunction::Levenshtein => Volatility::Immutable, + BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, + BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -383,149 +498,1235 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Uuid => Volatility::Volatile, } } -} -fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { - match func { - BuiltinScalarFunction::Abs => &["abs"], - BuiltinScalarFunction::Acos => &["acos"], - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asin => &["asin"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], - BuiltinScalarFunction::Cbrt => &["cbrt"], - BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cos => &["cos"], - BuiltinScalarFunction::Cosh => &["cosh"], - BuiltinScalarFunction::Degrees => &["degrees"], - BuiltinScalarFunction::Exp => &["exp"], - BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], - BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], - BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], - BuiltinScalarFunction::Pi => &["pi"], - BuiltinScalarFunction::Power => &["power", "pow"], - BuiltinScalarFunction::Radians => &["radians"], - BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Signum => &["signum"], - BuiltinScalarFunction::Sin => &["sin"], - BuiltinScalarFunction::Sinh => &["sinh"], - BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], - BuiltinScalarFunction::Tanh => &["tanh"], - BuiltinScalarFunction::Trunc => &["trunc"], - - // conditional functions - BuiltinScalarFunction::Coalesce => &["coalesce"], - BuiltinScalarFunction::NullIf => &["nullif"], - - // string functions - BuiltinScalarFunction::Ascii => &["ascii"], - BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] + /// Returns the dimension [`DataType`] of [`DataType::List`] if + /// treated as a N-dimensional array. + /// + /// ## Examples: + /// + /// * `Int64` has dimension 1 + /// * `List(Int64)` has dimension 2 + /// * `List(List(Int64))` has dimension 3 + /// * etc. + fn return_dimension(self, input_expr_type: &DataType) -> u64 { + let mut result: u64 = 1; + let mut current_data_type = input_expr_type; + while let DataType::List(field) = current_data_type { + current_data_type = field.data_type(); + result += 1; + } + result + } + + /// Returns the output [`DataType`] of this function + /// + /// This method should be invoked only after `input_expr_types` have been validated + /// against the function's `TypeSignature` using `type_coercion::functions::data_types()`. + /// + /// This method will: + /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. + /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. + pub fn return_type(self, input_expr_types: &[DataType]) -> Result { + use DataType::*; + use TimeUnit::*; + + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // the return type of the built in function. + // Some built-in functions' return type depends on the incoming type. + match self { + BuiltinScalarFunction::Flatten => { + fn get_base_type(data_type: &DataType) -> Result { + match data_type { + DataType::List(field) => match field.data_type() { + DataType::List(_) => get_base_type(field.data_type()), + _ => Ok(data_type.to_owned()), + }, + _ => internal_err!("Not reachable, data_type should be List"), + } + } + + let data_type = get_base_type(&input_expr_types[0])?; + Ok(data_type) + } + BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayConcat => { + let mut expr_type = Null; + let mut max_dims = 0; + for input_expr_type in input_expr_types { + match input_expr_type { + List(field) => { + if !field.data_type().equals_datatype(&Null) { + let dims = self.return_dimension(input_expr_type); + expr_type = match max_dims.cmp(&dims) { + Ordering::Greater => expr_type, + Ordering::Equal => { + get_wider_type(&expr_type, input_expr_type)? + } + Ordering::Less => { + max_dims = dims; + input_expr_type.clone() + } + }; + } + } + _ => { + return plan_err!( + "The {self} function can only accept list as the args." + ) + } + } + } + + Ok(expr_type) + } + BuiltinScalarFunction::ArrayHasAll + | BuiltinScalarFunction::ArrayHasAny + | BuiltinScalarFunction::ArrayHas + | BuiltinScalarFunction::ArrayEmpty => Ok(Boolean), + BuiltinScalarFunction::ArrayDims => { + Ok(List(Arc::new(Field::new("item", UInt64, true)))) + } + BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { + List(field) => Ok(field.data_type().clone()), + _ => plan_err!( + "The {self} function can only accept list as the first argument" + ), + }, + BuiltinScalarFunction::ArrayLength => Ok(UInt64), + BuiltinScalarFunction::ArrayNdims => Ok(UInt64), + BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayPosition => Ok(UInt64), + BuiltinScalarFunction::ArrayPositions => { + Ok(List(Arc::new(Field::new("item", UInt64, true)))) + } + BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()), + BuiltinScalarFunction::ArrayRepeat => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::ArrayRemove => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayRemoveN => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayRemoveAll => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::Range => { + Ok(List(Arc::new(Field::new("item", Int64, true)))) + } + BuiltinScalarFunction::ArrayExcept => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, _) | (_, DataType::Null) => { + Ok(input_expr_types[0].clone()) + } + (dt, _) => Ok(dt), + } + } + BuiltinScalarFunction::Cardinality => Ok(UInt64), + BuiltinScalarFunction::MakeArray => match input_expr_types.len() { + 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), + _ => { + let mut expr_type = Null; + for input_expr_type in input_expr_types { + if !input_expr_type.equals_datatype(&Null) { + expr_type = input_expr_type.clone(); + break; + } + } + + Ok(List(Arc::new(Field::new("item", expr_type, true)))) + } + }, + BuiltinScalarFunction::Ascii => Ok(Int32), + BuiltinScalarFunction::BitLength => { + utf8_to_int_type(&input_expr_types[0], "bit_length") + } + BuiltinScalarFunction::Btrim => { + utf8_to_str_type(&input_expr_types[0], "btrim") + } + BuiltinScalarFunction::CharacterLength => { + utf8_to_int_type(&input_expr_types[0], "character_length") + } + BuiltinScalarFunction::Chr => Ok(Utf8), + BuiltinScalarFunction::Coalesce => { + // COALESCE has multiple args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|types| types[0].clone()) + } + BuiltinScalarFunction::Concat => Ok(Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), + BuiltinScalarFunction::DatePart => Ok(Float64), + BuiltinScalarFunction::DateBin | BuiltinScalarFunction::DateTrunc => { + match &input_expr_types[1] { + Timestamp(Nanosecond, None) | Utf8 | Null => { + Ok(Timestamp(Nanosecond, None)) + } + Timestamp(Nanosecond, tz_opt) => { + Ok(Timestamp(Nanosecond, tz_opt.clone())) + } + Timestamp(Microsecond, tz_opt) => { + Ok(Timestamp(Microsecond, tz_opt.clone())) + } + Timestamp(Millisecond, tz_opt) => { + Ok(Timestamp(Millisecond, tz_opt.clone())) + } + Timestamp(Second, tz_opt) => Ok(Timestamp(Second, tz_opt.clone())), + _ => plan_err!( + "The {self} function can only accept timestamp as the second arg." + ), + } + } + BuiltinScalarFunction::InitCap => { + utf8_to_str_type(&input_expr_types[0], "initcap") + } + BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), + BuiltinScalarFunction::Lower => { + utf8_to_str_type(&input_expr_types[0], "lower") + } + BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), + BuiltinScalarFunction::Ltrim => { + utf8_to_str_type(&input_expr_types[0], "ltrim") + } + BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), + BuiltinScalarFunction::NullIf => { + // NULLIF has two args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|typs| typs[0].clone()) + } + BuiltinScalarFunction::OctetLength => { + utf8_to_int_type(&input_expr_types[0], "octet_length") + } + BuiltinScalarFunction::Pi => Ok(Float64), + BuiltinScalarFunction::Random => Ok(Float64), + BuiltinScalarFunction::Uuid => Ok(Utf8), + BuiltinScalarFunction::RegexpReplace => { + utf8_to_str_type(&input_expr_types[0], "regexp_replace") + } + BuiltinScalarFunction::Repeat => { + utf8_to_str_type(&input_expr_types[0], "repeat") + } + BuiltinScalarFunction::Replace => { + utf8_to_str_type(&input_expr_types[0], "replace") + } + BuiltinScalarFunction::Reverse => { + utf8_to_str_type(&input_expr_types[0], "reverse") + } + BuiltinScalarFunction::Right => { + utf8_to_str_type(&input_expr_types[0], "right") + } + BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), + BuiltinScalarFunction::Rtrim => { + utf8_to_str_type(&input_expr_types[0], "rtrim") + } + BuiltinScalarFunction::SHA224 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") + } + BuiltinScalarFunction::SHA256 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") + } + BuiltinScalarFunction::SHA384 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") + } + BuiltinScalarFunction::SHA512 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") + } + BuiltinScalarFunction::Digest => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") + } + BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] { + Utf8 => Utf8, + LargeUtf8 => LargeUtf8, + Binary => Utf8, + LargeBinary => LargeUtf8, + Null => Null, + _ => { + return plan_err!( + "The encode function can only accept utf8 or binary." + ); + } + }), + BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] { + Utf8 => Binary, + LargeUtf8 => LargeBinary, + Binary => Binary, + LargeBinary => LargeBinary, + Null => Null, + _ => { + return plan_err!( + "The decode function can only accept utf8 or binary." + ); + } + }), + BuiltinScalarFunction::SplitPart => { + utf8_to_str_type(&input_expr_types[0], "split_part") + } + BuiltinScalarFunction::StringToArray => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::StartsWith => Ok(Boolean), + BuiltinScalarFunction::Strpos => { + utf8_to_int_type(&input_expr_types[0], "strpos") + } + BuiltinScalarFunction::Substr => { + utf8_to_str_type(&input_expr_types[0], "substr") + } + BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + return plan_err!("The to_hex function can only accept integers."); + } + }), + BuiltinScalarFunction::SubstrIndex => { + utf8_to_str_type(&input_expr_types[0], "substr_index") + } + BuiltinScalarFunction::FindInSet => { + utf8_to_int_type(&input_expr_types[0], "find_in_set") + } + BuiltinScalarFunction::ToTimestamp + | BuiltinScalarFunction::ToTimestampNanos => Ok(Timestamp(Nanosecond, None)), + BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), + BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), + BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::Now => { + Ok(Timestamp(Nanosecond, Some("+00:00".into()))) + } + BuiltinScalarFunction::CurrentDate => Ok(Date32), + BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), + BuiltinScalarFunction::Translate => { + utf8_to_str_type(&input_expr_types[0], "translate") + } + BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), + BuiltinScalarFunction::Upper => { + utf8_to_str_type(&input_expr_types[0], "upper") + } + BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { + LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), + Utf8 => List(Arc::new(Field::new("item", Utf8, true))), + Null => Null, + _ => { + return plan_err!( + "The regexp_extract function can only accept strings." + ); + } + }), + + BuiltinScalarFunction::Factorial + | BuiltinScalarFunction::Gcd + | BuiltinScalarFunction::Lcm => Ok(Int64), + + BuiltinScalarFunction::Power => match &input_expr_types[0] { + Int64 => Ok(Int64), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Struct => { + let return_fields = input_expr_types + .iter() + .enumerate() + .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) + .collect::>(); + Ok(Struct(Fields::from(return_fields))) + } + + BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Log => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Nanvl => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => Ok(Boolean), + + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), + + BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()), + + BuiltinScalarFunction::OverLay => { + utf8_to_str_type(&input_expr_types[0], "overlay") + } + + BuiltinScalarFunction::Levenshtein => { + utf8_to_int_type(&input_expr_types[0], "levenshtein") + } + + BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc + | BuiltinScalarFunction::Cot => match input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + } + } + + /// Return the argument [`Signature`] supported by this function + pub fn signature(&self) -> Signature { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + use TypeSignature::*; + // note: the physical expression must accept the type returned by this function or the execution panics. + + // for now, the list is small, as we do not have many built-in functions. + match self { + BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayConcat => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayHasAll + | BuiltinScalarFunction::ArrayHasAny + | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayLength => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayPosition => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), + BuiltinScalarFunction::ArrayReplaceAll => { + Signature::any(3, self.volatility()) + } + BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArrayToString => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) + } + BuiltinScalarFunction::Range => Signature::one_of( + vec![ + Exact(vec![Int64]), + Exact(vec![Int64, Int64]), + Exact(vec![Int64, Int64, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Struct => Signature::variadic( + struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::Concat + | BuiltinScalarFunction::ConcatWithSeparator => { + Signature::variadic(vec![Utf8], self.volatility()) + } + BuiltinScalarFunction::Coalesce => Signature::variadic( + conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::SHA224 + | BuiltinScalarFunction::SHA256 + | BuiltinScalarFunction::SHA384 + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::MD5 => Signature::uniform( + 1, + vec![Utf8, LargeUtf8, Binary, LargeBinary], + self.volatility(), + ), + BuiltinScalarFunction::Ascii + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::InitCap + | BuiltinScalarFunction::Lower + | BuiltinScalarFunction::OctetLength + | BuiltinScalarFunction::Reverse + | BuiltinScalarFunction::Upper => { + Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) + } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::Trim => Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + self.volatility(), + ), + BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + self.volatility(), + ) + } + BuiltinScalarFunction::Left + | BuiltinScalarFunction::Repeat + | BuiltinScalarFunction::Right => Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestamp => Signature::uniform( + 1, + vec![ + Int64, + Float64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampNanos => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::FromUnixtime => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Digest => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Encode => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Decode => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::DateTrunc => Signature::one_of( + vec![ + Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![ + Utf8, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), + ], + self.volatility(), + ), + BuiltinScalarFunction::DateBin => { + let base_sig = |array_type: TimeUnit| { + vec![ + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + ]), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), None), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), + ]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .map(base_sig) + .collect::>() + .concat(); + + Signature::one_of(full_sig, self.volatility()) + } + BuiltinScalarFunction::DatePart => Signature::one_of( + vec![ + Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![ + Utf8, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![ + Utf8, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), + Exact(vec![Utf8, Date64]), + Exact(vec![Utf8, Date32]), + ], + self.volatility(), + ), + BuiltinScalarFunction::SplitPart => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::StringToArray => Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), + TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ) + } + + BuiltinScalarFunction::Substr => Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::SubstrIndex => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::FindInSet => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), + + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) + } + BuiltinScalarFunction::RegexpReplace => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8, Utf8]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::NullIf => { + Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), self.volatility()) + } + BuiltinScalarFunction::RegexpMatch => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Power => Signature::one_of( + vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Round => Signature::one_of( + vec![ + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Trunc => Signature::one_of( + vec![ + Exact(vec![Float32, Int64]), + Exact(vec![Float64, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Atan2 => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Log => Signature::one_of( + vec![ + Exact(vec![Float32]), + Exact(vec![Float64]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Nanvl => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Factorial => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { + Signature::uniform(2, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), + BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()), + BuiltinScalarFunction::OverLay => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Levenshtein => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), + BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Cot => { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + Signature::uniform(1, vec![Float64, Float32], self.volatility()) + } + BuiltinScalarFunction::Now + | BuiltinScalarFunction::CurrentDate + | BuiltinScalarFunction::CurrentTime => { + Signature::uniform(0, vec![], self.volatility()) + } + BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => { + Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + self.volatility(), + ) + } + } + } + + /// This function specifies monotonicity behaviors for built-in scalar functions. + /// The list can be extended, only mathematical and datetime functions are + /// considered for the initial implementation of this feature. + pub fn monotonicity(&self) -> Option { + if matches!( + &self, + BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Factorial + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc + | BuiltinScalarFunction::Pi + ) { + Some(vec![Some(true)]) + } else if matches!( + &self, + BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin + ) { + Some(vec![None, Some(true)]) + } else if *self == BuiltinScalarFunction::Log { + Some(vec![Some(true), Some(false)]) + } else { + None + } + } + + /// Returns all names that can be used to call this function + pub fn aliases(&self) -> &'static [&'static str] { + match self { + BuiltinScalarFunction::Abs => &["abs"], + BuiltinScalarFunction::Acos => &["acos"], + BuiltinScalarFunction::Acosh => &["acosh"], + BuiltinScalarFunction::Asin => &["asin"], + BuiltinScalarFunction::Asinh => &["asinh"], + BuiltinScalarFunction::Atan => &["atan"], + BuiltinScalarFunction::Atanh => &["atanh"], + BuiltinScalarFunction::Atan2 => &["atan2"], + BuiltinScalarFunction::Cbrt => &["cbrt"], + BuiltinScalarFunction::Ceil => &["ceil"], + BuiltinScalarFunction::Cos => &["cos"], + BuiltinScalarFunction::Cot => &["cot"], + BuiltinScalarFunction::Cosh => &["cosh"], + BuiltinScalarFunction::Degrees => &["degrees"], + BuiltinScalarFunction::Exp => &["exp"], + BuiltinScalarFunction::Factorial => &["factorial"], + BuiltinScalarFunction::Floor => &["floor"], + BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Isnan => &["isnan"], + BuiltinScalarFunction::Iszero => &["iszero"], + BuiltinScalarFunction::Lcm => &["lcm"], + BuiltinScalarFunction::Ln => &["ln"], + BuiltinScalarFunction::Log => &["log"], + BuiltinScalarFunction::Log10 => &["log10"], + BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], + BuiltinScalarFunction::Pi => &["pi"], + BuiltinScalarFunction::Power => &["power", "pow"], + BuiltinScalarFunction::Radians => &["radians"], + BuiltinScalarFunction::Random => &["random"], + BuiltinScalarFunction::Round => &["round"], + BuiltinScalarFunction::Signum => &["signum"], + BuiltinScalarFunction::Sin => &["sin"], + BuiltinScalarFunction::Sinh => &["sinh"], + BuiltinScalarFunction::Sqrt => &["sqrt"], + BuiltinScalarFunction::Tan => &["tan"], + BuiltinScalarFunction::Tanh => &["tanh"], + BuiltinScalarFunction::Trunc => &["trunc"], + + // conditional functions + BuiltinScalarFunction::Coalesce => &["coalesce"], + BuiltinScalarFunction::NullIf => &["nullif"], + + // string functions + BuiltinScalarFunction::Ascii => &["ascii"], + BuiltinScalarFunction::BitLength => &["bit_length"], + BuiltinScalarFunction::Btrim => &["btrim"], + BuiltinScalarFunction::CharacterLength => { + &["character_length", "char_length", "length"] + } + BuiltinScalarFunction::Concat => &["concat"], + BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], + BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::Left => &["left"], + BuiltinScalarFunction::Lower => &["lower"], + BuiltinScalarFunction::Lpad => &["lpad"], + BuiltinScalarFunction::Ltrim => &["ltrim"], + BuiltinScalarFunction::OctetLength => &["octet_length"], + BuiltinScalarFunction::Repeat => &["repeat"], + BuiltinScalarFunction::Replace => &["replace"], + BuiltinScalarFunction::Reverse => &["reverse"], + BuiltinScalarFunction::Right => &["right"], + BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Rtrim => &["rtrim"], + BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StringToArray => { + &["string_to_array", "string_to_list"] + } + BuiltinScalarFunction::StartsWith => &["starts_with"], + BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Substr => &["substr"], + BuiltinScalarFunction::ToHex => &["to_hex"], + BuiltinScalarFunction::Translate => &["translate"], + BuiltinScalarFunction::Trim => &["trim"], + BuiltinScalarFunction::Upper => &["upper"], + BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], + BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], + BuiltinScalarFunction::FindInSet => &["find_in_set"], + + // regex functions + BuiltinScalarFunction::RegexpMatch => &["regexp_match"], + BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], + + // time/date functions + BuiltinScalarFunction::Now => &["now"], + BuiltinScalarFunction::CurrentDate => &["current_date"], + BuiltinScalarFunction::CurrentTime => &["current_time"], + BuiltinScalarFunction::DateBin => &["date_bin"], + BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], + BuiltinScalarFunction::DatePart => &["date_part", "datepart"], + BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], + BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], + BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], + BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], + BuiltinScalarFunction::ToTimestampNanos => &["to_timestamp_nanos"], + BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], + + // hashing functions + BuiltinScalarFunction::Digest => &["digest"], + BuiltinScalarFunction::MD5 => &["md5"], + BuiltinScalarFunction::SHA224 => &["sha224"], + BuiltinScalarFunction::SHA256 => &["sha256"], + BuiltinScalarFunction::SHA384 => &["sha384"], + BuiltinScalarFunction::SHA512 => &["sha512"], + + // encode/decode + BuiltinScalarFunction::Encode => &["encode"], + BuiltinScalarFunction::Decode => &["decode"], + + // other functions + BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + + // array functions + BuiltinScalarFunction::ArrayAppend => &[ + "array_append", + "list_append", + "array_push_back", + "list_push_back", + ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], + BuiltinScalarFunction::ArrayConcat => { + &["array_concat", "array_cat", "list_concat", "list_cat"] + } + BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"], + BuiltinScalarFunction::ArrayEmpty => &["empty"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], + BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], + BuiltinScalarFunction::Flatten => &["flatten"], + BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], + BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], + BuiltinScalarFunction::ArrayHas => { + &["array_has", "list_has", "array_contains", "list_contains"] + } + BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], + BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], + BuiltinScalarFunction::ArrayPopFront => { + &["array_pop_front", "list_pop_front"] + } + BuiltinScalarFunction::ArrayPopBack => &["array_pop_back", "list_pop_back"], + BuiltinScalarFunction::ArrayPosition => &[ + "array_position", + "list_position", + "array_indexof", + "list_indexof", + ], + BuiltinScalarFunction::ArrayPositions => { + &["array_positions", "list_positions"] + } + BuiltinScalarFunction::ArrayPrepend => &[ + "array_prepend", + "list_prepend", + "array_push_front", + "list_push_front", + ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], + BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], + BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], + BuiltinScalarFunction::ArrayRemoveAll => { + &["array_remove_all", "list_remove_all"] + } + BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], + BuiltinScalarFunction::ArrayReplaceN => { + &["array_replace_n", "list_replace_n"] + } + BuiltinScalarFunction::ArrayReplaceAll => { + &["array_replace_all", "list_replace_all"] + } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], + BuiltinScalarFunction::ArrayToString => &[ + "array_to_string", + "list_to_string", + "array_join", + "list_join", + ], + BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"], + BuiltinScalarFunction::Cardinality => &["cardinality"], + BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => { + &["array_intersect", "list_intersect"] + } + BuiltinScalarFunction::OverLay => &["overlay"], + BuiltinScalarFunction::Range => &["range", "generate_series"], + + // struct functions + BuiltinScalarFunction::Struct => &["struct"], } - BuiltinScalarFunction::Concat => &["concat"], - BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], - BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], - BuiltinScalarFunction::OctetLength => &["octet_length"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], - BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StartsWith => &["starts_with"], - BuiltinScalarFunction::Strpos => &["strpos"], - BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], - BuiltinScalarFunction::Uuid => &["uuid"], - - // regex functions - BuiltinScalarFunction::RegexpMatch => &["regexp_match"], - BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], - - // time/date functions - BuiltinScalarFunction::Now => &["now"], - BuiltinScalarFunction::CurrentDate => &["current_date"], - BuiltinScalarFunction::CurrentTime => &["current_time"], - BuiltinScalarFunction::DateBin => &["date_bin"], - BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], - BuiltinScalarFunction::DatePart => &["date_part", "datepart"], - BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], - BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], - BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], - BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], - BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], - - // hashing functions - BuiltinScalarFunction::Digest => &["digest"], - BuiltinScalarFunction::MD5 => &["md5"], - BuiltinScalarFunction::SHA224 => &["sha224"], - BuiltinScalarFunction::SHA256 => &["sha256"], - BuiltinScalarFunction::SHA384 => &["sha384"], - BuiltinScalarFunction::SHA512 => &["sha512"], - - // other functions - BuiltinScalarFunction::Struct => &["struct"], - BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], - - // array functions - BuiltinScalarFunction::ArrayAppend => &["array_append"], - BuiltinScalarFunction::ArrayConcat => &["array_concat"], - BuiltinScalarFunction::ArrayDims => &["array_dims"], - BuiltinScalarFunction::ArrayFill => &["array_fill"], - BuiltinScalarFunction::ArrayLength => &["array_length"], - BuiltinScalarFunction::ArrayNdims => &["array_ndims"], - BuiltinScalarFunction::ArrayPosition => &["array_position"], - BuiltinScalarFunction::ArrayPositions => &["array_positions"], - BuiltinScalarFunction::ArrayPrepend => &["array_prepend"], - BuiltinScalarFunction::ArrayRemove => &["array_remove"], - BuiltinScalarFunction::ArrayReplace => &["array_replace"], - BuiltinScalarFunction::ArrayToString => &["array_to_string"], - BuiltinScalarFunction::Cardinality => &["cardinality"], - BuiltinScalarFunction::MakeArray => &["make_array"], - BuiltinScalarFunction::TrimArray => &["trim_array"], } } impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction - write!(f, "{}", FUNCTION_TO_NAME.get(self).unwrap()) + write!(f, "{}", self.name()) } } impl FromStr for BuiltinScalarFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - if let Some(func) = NAME_TO_FUNCTION.get(name) { + if let Some(func) = name_to_function().get(name) { Ok(*func) } else { - Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))) + plan_err!("There is no built-in function named {name}") } } } +/// Creates a function that returns the return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! make_utf8_to_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 => $largeUtf8Type, + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 => $utf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 => $largeUtf8Type, + DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 => $utf8Type, + DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return plan_err!( + "The {:?} function can only accept strings, but got {:?}.", + name, + **value_type + ); + } + }, + data_type => { + return plan_err!( + "The {:?} function can only accept strings, but got {:?}.", + name, + data_type + ); + } + }) + } + }; +} +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::LargeBinary => DataType::Binary, + DataType::Null => DataType::Null, + _ => { + return plan_err!( + "The {name:?} function can only accept strings or binary arrays." + ); + } + }) +} + #[cfg(test)] mod tests { use super::*; @@ -534,9 +1735,10 @@ mod tests { // Test for BuiltinScalarFunction's Display and from_str() implementations. // For each variant in BuiltinScalarFunction, it converts the variant to a string // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 fn test_display_and_from_str() { - for (_, func_original) in NAME_TO_FUNCTION.iter() { + for (_, func_original) in name_to_function().iter() { let func_name = func_original.to_string(); let func_from_str = BuiltinScalarFunction::from_str(&func_name).unwrap(); assert_eq!(func_from_str, *func_original); diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index 9a18bdceabe4c..7a28839281697 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::array::ArrayRef; use arrow::array::NullArray; use arrow::datatypes::DataType; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single @@ -41,17 +41,21 @@ impl ColumnarValue { pub fn data_type(&self) -> DataType { match self { ColumnarValue::Array(array_value) => array_value.data_type().clone(), - ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), + ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(), } } /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is /// converted by repeating the same scalar multiple times. - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { + /// + /// # Errors + /// + /// Errors if `self` is a Scalar that fails to be converted into an array of size + pub fn into_array(self, num_rows: usize) -> Result { + Ok(match self { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows)?, + }) } /// null columnar values are implemented as a null array in order to pass batch diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index aba44061387a1..c31bd04eafa0f 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -19,7 +19,7 @@ use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{plan_err, DFSchema, DataFusionError, Result}; use std::collections::HashSet; /// Currently supported types by the coalesce function. @@ -102,9 +102,9 @@ impl CaseBuilder { } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { - return Err(DataFusionError::Plan(format!( + return plan_err!( "CASE expression 'then' values had multiple data types: {unique_types:?}" - ))); + ); } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 86480f9a96b54..958f4f4a34561 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,8 +17,6 @@ //! Expr module contains core type definition for `Expr`. -use crate::aggregate_function; -use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::udaf; @@ -26,7 +24,11 @@ use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; use crate::Operator; +use crate::{aggregate_function, ExprSchemable}; +use crate::{built_in_function, BuiltinScalarFunction}; use arrow::datatypes::DataType; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; @@ -38,7 +40,7 @@ use std::sync::Arc; /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. /// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// An `Expr` can compute its [DataType] /// and nullability, and has functions for building up complex /// expressions. /// @@ -79,10 +81,10 @@ use std::sync::Arc; /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Expr { /// An expression with a specific name. - Alias(Box, String), + Alias(Alias), /// A named reference to a qualified filed in a schema. Column(Column), /// A named reference to a variable in a registry. @@ -93,31 +95,30 @@ pub enum Expr { BinaryExpr(BinaryExpr), /// LIKE expression Like(Like), - /// Case-insensitive LIKE expression - ILike(Like), /// LIKE expression that uses regular expressions SimilarTo(Like), /// Negation of an expression. The expression's type must be a boolean to make sense. Not(Box), - /// Whether an expression is not Null. This expression is never null. + /// True if argument is not NULL, false otherwise. This expression itself is never NULL. IsNotNull(Box), - /// Whether an expression is Null. This expression is never null. + /// True if argument is NULL, false otherwise. This expression itself is never NULL. IsNull(Box), - /// Whether an expression is True. Boolean operation + /// True if argument is true, false otherwise. This expression itself is never NULL. IsTrue(Box), - /// Whether an expression is False. Boolean operation + /// True if argument is false, false otherwise. This expression itself is never NULL. IsFalse(Box), - /// Whether an expression is Unknown. Boolean operation + /// True if argument is NULL, false otherwise. This expression itself is never NULL. IsUnknown(Box), - /// Whether an expression is not True. Boolean operation + /// True if argument is FALSE or NULL, false otherwise. This expression itself is never NULL. IsNotTrue(Box), - /// Whether an expression is not False. Boolean operation + /// True if argument is TRUE OR NULL, false otherwise. This expression itself is never NULL. IsNotFalse(Box), - /// Whether an expression is not Unknown. Boolean operation + /// True if argument is TRUE or FALSE, false otherwise. This expression itself is never NULL. IsNotUnknown(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key + /// Returns the field of a [`arrow::array::ListArray`] or + /// [`arrow::array::StructArray`] by index or range GetIndexedField(GetIndexedField), /// Whether an expression is between a given range. Between(Between), @@ -147,16 +148,12 @@ pub enum Expr { TryCast(TryCast), /// A sort expression, that can be used to sort values. Sort(Sort), - /// Represents the call of a built-in scalar function with a set of arguments. + /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF(ScalarUDF), /// Represents the call of an aggregate built-in function with arguments. AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -165,10 +162,12 @@ pub enum Expr { InSubquery(InSubquery), /// Scalar subquery ScalarSubquery(Subquery), - /// Represents a reference to all fields in a schema. - Wildcard, - /// Represents a reference to all fields in a specific schema. - QualifiedWildcard { qualifier: String }, + /// Represents a reference to all available fields in a specific schema, + /// with an optional (schema) qualifier. + /// + /// This expr has to be resolved to a list of columns before translating logical + /// plan into physical plan. + Wildcard { qualifier: Option }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -180,8 +179,31 @@ pub enum Expr { OuterReferenceColumn(DataType, Column), } +/// Alias expression +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct Alias { + pub expr: Box, + pub relation: Option, + pub name: String, +} + +impl Alias { + /// Create an alias with an optional schema/field qualifier. + pub fn new( + expr: Expr, + relation: Option>, + name: impl Into, + ) -> Self { + Self { + expr: Box::new(expr), + relation: relation.map(|r| r.into()), + name: name.into(), + } + } +} + /// Binary expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct BinaryExpr { /// Left-hand side of the expression pub left: Box, @@ -258,12 +280,14 @@ impl Case { } /// LIKE expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Like { pub negated: bool, pub expr: Box, pub pattern: Box, pub escape_char: Option, + /// Whether to ignore case on comparing + pub case_insensitive: bool, } impl Like { @@ -273,18 +297,20 @@ impl Like { expr: Box, pattern: Box, escape_char: Option, + case_insensitive: bool, ) -> Self { Self { negated, expr, pattern, escape_char, + case_insensitive, } } } /// BETWEEN expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Between { /// The value to compare pub expr: Box, @@ -308,56 +334,96 @@ impl Between { } } -/// ScalarFunction expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. +pub enum ScalarFunctionDefinition { + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term + BuiltIn(BuiltinScalarFunction), + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +/// ScalarFunction expression invokes a built-in scalar function +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub fun: built_in_function::BuiltinScalarFunction, + pub func_def: ScalarFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, } impl ScalarFunction { - /// Create a new ScalarFunction expression - pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { - Self { fun, args } + // return the Function's name + pub fn name(&self) -> &str { + self.func_def.name() } } -/// ScalarUDF expression -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct ScalarUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, +impl ScalarFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => fun.name(), + ScalarFunctionDefinition::UDF(udf) => udf.name(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } } -impl ScalarUDF { - /// Create a new ScalarUDF expression - pub fn new(fun: Arc, args: Vec) -> Self { - Self { fun, args } +impl ScalarFunction { + /// Create a new ScalarFunction expression + pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + } } + + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf(udf: Arc, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + } + } +} + +/// Access a sub field of a nested type, such as `Field` or `List` +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum GetFieldAccess { + /// Named field, for example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key: Box }, + /// List range, for example `list[i:j]` + ListRange { start: Box, stop: Box }, } -/// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key -#[derive(Clone, PartialEq, Eq, Hash)] +/// Returns the field of a [`arrow::array::ListArray`] or +/// [`arrow::array::StructArray`] by `key`. See [`GetFieldAccess`] for +/// details. +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct GetIndexedField { - /// the expression to take the field from + /// The expression to take the field from pub expr: Box, /// The name of the field to take - pub key: ScalarValue, + pub field: GetFieldAccess, } impl GetIndexedField { /// Create a new GetIndexedField expression - pub fn new(expr: Box, key: ScalarValue) -> Self { - Self { expr, key } + pub fn new(expr: Box, field: GetFieldAccess) -> Self { + Self { expr, field } } } /// Cast expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Cast { /// The expression being cast pub expr: Box, @@ -373,7 +439,7 @@ impl Cast { } /// TryCast Expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct TryCast { /// The expression being cast pub expr: Box, @@ -389,7 +455,7 @@ impl TryCast { } /// SORT expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Sort { /// The expression to sort on pub expr: Box, @@ -410,11 +476,33 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum AggregateFunctionDefinition { + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function + UDF(Arc), + /// A aggregation function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -434,7 +522,24 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), + args, + distinct, + filter, + order_by, + } + } + + /// Create a new AggregateFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), args, distinct, filter, @@ -444,7 +549,7 @@ impl AggregateFunction { } /// Window function -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function pub fun: window_function::WindowFunction, @@ -478,7 +583,7 @@ impl WindowFunction { } // Exists expression. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { /// subquery that will produce a single column of data pub subquery: Subquery, @@ -493,7 +598,7 @@ impl Exists { } } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateUDF { /// The function pub fun: Arc, @@ -523,7 +628,7 @@ impl AggregateUDF { } /// InList expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct InList { /// The expression to compare pub expr: Box, @@ -545,7 +650,7 @@ impl InList { } /// IN subquery -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct InSubquery { /// The expression to compare pub expr: Box, @@ -566,10 +671,13 @@ impl InSubquery { } } -/// Placeholder -#[derive(Clone, PartialEq, Eq, Hash)] +/// Placeholder, representing bind parameter values such as `$1` or `$name`. +/// +/// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] +/// or can be specified directly using `PREPARE` statements. +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Placeholder { - /// The identifier of the parameter (e.g, $1 or $foo) + /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, /// The type the parameter will be filled in with pub data_type: Option, @@ -587,7 +695,7 @@ impl Placeholder { /// for Postgres definition. /// See /// for Apache Spark definition. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum GroupingSet { /// Rollup grouping sets Rollup(Vec), @@ -659,7 +767,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -675,7 +782,6 @@ impl Expr { Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", - Expr::ILike { .. } => "ILike", Expr::SimilarTo { .. } => "RLike", Expr::IsTrue(..) => "IsTrue", Expr::IsFalse(..) => "IsFalse", @@ -687,15 +793,13 @@ impl Expr { Expr::Negative(..) => "Negative", Expr::Not(..) => "Not", Expr::Placeholder(_) => "Placeholder", - Expr::QualifiedWildcard { .. } => "QualifiedWildcard", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", - Expr::ScalarUDF(..) => "ScalarUDF", Expr::ScalarVariable(..) => "ScalarVariable", Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", - Expr::Wildcard => "Wildcard", + Expr::Wildcard { .. } => "Wildcard", } } @@ -741,33 +845,100 @@ impl Expr { /// Return `self LIKE other` pub fn like(self, other: Expr) -> Expr { - Expr::Like(Like::new(false, Box::new(self), Box::new(other), None)) + Expr::Like(Like::new( + false, + Box::new(self), + Box::new(other), + None, + false, + )) } /// Return `self NOT LIKE other` pub fn not_like(self, other: Expr) -> Expr { - Expr::Like(Like::new(true, Box::new(self), Box::new(other), None)) + Expr::Like(Like::new( + true, + Box::new(self), + Box::new(other), + None, + false, + )) } /// Return `self ILIKE other` pub fn ilike(self, other: Expr) -> Expr { - Expr::ILike(Like::new(false, Box::new(self), Box::new(other), None)) + Expr::Like(Like::new( + false, + Box::new(self), + Box::new(other), + None, + true, + )) } /// Return `self NOT ILIKE other` pub fn not_ilike(self, other: Expr) -> Expr { - Expr::ILike(Like::new(true, Box::new(self), Box::new(other), None)) + Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) + } + + /// Return the name to use for the specific Expr, recursing into + /// `Expr::Sort` as appropriate + pub fn name_for_alias(&self) -> Result { + match self { + // call Expr::display_name() on a Expr::Sort will throw an error + Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(), + expr => expr.display_name(), + } + } + + /// Ensure `expr` has the name as `original_name` by adding an + /// alias if necessary. + pub fn alias_if_changed(self, original_name: String) -> Result { + let new_name = self.name_for_alias()?; + + if new_name == original_name { + return Ok(self); + } + + Ok(self.alias(original_name)) } /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - Expr::Alias(Box::new(self), name.into()) + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), + _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), + } + } + + /// Return `self AS name` alias expression with a specific qualifier + pub fn alias_qualified( + self, + relation: Option>, + name: impl Into, + ) -> Expr { + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + Box::new(expr.alias_qualified(relation, name)), + asc, + nulls_first, + )), + _ => Expr::Alias(Alias::new(self, relation, name.into())), + } } /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { - Expr::Alias(expr, _) => expr.as_ref().clone(), + Expr::Alias(alias) => alias.expr.as_ref().clone(), _ => self, } } @@ -848,10 +1019,94 @@ impl Expr { )) } + /// Return access to the named field. Example `expr["name"]` + /// + /// ## Access field "my_field" from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// { + /// "my_field": 123.34, + /// "other_field": "Boston", + /// } + /// ``` + /// + /// You can access column "my_field" with + /// + /// ``` + /// # use datafusion_expr::{col}; + /// let expr = col("c1") + /// .field("my_field"); + /// assert_eq!(expr.display_name().unwrap(), "c1[my_field]"); + /// ``` + pub fn field(self, name: impl Into) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::NamedStructField { + name: ScalarValue::from(name.into()), + }, + }) + } + + /// Return access to the element field. Example `expr["name"]` + /// + /// ## Example Access element 2 from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// [10, 20, 30, 40] + /// ``` + /// + /// You can access the value "30" with + /// + /// ``` + /// # use datafusion_expr::{lit, col, Expr}; + /// let expr = col("c1") + /// .index(lit(3)); + /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(3)]"); + /// ``` + pub fn index(self, key: Expr) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::ListIndex { key: Box::new(key) }, + }) + } + + /// Return elements between `1` based `start` and `stop`, for + /// example `expr[1:3]` + /// + /// ## Example: Access element 2, 3, 4 from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// [10, 20, 30, 40] + /// ``` + /// + /// You can access the value `[20, 30, 40]` with + /// + /// ``` + /// # use datafusion_expr::{lit, col}; + /// let expr = col("c1") + /// .range(lit(2), lit(4)); + /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]"); + /// ``` + pub fn range(self, start: Expr, stop: Expr) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::ListRange { + start: Box::new(start), + stop: Box::new(stop), + }, + }) + } + pub fn try_into_col(&self) -> Result { match self { Expr::Column(it) => Ok(it.clone()), - _ => plan_err!(format!("Could not coerce '{self}' into Column!")), + _ => plan_err!("Could not coerce '{self}' into Column!"), } } @@ -867,22 +1122,71 @@ impl Expr { pub fn contains_outer(&self) -> bool { !find_out_reference_exprs(self).is_empty() } + + /// Recursively find all [`Expr::Placeholder`] expressions, and + /// to infer their [`DataType`] from the context of their use. + /// + /// For example, gicen an expression like ` = $0` will infer `$0` to + /// have type `int32`. + pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { + self.transform(&|mut expr| { + // Default to assuming the arguments are the same type + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + }; + if let Expr::Between(Between { + expr, + negated: _, + low, + high, + }) = &mut expr + { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Ok(Transformed::Yes(expr)) + }) + } } -/// Format expressions for display as part of a logical plan. In many cases, this will produce -/// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Display for Expr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{self:?}") +// modifies expr if it is a placeholder with datatype of right +fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { + if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { + if data_type.is_none() { + let other_dt = other.get_type(schema); + match other_dt { + Err(e) => { + Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; + } + Ok(dt) => { + *data_type = Some(dt); + } + } + }; } + Ok(()) +} + +#[macro_export] +macro_rules! expr_vec_fmt { + ( $ARRAY:expr ) => {{ + $ARRAY + .iter() + .map(|e| format!("{e}")) + .collect::>() + .join(", ") + }}; } /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Debug for Expr { +impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Expr::Alias(expr, alias) => write!(f, "{expr:?} AS {alias}"), + Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), @@ -890,32 +1194,32 @@ impl fmt::Debug for Expr { Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { - write!(f, "{e:?} ")?; + write!(f, "{e} ")?; } for (w, t) in &case.when_then_expr { - write!(f, "WHEN {w:?} THEN {t:?} ")?; + write!(f, "WHEN {w} THEN {t} ")?; } if let Some(e) = &case.else_expr { - write!(f, "ELSE {e:?} ")?; + write!(f, "ELSE {e} ")?; } write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr:?} AS {data_type:?})") + write!(f, "CAST({expr} AS {data_type:?})") } Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr:?} AS {data_type:?})") + write!(f, "TRY_CAST({expr} AS {data_type:?})") } - Expr::Not(expr) => write!(f, "NOT {expr:?}"), - Expr::Negative(expr) => write!(f, "(- {expr:?})"), - Expr::IsNull(expr) => write!(f, "{expr:?} IS NULL"), - Expr::IsNotNull(expr) => write!(f, "{expr:?} IS NOT NULL"), - Expr::IsTrue(expr) => write!(f, "{expr:?} IS TRUE"), - Expr::IsFalse(expr) => write!(f, "{expr:?} IS FALSE"), - Expr::IsUnknown(expr) => write!(f, "{expr:?} IS UNKNOWN"), - Expr::IsNotTrue(expr) => write!(f, "{expr:?} IS NOT TRUE"), - Expr::IsNotFalse(expr) => write!(f, "{expr:?} IS NOT FALSE"), - Expr::IsNotUnknown(expr) => write!(f, "{expr:?} IS NOT UNKNOWN"), + Expr::Not(expr) => write!(f, "NOT {expr}"), + Expr::Negative(expr) => write!(f, "(- {expr})"), + Expr::IsNull(expr) => write!(f, "{expr} IS NULL"), + Expr::IsNotNull(expr) => write!(f, "{expr} IS NOT NULL"), + Expr::IsTrue(expr) => write!(f, "{expr} IS TRUE"), + Expr::IsFalse(expr) => write!(f, "{expr} IS FALSE"), + Expr::IsUnknown(expr) => write!(f, "{expr} IS UNKNOWN"), + Expr::IsNotTrue(expr) => write!(f, "{expr} IS NOT TRUE"), + Expr::IsNotFalse(expr) => write!(f, "{expr} IS NOT FALSE"), + Expr::IsNotUnknown(expr) => write!(f, "{expr} IS NOT UNKNOWN"), Expr::Exists(Exists { subquery, negated: true, @@ -928,12 +1232,12 @@ impl fmt::Debug for Expr { expr, subquery, negated: true, - }) => write!(f, "{expr:?} NOT IN ({subquery:?})"), + }) => write!(f, "{expr} NOT IN ({subquery:?})"), Expr::InSubquery(InSubquery { expr, subquery, negated: false, - }) => write!(f, "{expr:?} IN ({subquery:?})"), + }) => write!(f, "{expr} IN ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::Sort(Sort { @@ -942,9 +1246,9 @@ impl fmt::Debug for Expr { nulls_first, }) => { if *asc { - write!(f, "{expr:?} ASC")?; + write!(f, "{expr} ASC")?; } else { - write!(f, "{expr:?} DESC")?; + write!(f, "{expr} DESC")?; } if *nulls_first { write!(f, " NULLS FIRST") @@ -952,11 +1256,8 @@ impl fmt::Debug for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, false) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, false) + Expr::ScalarFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -965,12 +1266,12 @@ impl fmt::Debug for Expr { order_by, window_frame, }) => { - fmt_function(f, &fun.to_string(), false, args, false)?; + fmt_function(f, &fun.to_string(), false, args, true)?; if !partition_by.is_empty() { - write!(f, " PARTITION BY {partition_by:?}")?; + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; } if !order_by.is_empty() { - write!(f, " ORDER BY {order_by:?}")?; + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; } write!( f, @@ -980,35 +1281,19 @@ impl fmt::Debug for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } if let Some(ob) = order_by { - write!(f, " ORDER BY {:?}", ob)?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, &fun.name, false, args, false)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY {:?}", ob)?; + write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; } Ok(()) } @@ -1019,9 +1304,9 @@ impl fmt::Debug for Expr { high, }) => { if *negated { - write!(f, "{expr:?} NOT BETWEEN {low:?} AND {high:?}") + write!(f, "{expr} NOT BETWEEN {low} AND {high}") } else { - write!(f, "{expr:?} BETWEEN {low:?} AND {high:?}") + write!(f, "{expr} BETWEEN {low} AND {high}") } } Expr::Like(Like { @@ -1029,31 +1314,17 @@ impl fmt::Debug for Expr { expr, pattern, escape_char, + case_insensitive, }) => { - write!(f, "{expr:?}")?; + write!(f, "{expr}")?; + let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; if *negated { write!(f, " NOT")?; } if let Some(char) = escape_char { - write!(f, " LIKE {pattern:?} ESCAPE '{char}'") + write!(f, " {op_name} {pattern} ESCAPE '{char}'") } else { - write!(f, " LIKE {pattern:?}") - } - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - write!(f, "{expr:?}")?; - if *negated { - write!(f, " NOT")?; - } - if let Some(char) = escape_char { - write!(f, " ILIKE {pattern:?} ESCAPE '{char}'") - } else { - write!(f, " ILIKE {pattern:?}") + write!(f, " {op_name} {pattern}") } } Expr::SimilarTo(Like { @@ -1061,15 +1332,16 @@ impl fmt::Debug for Expr { expr, pattern, escape_char, + case_insensitive: _, }) => { - write!(f, "{expr:?}")?; + write!(f, "{expr}")?; if *negated { write!(f, " NOT")?; } if let Some(char) = escape_char { - write!(f, " SIMILAR TO {pattern:?} ESCAPE '{char}'") + write!(f, " SIMILAR TO {pattern} ESCAPE '{char}'") } else { - write!(f, " SIMILAR TO {pattern:?}") + write!(f, " SIMILAR TO {pattern}") } } Expr::InList(InList { @@ -1078,40 +1350,32 @@ impl fmt::Debug for Expr { negated, }) => { if *negated { - write!(f, "{expr:?} NOT IN ({list:?})") + write!(f, "{expr} NOT IN ([{}])", expr_vec_fmt!(list)) } else { - write!(f, "{expr:?} IN ({list:?})") + write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard => write!(f, "*"), - Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - write!(f, "({expr:?})[{key}]") - } + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => write!(f, "{qualifier}.*"), + None => write!(f, "*"), + }, + Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { + GetFieldAccess::NamedStructField { name } => { + write!(f, "({expr})[{name}]") + } + GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"), + GetFieldAccess::ListRange { start, stop } => { + write!(f, "({expr})[{start}:{stop}]") + } + }, Expr::GroupingSet(grouping_sets) => match grouping_sets { GroupingSet::Rollup(exprs) => { // ROLLUP (c0, c1, c2) - write!( - f, - "ROLLUP ({})", - exprs - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - ) + write!(f, "ROLLUP ({})", expr_vec_fmt!(exprs)) } GroupingSet::Cube(exprs) => { // CUBE (c0, c1, c2) - write!( - f, - "CUBE ({})", - exprs - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - ) + write!(f, "CUBE ({})", expr_vec_fmt!(exprs)) } GroupingSet::GroupingSets(lists_of_exprs) => { // GROUPING SETS ((c0), (c1, c2), (c3, c4)) @@ -1120,14 +1384,7 @@ impl fmt::Debug for Expr { "GROUPING SETS ({})", lists_of_exprs .iter() - .map(|exprs| format!( - "({})", - exprs - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - )) + .map(|exprs| format!("({})", expr_vec_fmt!(exprs))) .collect::>() .join(", ") ) @@ -1171,7 +1428,7 @@ fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result 2)". fn create_name(e: &Expr) -> Result { match e { - Expr::Alias(_, name) => Ok(name.clone()), + Expr::Alias(Alias { name, .. }) => Ok(name.clone()), Expr::Column(c) => Ok(c.flat_name()), Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), @@ -1186,30 +1443,13 @@ fn create_name(e: &Expr) -> Result { expr, pattern, escape_char, + case_insensitive, }) => { let s = format!( - "{} {} {} {}", + "{} {}{} {} {}", expr, - if *negated { "NOT LIKE" } else { "LIKE" }, - pattern, - if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - } - ); - Ok(s) - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let s = format!( - "{} {} {} {}", - expr, - if *negated { "NOT ILIKE" } else { "ILIKE" }, + if *negated { "NOT " } else { "" }, + if *case_insensitive { "ILIKE" } else { "LIKE" }, pattern, if let Some(char) = escape_char { format!("CHAR '{char}'") @@ -1224,6 +1464,7 @@ fn create_name(e: &Expr) -> Result { expr, pattern, escape_char, + case_insensitive: _, }) => { let s = format!( "{} {} {} {}", @@ -1315,16 +1556,24 @@ fn create_name(e: &Expr) -> Result { Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).name().clone()) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { let expr = create_name(expr)?; - Ok(format!("{expr}[{key}]")) - } - Expr::ScalarFunction(func) => { - create_function_name(&func.fun.to_string(), false, &func.args) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(&fun.name, false, args) + match field { + GetFieldAccess::NamedStructField { name } => { + Ok(format!("{expr}[{name}]")) + } + GetFieldAccess::ListIndex { key } => { + let key = create_name(key)?; + Ok(format!("{expr}[{key}]")) + } + GetFieldAccess::ListRange { start, stop } => { + let start = create_name(start)?; + let stop = create_name(stop)?; + Ok(format!("{expr}[{start}:{stop}]")) + } + } } + Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), Expr::WindowFunction(WindowFunction { fun, args, @@ -1335,48 +1584,48 @@ fn create_name(e: &Expr) -> Result { let mut parts: Vec = vec![create_function_name(&fun.to_string(), false, args)?]; if !partition_by.is_empty() { - parts.push(format!("PARTITION BY {partition_by:?}")); + parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by))); } if !order_by.is_empty() { - parts.push(format!("ORDER BY {order_by:?}")); + parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by))); } parts.push(format!("{window_frame}")); Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY {order_by:?}"); + let name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? + } + AggregateFunctionDefinition::UDF(..) => { + let names: Vec = + args.iter().map(create_name).collect::>()?; + names.join(",") + } }; - Ok(name) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } let mut info = String::new(); if let Some(fe) = filter { info += &format!(" FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ({:?})", ob); - } - Ok(format!("{}({}){}", fun.name, names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { @@ -1421,13 +1670,15 @@ fn create_name(e: &Expr) -> Result { Ok(format!("{expr} BETWEEN {low} AND {high}")) } } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create name does not support sort expression".to_string(), - )), - Expr::Wildcard => Ok("*".to_string()), - Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( - "Create name does not support qualified wildcard".to_string(), - )), + Expr::Sort { .. } => { + internal_err!("Create name does not support sort expression") + } + Expr::Wildcard { qualifier } => match qualifier { + Some(qualifier) => internal_err!( + "Create name does not support qualified wildcard, got {qualifier}" + ), + None => Ok("*".to_string()), + }, Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()), } } @@ -1459,7 +1710,6 @@ mod test { let expected = "CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END"; assert_eq!(expected, expr.canonical_name()); assert_eq!(expected, format!("{expr}")); - assert_eq!(expected, format!("{expr:?}")); assert_eq!(expected, expr.display_name()?); Ok(()) } @@ -1473,7 +1723,6 @@ mod test { let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); - assert_eq!(expected_canonical, format!("{expr:?}")); // note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. assert_eq!("Float32(1.23)", expr.display_name()?); @@ -1515,4 +1764,40 @@ mod test { Ok(()) } + + #[test] + fn test_logical_ops() { + assert_eq!( + format!("{}", lit(1u32).eq(lit(2u32))), + "UInt32(1) = UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).not_eq(lit(2u32))), + "UInt32(1) != UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).gt(lit(2u32))), + "UInt32(1) > UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).gt_eq(lit(2u32))), + "UInt32(1) >= UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).lt(lit(2u32))), + "UInt32(1) < UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).lt_eq(lit(2u32))), + "UInt32(1) <= UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).and(lit(2u32))), + "UInt32(1) AND UInt32(2)" + ); + assert_eq!( + format!("{}", lit(1u32).or(lit(2u32))), + "UInt32(1) OR UInt32(2)" + ); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 819f8e4aa7fc8..cedf1d845137f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,16 +19,19 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - ScalarFunction, TryCast, + Placeholder, ScalarFunction, TryCast, }; +use crate::function::PartitionEvaluatorFactory; +use crate::WindowUDF; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, - logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF, + logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, }; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; +use std::ops::Not; use std::sync::Arc; /// Create a column expression based on a qualified or unqualified column name. Will @@ -78,6 +81,37 @@ pub fn ident(name: impl Into) -> Expr { Expr::Column(Column::from_name(name)) } +/// Create placeholder value that will be filled in (such as `$1`) +/// +/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`] +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{placeholder}; +/// let p = placeholder("$0"); // $0, refers to parameter 1 +/// assert_eq!(p.to_string(), "$0") +/// ``` +pub fn placeholder(id: impl Into) -> Expr { + Expr::Placeholder(Placeholder { + id: id.into(), + data_type: None, + }) +} + +/// Create an '*' [`Expr::Wildcard`] expression that matches all columns +/// +/// # Example +/// +/// ```rust +/// # use datafusion_expr::{wildcard}; +/// let p = wildcard(); +/// assert_eq!(p.to_string(), "*") +/// ``` +pub fn wildcard() -> Expr { + Expr::Wildcard { qualifier: None } +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) @@ -101,6 +135,11 @@ pub fn or(left: Expr, right: Expr) -> Expr { )) } +/// Return a new expression with a logical NOT +pub fn not(expr: Expr) -> Expr { + expr.not() +} + /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( @@ -134,6 +173,17 @@ pub fn sum(expr: Expr) -> Expr { )) } +/// Create an expression to represent the array_agg() aggregate function +pub fn array_agg(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::ArrayAgg, + vec![expr], + false, + None, + None, + )) +} + /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( @@ -474,6 +524,7 @@ scalar_expr!(Cbrt, cbrt, num, "cube root of a number"); scalar_expr!(Sin, sin, num, "sine"); scalar_expr!(Cos, cos, num, "cosine"); scalar_expr!(Tan, tan, num, "tangent"); +scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); scalar_expr!(Tanh, tanh, num, "hyperbolic tangent"); @@ -499,7 +550,11 @@ scalar_expr!( scalar_expr!(Degrees, degrees, num, "converts radians to degrees"); scalar_expr!(Radians, radians, num, "converts degrees to radians"); nary_scalar_expr!(Round, round, "round to nearest integer"); -scalar_expr!(Trunc, trunc, num, "truncate toward zero"); +nary_scalar_expr!( + Trunc, + trunc, + "truncate toward zero, with optional precision" +); scalar_expr!(Abs, abs, num, "absolute value"); scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) "); scalar_expr!(Exp, exp, num, "exponential"); @@ -517,7 +572,7 @@ scalar_expr!( num, "returns the hexdecimal representation of an integer" ); -scalar_expr!(Uuid, uuid, , "Returns uuid v4 as a string value"); +scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); // array functions @@ -527,7 +582,54 @@ scalar_expr!( array element, "appends an element to the end of an array." ); + +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); + +scalar_expr!( + ArrayPopBack, + array_pop_back, + array, + "returns the array without the last element." +); + +scalar_expr!( + ArrayPopFront, + array_pop_front, + array, + "returns the array without the first element." +); + nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); +scalar_expr!( + ArrayHas, + array_has, + first_array second_array, + "returns true, if the element appears in the first array, otherwise false." +); +scalar_expr!( + ArrayEmpty, + array_empty, + array, + "returns 1 for an empty array or 0 for a non-empty array." +); +scalar_expr!( + ArrayHasAll, + array_has_all, + first_array second_array, + "returns true if each element of the second array appears in the first array; otherwise, it returns false." +); +scalar_expr!( + ArrayHasAny, + array_has_any, + first_array second_array, + "returns true if at least one element of the second array appears in the first array; otherwise, it returns false." +); +scalar_expr!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array." +); scalar_expr!( ArrayDims, array_dims, @@ -535,10 +637,16 @@ scalar_expr!( "returns an array of the array's dimensions." ); scalar_expr!( - ArrayFill, - array_fill, - element array, - "returns an array filled with copies of the given value." + ArrayElement, + array_element, + array element, + "extracts the element with the index n from the array." +); +scalar_expr!( + ArrayExcept, + array_except, + first_array second_array, + "Returns an array of the elements that appear in the first array but not in the second." ); scalar_expr!( ArrayLength, @@ -552,6 +660,12 @@ scalar_expr!( array, "returns the number of dimensions of the array." ); +scalar_expr!( + ArrayDistinct, + array_distinct, + array, + "return distinct values from the array after removing duplicates." +); scalar_expr!( ArrayPosition, array_position, @@ -570,24 +684,62 @@ scalar_expr!( array element, "prepends an element to the beginning of an array." ); +scalar_expr!( + ArrayRepeat, + array_repeat, + element count, + "returns an array containing element `count` times." +); scalar_expr!( ArrayRemove, array_remove, array element, - "removes all elements equal to the given value from the array." + "removes the first element from the array equal to the given value." +); +scalar_expr!( + ArrayRemoveN, + array_remove_n, + array element max, + "removes the first `max` elements from the array equal to the given value." +); +scalar_expr!( + ArrayRemoveAll, + array_remove_all, + array element, + "removes all elements from the array equal to the given value." ); scalar_expr!( ArrayReplace, array_replace, array from to, - "replaces a specified element with another specified element." + "replaces the first occurrence of the specified element with another specified element." +); +scalar_expr!( + ArrayReplaceN, + array_replace_n, + array from to max, + "replaces the first `max` occurrences of the specified element with another specified element." +); +scalar_expr!( + ArrayReplaceAll, + array_replace_all, + array from to, + "replaces all occurrences of the specified element with another specified element." +); +scalar_expr!( + ArraySlice, + array_slice, + array offset length, + "returns a slice of the array." ); scalar_expr!( ArrayToString, array_to_string, - array delimeter, + array delimiter, "converts each element to its text representation." ); +scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates."); + scalar_expr!( Cardinality, cardinality, @@ -600,10 +752,16 @@ nary_scalar_expr!( "returns an Arrow array using the specified input expressions." ); scalar_expr!( - TrimArray, - trim_array, - array n, - "removes the last n elements from the array." + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); + +nary_scalar_expr!( + Range, + gen_range, + "Returns a list of values in the range between start and stop with step." ); // string functions @@ -627,6 +785,8 @@ scalar_expr!( "converts the Unicode code point to a UTF8 character" ); scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input`, using the `algorithm`"); +scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"); +scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); @@ -658,6 +818,7 @@ scalar_expr!(SHA256, sha256, string, "SHA-256 hash"); scalar_expr!(SHA384, sha384, string, "SHA-384 hash"); scalar_expr!(SHA512, sha512, string, "SHA-512 hash"); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); +scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`"); scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); @@ -704,6 +865,11 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); +nary_scalar_expr!( + OverLay, + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring" +); // date functions scalar_expr!(DatePart, date_part, part date, "extracts a subfield from the date"); @@ -721,6 +887,12 @@ scalar_expr!( date, "converts a string to a `Timestamp(Microseconds, None)`" ); +scalar_expr!( + ToTimestampNanos, + to_timestamp_nanos, + date, + "converts a string to a `Timestamp(Nanoseconds, None)`" +); scalar_expr!( ToTimestampSeconds, to_timestamp_seconds, @@ -736,8 +908,31 @@ scalar_expr!( scalar_expr!(CurrentDate, current_date, ,"returns current UTC date as a [`DataType::Date32`] value"); scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the same value for all instances of now() in same statement"); scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value"); +scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); +scalar_expr!( + Isnan, + isnan, + num, + "returns true if a given number is +NaN or -NaN otherwise returns false" +); +scalar_expr!( + Iszero, + iszero, + num, + "returns true if a given number is +0.0 or -0.0 otherwise returns false" +); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); +scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); +scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); + +scalar_expr!( + Struct, + struct_fun, + val, + "returns a vector of fields from the struct" +); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -774,23 +969,44 @@ pub fn create_udf( /// The signature and state type must match the `Accumulator's implementation`. pub fn create_udaf( name: &str, - input_type: DataType, + input_type: Vec, return_type: Arc, volatility: Volatility, - accumulator: AccumulatorFunctionImplementation, + accumulator: AccumulatorFactoryFunction, state_type: Arc>, ) -> AggregateUDF { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); AggregateUDF::new( name, - &Signature::exact(vec![input_type], volatility), + &Signature::exact(input_type, volatility), &return_type, &accumulator, &state_type, ) } +/// Creates a new UDWF with a specific signature, state type and return type. +/// +/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. +/// +/// [`PartitionEvaluator`]: crate::PartitionEvaluator +pub fn create_udwf( + name: &str, + input_type: DataType, + return_type: Arc, + volatility: Volatility, + partition_evaluator_factory: PartitionEvaluatorFactory, +) -> WindowUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + WindowUDF::new( + name, + &Signature::exact(vec![input_type], volatility), + &return_type, + &partition_evaluator_factory, + ) +} + /// Calls a named built in function /// ``` /// use datafusion_expr::{col, lit, call_fn}; @@ -808,23 +1024,25 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; - use crate::lit; + use crate::{lit, ScalarFunctionDefinition}; #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); let col_not_null = ident("col2"); - assert_eq!(format!("{:?}", col_null.is_null()), "col1 IS NULL"); + assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL"); assert_eq!( - format!("{:?}", col_not_null.is_not_null()), + format!("{}", col_not_null.is_not_null()), "col2 IS NOT NULL" ); } macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - $FUNC(col("tableA.a")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = $FUNC(col("tableA.a")) { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); @@ -836,42 +1054,42 @@ mod test { } macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = vec![$(stringify!($arg)),*]; - let result = $FUNC( + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + vec![ $( col(stringify!($arg.to_string())) ),* - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } - - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = vec![$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } + ] + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} #[test] fn scalar_function_definitions() { @@ -880,6 +1098,7 @@ mod test { test_unary_scalar_expr!(Sin, sin); test_unary_scalar_expr!(Cos, cos); test_unary_scalar_expr!(Tan, tan); + test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); test_unary_scalar_expr!(Tanh, tanh); @@ -896,7 +1115,8 @@ mod test { test_unary_scalar_expr!(Radians, radians); test_nary_scalar_expr!(Round, round, input); test_nary_scalar_expr!(Round, round, input, decimal_places); - test_unary_scalar_expr!(Trunc, trunc); + test_nary_scalar_expr!(Trunc, trunc, num); + test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Abs, abs); test_unary_scalar_expr!(Signum, signum); test_unary_scalar_expr!(Exp, exp); @@ -904,6 +1124,9 @@ mod test { test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); + test_scalar_expr!(Nanvl, nanvl, x, y); + test_scalar_expr!(Isnan, isnan, input); + test_scalar_expr!(Iszero, iszero, input); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); @@ -912,6 +1135,8 @@ mod test { test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Digest, digest, string, algorithm); + test_scalar_expr!(Encode, encode, string, encoding); + test_scalar_expr!(Decode, decode, string, encoding); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); @@ -951,6 +1176,7 @@ mod test { test_scalar_expr!(SHA384, sha384, string); test_scalar_expr!(SHA512, sha512, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); + test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value); test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); @@ -966,26 +1192,41 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); + test_scalar_expr!(ArrayPopFront, array_pop_front, array); + test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); - test_scalar_expr!(ArrayFill, array_fill, element, array); test_scalar_expr!(ArrayLength, array_length, array, dimension); test_unary_scalar_expr!(ArrayNdims, array_ndims); test_scalar_expr!(ArrayPosition, array_position, array, element, index); test_scalar_expr!(ArrayPositions, array_positions, array, element); test_scalar_expr!(ArrayPrepend, array_prepend, array, element); + test_scalar_expr!(ArrayRepeat, array_repeat, element, count); test_scalar_expr!(ArrayRemove, array_remove, array, element); + test_scalar_expr!(ArrayRemoveN, array_remove_n, array, element, max); + test_scalar_expr!(ArrayRemoveAll, array_remove_all, array, element); test_scalar_expr!(ArrayReplace, array_replace, array, from, to); + test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max); + test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to); test_scalar_expr!(ArrayToString, array_to_string, array, delimiter); test_unary_scalar_expr!(Cardinality, cardinality); test_nary_scalar_expr!(MakeArray, array, input); - test_scalar_expr!(TrimArray, trim_array, array, n); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); + test_nary_scalar_expr!(OverLay, overlay, string, characters, position); + test_scalar_expr!(Levenshtein, levenshtein, string1, string2); + test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); + test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } #[test] fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = uuid() + { let name = BuiltinScalarFunction::Uuid; assert_eq!(name, fun); assert_eq!(0, args.len()); @@ -996,8 +1237,10 @@ mod test { #[test] fn digest_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - digest(col("tableA.a"), lit("md5")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; assert_eq!(name, fun); @@ -1006,4 +1249,34 @@ mod test { unreachable!(); } } + + #[test] + fn encode_function_definitions() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = encode(col("tableA.a"), lit("base64")) + { + let name = BuiltinScalarFunction::Encode; + assert_eq!(name, fun); + assert_eq!(2, args.len()); + } else { + unreachable!(); + } + } + + #[test] + fn decode_function_definitions() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = decode(col("tableA.a"), lit("hex")) + { + let name = BuiltinScalarFunction::Decode; + assert_eq!(name, fun); + assert_eq!(2, args.len()); + } else { + unreachable!(); + } + } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index ca9383709f5eb..1f04c80833f09 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,9 +17,10 @@ //! Expression rewriter +use crate::expr::Alias; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -137,6 +138,23 @@ pub fn unnormalize_col(expr: Expr) -> Expr { .expect("Unnormalize is infallable") } +/// Create a Column from the Scalar Expr +pub fn create_col_from_scalar_expr( + scalar_expr: &Expr, + subqry_alias: String, +) -> Result { + match scalar_expr { + Expr::Alias(Alias { name, .. }) => Ok(Column::new(Some(subqry_alias), name)), + Expr::Column(Column { relation: _, name }) => { + Ok(Column::new(Some(subqry_alias), name)) + } + _ => { + let scalar_column = scalar_expr.display_name()?; + Ok(Column::new(Some(subqry_alias), scalar_column)) + } + } +} + /// Recursively un-normalize all [`Column`] expressions in a list of expression trees #[inline] pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { @@ -204,8 +222,8 @@ fn coerce_exprs_for_schema( let new_type = dst_schema.field(idx).data_type(); if new_type != &expr.get_type(src_schema)? { match expr { - Expr::Alias(e, alias) => { - Ok(e.cast_to(new_type, src_schema)?.alias(alias)) + Expr::Alias(Alias { expr, name, .. }) => { + Ok(expr.cast_to(new_type, src_schema)?.alias(name)) } _ => expr.cast_to(new_type, src_schema), } @@ -216,13 +234,38 @@ fn coerce_exprs_for_schema( .collect::>() } +/// Recursively un-alias an expressions +#[inline] +pub fn unalias(expr: Expr) -> Expr { + match expr { + Expr::Alias(Alias { expr, .. }) => unalias(*expr), + _ => expr, + } +} + +/// Rewrites `expr` using `rewriter`, ensuring that the output has the +/// same name as `expr` prior to rewrite, adding an alias if necessary. +/// +/// This is important when optimizing plans to ensure the output +/// schema of plan nodes don't change after optimization +pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +where + R: TreeNodeRewriter, +{ + let original_name = expr.name_for_alias()?; + let expr = expr.rewrite(rewriter)?; + expr.alias_if_changed(original_name) +} + #[cfg(test)] mod test { use super::*; - use crate::{col, lit}; + use crate::expr::Sort; + use crate::{col, lit, Cast}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; + use std::ops::Add; #[derive(Default)] struct RecordingRewriter { @@ -233,12 +276,12 @@ mod test { type N = Expr; fn pre_visit(&mut self, expr: &Expr) -> Result { - self.v.push(format!("Previsited {expr:?}")); + self.v.push(format!("Previsited {expr}")); Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { - self.v.push(format!("Mutated {expr:?}")); + self.v.push(format!("Mutated {expr}")); Ok(expr) } } @@ -331,7 +374,7 @@ mod test { let error = normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[]) .unwrap_err() - .to_string(); + .strip_backtrace(); assert_eq!( error, r#"Schema error: No field named b. Valid fields are "tableA".a."# @@ -370,4 +413,64 @@ mod test { ] ) } + + #[test] + fn test_rewrite_preserving_name() { + test_rewrite(col("a"), col("a")); + + test_rewrite(col("a"), col("b")); + + // cast data types + test_rewrite( + col("a"), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + ); + + // change literal type from i32 to i64 + test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); + + // SortExpr a+1 ==> b + 2 + test_rewrite( + Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), + Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), + ); + } + + /// rewrites `expr_from` to `rewrite_to` using + /// `rewrite_preserving_name` verifying the result is `expected_expr` + fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { + struct TestRewriter { + rewrite_to: Expr, + } + + impl TreeNodeRewriter for TestRewriter { + type N = Expr; + + fn mutate(&mut self, _: Expr) -> Result { + Ok(self.rewrite_to.clone()) + } + } + + let mut rewriter = TestRewriter { + rewrite_to: rewrite_to.clone(), + }; + let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + + let original_name = match &expr_from { + Expr::Sort(Sort { expr, .. }) => expr.display_name(), + expr => expr.display_name(), + } + .unwrap(); + + let new_name = match &expr { + Expr::Sort(Sort { expr, .. }) => expr.display_name(), + expr => expr.display_name(), + } + .unwrap(); + + assert_eq!( + original_name, new_name, + "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" + ) + } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index ce832d11fd598..c87a724d5646b 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -17,7 +17,7 @@ //! Rewrite for order by expressions -use crate::expr::Sort; +use crate::expr::{Alias, Sort}; use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; use datafusion_common::tree_node::{Transformed, TreeNode}; @@ -137,12 +137,12 @@ fn rewrite_in_terms_of_projection( /// Does the underlying expr match e? /// so avg(c) as average will match avgc -fn expr_match(needle: &Expr, haystack: &Expr) -> bool { +fn expr_match(needle: &Expr, expr: &Expr) -> bool { // check inside aliases - if let Expr::Alias(haystack, _) = &haystack { - haystack.as_ref() == needle + if let Expr::Alias(Alias { expr, .. }) = &expr { + expr.as_ref() == needle } else { - haystack == needle + expr == needle } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3c68c4acd7dd0..e5b0185d90e0b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,17 +17,21 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, BinaryExpr, Cast, GetIndexedField, InList, - InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; -use crate::field_util::get_indexed_field; +use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::{ - aggregate_function, function, window_function, LogicalPlan, Projection, Subquery, -}; +use crate::type_coercion::functions::data_types; +use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; -use arrow::datatypes::DataType; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError, ExprSchema, Result}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{ + internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, + DataFusionError, ExprSchema, Result, +}; +use std::collections::HashMap; use std::sync::Arc; /// trait to allow expr to typable with respect to a schema @@ -38,6 +42,9 @@ pub trait ExprSchemable { /// given a schema, return the nullability of the expr fn nullable(&self, input_schema: &S) -> Result; + /// given a schema, return the expr's optional metadata + fn metadata(&self, schema: &S) -> Result>; + /// convert to a field with respect to a schema fn to_field(&self, input_schema: &DFSchema) -> Result; @@ -60,7 +67,7 @@ impl ExprSchemable for Expr { /// (e.g. `[utf8] + [bool]`). fn get_type(&self, schema: &S) -> Result { match self { - Expr::Alias(expr, name) => match &**expr { + Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { None => schema.data_type(&Column::from_name(name)).cloned(), Some(dt) => Ok(dt.clone()), @@ -71,44 +78,62 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.get_datatype()), + Expr::Literal(l) => Ok(l.data_type()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let data_types = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - function::return_type(fun, &data_types) + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&arg_data_types)?) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - window_function::return_type(fun, &data_types) - } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregate_function::return_type(fun, &data_types) + fun.return_type(&data_types) } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -131,32 +156,25 @@ impl ExprSchemable for Expr { ref right, ref op, }) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?), - Expr::Like { .. } | Expr::ILike { .. } | Expr::SimilarTo { .. } => { - Ok(DataType::Boolean) - } + Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { - DataFusionError::Plan( - "Placeholder type could not be resolved".to_owned(), - ) + plan_datafusion_err!("Placeholder type could not be resolved") }) } - Expr::Wildcard => { + Expr::Wildcard { qualifier } => { // Wildcard do not really have a type and do not appear in projections - Ok(DataType::Null) + match qualifier { + Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), + None => Ok(DataType::Null) + } } - Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( - "QualifiedWildcard expressions are not valid in a logical query plan" - .to_owned(), - )), Expr::GroupingSet(_) => { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + field_for_index(expr, field, schema).map(|x| x.data_type().clone()) } } } @@ -172,12 +190,40 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &S) -> Result { match self { - Expr::Alias(expr, _) + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) - | Expr::Sort(Sort { expr, .. }) - | Expr::InList(InList { expr, .. }) => expr.nullable(input_schema), - Expr::Between(Between { expr, .. }) => expr.nullable(input_schema), + | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema), + + Expr::InList(InList { expr, list, .. }) => { + // Avoid inspecting too many expressions. + const MAX_INSPECT_LIMIT: usize = 6; + // Stop if a nullable expression is found or an error occurs. + let has_nullable = std::iter::once(expr.as_ref()) + .chain(list) + .take(MAX_INSPECT_LIMIT) + .find_map(|e| { + e.nullable(input_schema) + .map(|nullable| if nullable { Some(()) } else { None }) + .transpose() + }) + .transpose()?; + Ok(match has_nullable { + // If a nullable subexpression is found, the result may also be nullable. + Some(_) => true, + // If the list is too long, we assume it is nullable. + None if list.len() + 1 > MAX_INSPECT_LIMIT => true, + // All the subexpressions are non-nullable, so the result must be non-nullable. + _ => false, + }) + } + + Expr::Between(Between { + expr, low, high, .. + }) => Ok(expr.nullable(input_schema)? + || low.nullable(input_schema)? + || high.nullable(input_schema)?), + Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), @@ -202,10 +248,8 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) @@ -225,19 +269,15 @@ impl ExprSchemable for Expr { ref right, .. }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Like(Like { expr, .. }) => expr.nullable(input_schema), - Expr::ILike(Like { expr, .. }) => expr.nullable(input_schema), - Expr::SimilarTo(Like { expr, .. }) => expr.nullable(input_schema), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( - "QualifiedWildcard expressions are not valid in a logical query plan" - .to_owned(), - )), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) + } + Expr::Wildcard { .. } => internal_err!( + "Wildcard expressions are not valid in a logical query plan" + ), + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear @@ -247,6 +287,14 @@ impl ExprSchemable for Expr { } } + fn metadata(&self, schema: &S) -> Result> { + match self { + Expr::Column(c) => Ok(schema.metadata(c)?.clone()), + Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), + _ => Ok(HashMap::new()), + } + } + /// Returns a [arrow::datatypes::Field] compatible with this expression. /// /// So for example, a projected expression `col(c1) + col(c2)` is @@ -258,12 +306,21 @@ impl ExprSchemable for Expr { &c.name, self.get_type(input_schema)?, self.nullable(input_schema)?, - )), + ) + .with_metadata(self.metadata(input_schema)?)), + Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( + relation.clone(), + name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ) + .with_metadata(self.metadata(input_schema)?)), _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, self.nullable(input_schema)?, - )), + ) + .with_metadata(self.metadata(input_schema)?)), } } @@ -291,13 +348,33 @@ impl ExprSchemable for Expr { _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), } } else { - Err(DataFusionError::Plan(format!( - "Cannot automatically convert {this_type:?} to {cast_to_type:?}" - ))) + plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") } } } +/// return the schema [`Field`] for the type referenced by `get_indexed_field` +fn field_for_index( + expr: &Expr, + field: &GetFieldAccess, + schema: &S, +) -> Result { + let expr_dt = expr.get_type(schema)?; + match field { + GetFieldAccess::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } + } + GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.get_type(schema)?, + }, + GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange { + start_dt: start.get_type(schema)?, + stop_dt: stop.get_type(schema)?, + }, + } + .get_accessed_field(&expr_dt) +} + /// cast subquery in InSubquery/ScalarSubquery to a given type. pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { @@ -362,6 +439,71 @@ mod tests { test_is_expr_nullable!(is_not_unknown); } + #[test] + fn test_between_nullability() { + let get_schema = |nullable| { + MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(nullable) + }; + + let expr = col("foo").between(lit(1), lit(2)); + assert!(!expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.nullable(&get_schema(true)).unwrap()); + + let null = lit(ScalarValue::Int32(None)); + + let expr = col("foo").between(null.clone(), lit(2)); + assert!(expr.nullable(&get_schema(false)).unwrap()); + + let expr = col("foo").between(lit(1), null.clone()); + assert!(expr.nullable(&get_schema(false)).unwrap()); + + let expr = col("foo").between(null.clone(), null); + assert!(expr.nullable(&get_schema(false)).unwrap()); + } + + #[test] + fn test_inlist_nullability() { + let get_schema = |nullable| { + MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(nullable) + }; + + let expr = col("foo").in_list(vec![lit(1); 5], false); + assert!(!expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.nullable(&get_schema(true)).unwrap()); + // Testing nullable() returns an error. + assert!(expr + .nullable(&get_schema(false).with_error_on_nullable(true)) + .is_err()); + + let null = lit(ScalarValue::Int32(None)); + let expr = col("foo").in_list(vec![null, lit(1)], false); + assert!(expr.nullable(&get_schema(false)).unwrap()); + + // Testing on long list + let expr = col("foo").in_list(vec![lit(1); 6], false); + assert!(expr.nullable(&get_schema(false)).unwrap()); + } + + #[test] + fn test_like_nullability() { + let get_schema = |nullable| { + MockExprSchema::new() + .with_data_type(DataType::Utf8) + .with_nullable(nullable) + }; + + let expr = col("foo").like(lit("bar")); + assert!(!expr.nullable(&get_schema(false)).unwrap()); + assert!(expr.nullable(&get_schema(true)).unwrap()); + + let expr = col("foo").like(lit(ScalarValue::Utf8(None))); + assert!(expr.nullable(&get_schema(false)).unwrap()); + } + #[test] fn expr_schema_data_type() { let expr = col("foo"); @@ -372,10 +514,46 @@ mod tests { ); } + #[test] + fn test_expr_metadata() { + let mut meta = HashMap::new(); + meta.insert("bar".to_string(), "buzz".to_string()); + let expr = col("foo"); + let schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_metadata(meta.clone()); + + // col and alias should be metadata-preserving + assert_eq!(meta, expr.metadata(&schema).unwrap()); + assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap()); + + // cast should drop input metadata since the type has changed + assert_eq!( + HashMap::new(), + expr.clone() + .cast_to(&DataType::Int64, &schema) + .unwrap() + .metadata(&schema) + .unwrap() + ); + + let schema = DFSchema::new_with_metadata( + vec![DFField::new_unqualified("foo", DataType::Int32, true) + .with_metadata(meta.clone())], + HashMap::new(), + ) + .unwrap(); + + // verify to_field method populates metadata + assert_eq!(&meta, expr.to_field(&schema).unwrap().metadata()); + } + #[derive(Debug)] struct MockExprSchema { nullable: bool, data_type: DataType, + error_on_nullable: bool, + metadata: HashMap, } impl MockExprSchema { @@ -383,6 +561,8 @@ mod tests { Self { nullable: false, data_type: DataType::Null, + error_on_nullable: false, + metadata: HashMap::new(), } } @@ -395,15 +575,33 @@ mod tests { self.data_type = data_type; self } + + fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self { + self.error_on_nullable = error_on_nullable; + self + } + + fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } } impl ExprSchema for MockExprSchema { fn nullable(&self, _col: &Column) -> Result { - Ok(self.nullable) + if self.error_on_nullable { + internal_err!("nullable error") + } else { + Ok(self.nullable) + } } fn data_type(&self, _col: &Column) -> Result<&DataType> { Ok(&self.data_type) } + + fn metadata(&self, _col: &Column) -> Result<&HashMap> { + Ok(&self.metadata) + } } } diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index feb96928c1206..3829a2086b26e 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -18,36 +18,82 @@ //! Utility functions for complex field access use arrow::datatypes::{DataType, Field}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, +}; -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] -/// # Error -/// Errors if -/// * the `data_type` is not a Struct or, -/// * there is no field key is not of the required index type -pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { - match (data_type, key) { - (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - Ok(Field::new(i.to_string(), lt.data_type().clone(), true)) - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - Err(DataFusionError::Plan( - "Struct based indexed access requires a non empty string".to_string(), - )) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) +/// Types of the field access expression of a nested type, such as `Field` or `List` +pub enum GetFieldAccessSchema { + /// Named field, For example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key_dt: DataType }, + /// List range, for example `list[i:j]` + ListRange { + start_dt: DataType, + stop_dt: DataType, + }, +} + +impl GetFieldAccessSchema { + /// Returns the schema [`Field`] from a [`DataType::List`] or + /// [`DataType::Struct`] indexed by this structure + /// + /// # Error + /// Errors if + /// * the `data_type` is not a Struct or a List, + /// * the `data_type` of the name/index/start-stop do not match a supported index type + pub fn get_accessed_field(&self, data_type: &DataType) -> Result { + match self { + Self::NamedStructField{ name } => { + match (data_type, name) { + (DataType::Map(fields, _), _) => { + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second columnis the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(Field::new("map", value_field.data_type().clone(), true)) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.as_ref().clone()) + } + } + (DataType::Struct(_), _) => plan_err!( + "Only utf8 strings are valid as an indexed field in a struct" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), + } + } + Self::ListIndex{ key_dt } => { + match (data_type, key_dt) { + (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), + (DataType::List(_), _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } + Self::ListRange{ start_dt, stop_dt } => { + match (data_type, start_dt, stop_dt) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), + (DataType::List(_), _, _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } } } - (DataType::Struct(_), _) => Err(DataFusionError::Plan( - "Only utf8 strings are valid as an indexed field in a struct".to_string(), - )), - (DataType::List(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list".to_string(), - )), - (other, _) => Err(DataFusionError::Plan( - format!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}") - )), } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bec672ab6f6c3..3e30a5574be0e 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,17 +17,13 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::function_err::generate_signature_error_msg; -use crate::nullif::SUPPORTED_NULLIF_TYPES; -use crate::type_coercion::functions::data_types; -use crate::ColumnarValue; -use crate::{ - conditional_expressions, struct_expressions, Accumulator, BuiltinScalarFunction, - Signature, TypeSignature, -}; -use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{DataFusionError, Result}; +use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; +use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use arrow::datatypes::DataType; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::Result; use std::sync::Arc; +use strum::IntoEnumIterator; /// Scalar function /// @@ -46,788 +42,66 @@ pub type ReturnTypeFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype. -pub type AccumulatorFunctionImplementation = +pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; +/// Factory that creates a PartitionEvaluator for the given window +/// function +pub type PartitionEvaluatorFactory = + Arc Result> + Send + Sync>; + /// Factory that returns the types used by an aggregator to serialize /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; -macro_rules! make_utf8_to_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 => $largeUtf8Type, - DataType::Utf8 => $utf8Type, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {:?} function can only accept strings.", - name - ))); - } - }) - } - }; -} - -make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - -fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::LargeBinary => DataType::Binary, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {name:?} function can only accept strings or binary arrays." - ))); - } - }) -} - /// Returns the datatype of the scalar function +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::return_type` instead" +)] pub fn return_type( fun: &BuiltinScalarFunction, input_expr_types: &[DataType], ) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - if input_expr_types.is_empty() && !fun.supports_zero_argument() { - return Err(DataFusionError::Plan(generate_signature_error_msg( - fun, - input_expr_types, - ))); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun)).map_err(|_| { - DataFusionError::Plan(generate_signature_error_msg(fun, input_expr_types)) - })?; - - // the return type of the built in function. - // Some built-in functions' return type depends on the incoming type. - match fun { - BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept fixed size list as the args." - ))), - }, - BuiltinScalarFunction::ArrayDims => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayFill => Ok(DataType::List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::ArrayLength => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayNdims => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayPosition => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayPositions => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Cardinality => Ok(DataType::UInt64), - BuiltinScalarFunction::MakeArray => Ok(DataType::List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Ascii => Ok(DataType::Int32), - BuiltinScalarFunction::BitLength => { - utf8_to_int_type(&input_expr_types[0], "bit_length") - } - BuiltinScalarFunction::Btrim => utf8_to_str_type(&input_expr_types[0], "btrim"), - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } - BuiltinScalarFunction::Chr => Ok(DataType::Utf8), - BuiltinScalarFunction::Coalesce => { - // COALESCE has multiple args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|types| types[0].clone()) - } - BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), - BuiltinScalarFunction::DatePart => Ok(DataType::Float64), - BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { - match input_expr_types[1] { - DataType::Timestamp(TimeUnit::Nanosecond, _) | DataType::Utf8 => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - DataType::Timestamp(TimeUnit::Second, _) => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept timestamp as the second arg." - ))), - } - } - BuiltinScalarFunction::InitCap => { - utf8_to_str_type(&input_expr_types[0], "initcap") - } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lower => utf8_to_str_type(&input_expr_types[0], "lower"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"), - BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), - BuiltinScalarFunction::NullIf => { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|typs| typs[0].clone()) - } - BuiltinScalarFunction::OctetLength => { - utf8_to_int_type(&input_expr_types[0], "octet_length") - } - BuiltinScalarFunction::Pi => Ok(DataType::Float64), - BuiltinScalarFunction::Random => Ok(DataType::Float64), - BuiltinScalarFunction::Uuid => Ok(DataType::Utf8), - BuiltinScalarFunction::RegexpReplace => { - utf8_to_str_type(&input_expr_types[0], "regex_replace") - } - BuiltinScalarFunction::Repeat => utf8_to_str_type(&input_expr_types[0], "repeat"), - BuiltinScalarFunction::Replace => { - utf8_to_str_type(&input_expr_types[0], "replace") - } - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => utf8_to_str_type(&input_expr_types[0], "right"), - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"), - BuiltinScalarFunction::SHA224 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") - } - BuiltinScalarFunction::SHA256 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") - } - BuiltinScalarFunction::SHA384 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") - } - BuiltinScalarFunction::SHA512 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") - } - BuiltinScalarFunction::Digest => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") - } - BuiltinScalarFunction::SplitPart => { - utf8_to_str_type(&input_expr_types[0], "split_part") - } - BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean), - BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"), - BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"), - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Utf8 - } - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The to_hex function can only accept integers.".to_string(), - )); - } - }), - BuiltinScalarFunction::ToTimestamp => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::ToTimestampMillis => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - BuiltinScalarFunction::ToTimestampMicros => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - BuiltinScalarFunction::ToTimestampSeconds => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - BuiltinScalarFunction::FromUnixtime => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - BuiltinScalarFunction::Now => Ok(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+00:00".into()), - )), - BuiltinScalarFunction::CurrentDate => Ok(DataType::Date32), - BuiltinScalarFunction::CurrentTime => Ok(DataType::Time64(TimeUnit::Nanosecond)), - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"), - BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { - DataType::LargeUtf8 => { - DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, true))) - } - DataType::Utf8 => { - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))) - } - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The regexp_extract function can only accept strings.".to_string(), - )); - } - }), - - BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(DataType::Int64), - - BuiltinScalarFunction::Power => match &input_expr_types[0] { - DataType::Int64 => Ok(DataType::Int64), - _ => Ok(DataType::Float64), - }, - - BuiltinScalarFunction::Struct => { - let return_fields = input_expr_types - .iter() - .enumerate() - .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) - .collect::>(); - Ok(DataType::Struct(Fields::from(return_fields))) - } - - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), - }, - - BuiltinScalarFunction::Log => match &input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), - }, - - BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8), - - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => match input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), - }, - } + fun.return_type(input_expr_types) } /// Return the [`Signature`] supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::signature` instead" +)] pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - - // for now, the list is small, as we do not have many built-in functions. - match fun { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayConcat => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayDims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayFill => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayLength => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayNdims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayPosition => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayPositions => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayRemove => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayReplace => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayToString => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::Cardinality => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::MakeArray => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::TrimArray => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { - Signature::variadic(vec![DataType::Utf8], fun.volatility()) - } - BuiltinScalarFunction::Coalesce => Signature::variadic( - conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::SHA224 - | BuiltinScalarFunction::SHA256 - | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 - | BuiltinScalarFunction::MD5 => Signature::uniform( - 1, - vec![ - DataType::Utf8, - DataType::LargeUtf8, - DataType::Binary, - DataType::LargeBinary, - ], - fun.volatility(), - ), - BuiltinScalarFunction::Ascii - | BuiltinScalarFunction::BitLength - | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => Signature::uniform( - 1, - vec![DataType::Utf8, DataType::LargeUtf8], - fun.volatility(), - ), - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) - } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::LargeUtf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::LargeUtf8, - ]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Left - | BuiltinScalarFunction::Repeat - | BuiltinScalarFunction::Right => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestamp => Signature::uniform( - 1, - vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( - 1, - vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( - 1, - vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( - 1, - vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::FromUnixtime => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) - } - BuiltinScalarFunction::Digest => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Binary, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeBinary, DataType::Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DateTrunc => Signature::exact( - vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DateBin => { - let base_sig = |array_type: TimeUnit| { - vec![ - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Timestamp(array_type.clone(), None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::DayTime), - DataType::Timestamp(array_type.clone(), None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Timestamp(array_type.clone(), None), - ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::DayTime), - DataType::Timestamp(array_type, None), - ]), - ] - }; - - let full_sig = [ - TimeUnit::Nanosecond, - TimeUnit::Microsecond, - TimeUnit::Millisecond, - TimeUnit::Second, - ] - .into_iter() - .map(base_sig) - .collect::>() - .concat(); - - Signature::one_of(full_sig, fun.volatility()) - } - BuiltinScalarFunction::DatePart => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date32]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Second, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Microsecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Millisecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - ]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::LargeUtf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::LargeUtf8, - DataType::Int64, - ]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - fun.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::Int64, - ]), - ], - fun.volatility(), - ), + fun.signature() +} - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of( - vec![TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ])], - fun.volatility(), - ) - } - BuiltinScalarFunction::RegexpReplace => Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - ], - fun.volatility(), - ), +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + AggregateFunction::iter() + .map(|func| func.to_string()) + .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) + .collect() + } else { + // All scalar functions and aggregate functions + BuiltinScalarFunction::iter() + .map(|func| func.to_string()) + .chain(AggregateFunction::iter().map(|func| func.to_string())) + .collect() + }; + find_closest_match(valid_funcs, input_function_name) +} - BuiltinScalarFunction::NullIf => { - Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility()) - } - BuiltinScalarFunction::RegexpMatch => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Utf8, - ]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Pi => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Uuid => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Power => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Round => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float64]), - TypeSignature::Exact(vec![DataType::Float32]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Log => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64]), - TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Factorial => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) - } - BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![DataType::Int64], fun.volatility()) - } - BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => { - // math expressions expect 1 argument of type f64 or f32 - // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we - // return the best approximation for it (in f64). - // We accept f32 because in this case it is clear that the best approximation - // will be as good as the number of digits in the number - Signature::uniform( - 1, - vec![DataType::Float64, DataType::Float32], - fun.volatility(), - ) - } - BuiltinScalarFunction::Now - | BuiltinScalarFunction::CurrentDate - | BuiltinScalarFunction::CurrentTime => { - Signature::uniform(0, vec![], fun.volatility()) - } - } +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty } diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs deleted file mode 100644 index e97e0f92cd80f..0000000000000 --- a/datafusion/expr/src/function_err.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Function_err module enhances frontend error messages for unresolved functions due to incorrect parameters, -//! by providing the correct function signatures. -//! -//! For example, a query like `select round(3.14, 1.1);` would yield: -//! ```text -//! Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. -//! Candidate functions: -//! round(Float64, Int64) -//! round(Float32, Int64) -//! round(Float64) -//! round(Float32) -//! ``` - -use crate::function::signature; -use crate::{ - AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, TypeSignature, -}; -use arrow::datatypes::DataType; -use datafusion_common::utils::datafusion_strsim; -use strum::IntoEnumIterator; - -impl TypeSignature { - fn to_string_repr(&self) -> Vec { - match self { - TypeSignature::Variadic(types) => { - vec![format!("{}, ..", join_types(types, "/"))] - } - TypeSignature::Uniform(arg_count, valid_types) => { - vec![std::iter::repeat(join_types(valid_types, "/")) - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::Exact(types) => { - vec![join_types(types, ", ")] - } - TypeSignature::Any(arg_count) => { - vec![std::iter::repeat("Any") - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], - TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], - TypeSignature::OneOf(sigs) => { - sigs.iter().flat_map(|s| s.to_string_repr()).collect() - } - } - } -} - -/// Helper function to join types with specified delimiter. -fn join_types(types: &[T], delimiter: &str) -> String { - types - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(delimiter) -} - -/// Creates a detailed error message for a function with wrong signature. -pub fn generate_signature_error_msg( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> String { - let candidate_signatures = signature(fun) - .type_signature - .to_string_repr() - .iter() - .map(|args_str| format!("\t{}({})", fun, args_str)) - .collect::>() - .join("\n"); - - format!( - "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", - fun, join_types(input_expr_types, ", "), candidate_signatures - ) -} - -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { - let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) - }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty -} - -/// Suggest a valid function based on an invalid input function name -pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { - let valid_funcs = if is_window_func { - // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() - } else { - // All scalar functions and aggregate functions - BuiltinScalarFunction::iter() - .map(|func| func.to_string()) - .chain(AggregateFunction::iter().map(|func| func.to_string())) - .collect() - }; - find_closest_match(valid_funcs, input_function_name) -} diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs new file mode 100644 index 0000000000000..5d34fe91c3ace --- /dev/null +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -0,0 +1,3307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetic library + +use std::borrow::Borrow; +use std::fmt::{self, Display, Formatter}; +use std::ops::{AddAssign, SubAssign}; + +use crate::type_coercion::binary::get_result_type; +use crate::Operator; + +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; +use arrow::datatypes::{IntervalUnit, TimeUnit}; +use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; + +macro_rules! get_extreme_value { + ($extreme:ident, $value:expr) => { + match $value { + DataType::UInt8 => ScalarValue::UInt8(Some(u8::$extreme)), + DataType::UInt16 => ScalarValue::UInt16(Some(u16::$extreme)), + DataType::UInt32 => ScalarValue::UInt32(Some(u32::$extreme)), + DataType::UInt64 => ScalarValue::UInt64(Some(u64::$extreme)), + DataType::Int8 => ScalarValue::Int8(Some(i8::$extreme)), + DataType::Int16 => ScalarValue::Int16(Some(i16::$extreme)), + DataType::Int32 => ScalarValue::Int32(Some(i32::$extreme)), + DataType::Int64 => ScalarValue::Int64(Some(i64::$extreme)), + DataType::Float32 => ScalarValue::Float32(Some(f32::$extreme)), + DataType::Float64 => ScalarValue::Float64(Some(f64::$extreme)), + DataType::Duration(TimeUnit::Second) => { + ScalarValue::DurationSecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(Some(i64::$extreme)) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(Some(i64::$extreme)) + } + DataType::Timestamp(TimeUnit::Second, _) => { + ScalarValue::TimestampSecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + ScalarValue::TimestampMillisecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + ScalarValue::TimestampMicrosecond(Some(i64::$extreme), None) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + ScalarValue::TimestampNanosecond(Some(i64::$extreme), None) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(Some(i32::$extreme)) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(Some(i64::$extreme)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(Some(i128::$extreme)) + } + _ => unreachable!(), + } + }; +} + +macro_rules! value_transition { + ($bound:ident, $direction:expr, $value:expr) => { + match $value { + UInt8(Some(value)) if value == u8::$bound => UInt8(None), + UInt16(Some(value)) if value == u16::$bound => UInt16(None), + UInt32(Some(value)) if value == u32::$bound => UInt32(None), + UInt64(Some(value)) if value == u64::$bound => UInt64(None), + Int8(Some(value)) if value == i8::$bound => Int8(None), + Int16(Some(value)) if value == i16::$bound => Int16(None), + Int32(Some(value)) if value == i32::$bound => Int32(None), + Int64(Some(value)) if value == i64::$bound => Int64(None), + Float32(Some(value)) if value == f32::$bound => Float32(None), + Float64(Some(value)) if value == f64::$bound => Float64(None), + DurationSecond(Some(value)) if value == i64::$bound => DurationSecond(None), + DurationMillisecond(Some(value)) if value == i64::$bound => { + DurationMillisecond(None) + } + DurationMicrosecond(Some(value)) if value == i64::$bound => { + DurationMicrosecond(None) + } + DurationNanosecond(Some(value)) if value == i64::$bound => { + DurationNanosecond(None) + } + TimestampSecond(Some(value), tz) if value == i64::$bound => { + TimestampSecond(None, tz) + } + TimestampMillisecond(Some(value), tz) if value == i64::$bound => { + TimestampMillisecond(None, tz) + } + TimestampMicrosecond(Some(value), tz) if value == i64::$bound => { + TimestampMicrosecond(None, tz) + } + TimestampNanosecond(Some(value), tz) if value == i64::$bound => { + TimestampNanosecond(None, tz) + } + IntervalYearMonth(Some(value)) if value == i32::$bound => { + IntervalYearMonth(None) + } + IntervalDayTime(Some(value)) if value == i64::$bound => IntervalDayTime(None), + IntervalMonthDayNano(Some(value)) if value == i128::$bound => { + IntervalMonthDayNano(None) + } + _ => next_value_helper::<$direction>($value), + } + }; +} + +/// The `Interval` type represents a closed interval used for computing +/// reliable bounds for mathematical expressions. +/// +/// Conventions: +/// +/// 1. **Closed bounds**: The interval always encompasses its endpoints. We +/// accommodate operations resulting in open intervals by incrementing or +/// decrementing the interval endpoint value to its successor/predecessor. +/// +/// 2. **Unbounded endpoints**: If the `lower` or `upper` bounds are indeterminate, +/// they are labeled as *unbounded*. This is represented using a `NULL`. +/// +/// 3. **Overflow handling**: If the `lower` or `upper` endpoints exceed their +/// limits after any operation, they either become unbounded or they are fixed +/// to the maximum/minimum value of the datatype, depending on the direction +/// of the overflowing endpoint, opting for the safer choice. +/// +/// 4. **Floating-point special cases**: +/// - `INF` values are converted to `NULL`s while constructing an interval to +/// ensure consistency, with other data types. +/// - `NaN` (Not a Number) results are conservatively result in unbounded +/// endpoints. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Interval { + lower: ScalarValue, + upper: ScalarValue, +} + +/// This macro handles the `NaN` and `INF` floating point values. +/// +/// - `NaN` values are always converted to unbounded i.e. `NULL` values. +/// - For lower bounds: +/// - A `NEG_INF` value is converted to a `NULL`. +/// - An `INF` value is conservatively converted to the maximum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as a `NEG_INF`. +/// - For upper bounds: +/// - An `INF` value is converted to a `NULL`. +/// - An `NEG_INF` value is conservatively converted to the minimum representable +/// number for the floating-point type in question. In this case, converting +/// to `NULL` doesn't make sense as it would be interpreted as an `INF`. +macro_rules! handle_float_intervals { + ($scalar_type:ident, $primitive_type:ident, $lower:expr, $upper:expr) => {{ + let lower = match $lower { + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::NEG_INFINITY || l_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(l_val)) + if l_val == $primitive_type::INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MAX)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + let upper = match $upper { + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::INFINITY || r_val.is_nan() => + { + ScalarValue::$scalar_type(None) + } + ScalarValue::$scalar_type(Some(r_val)) + if r_val == $primitive_type::NEG_INFINITY => + { + ScalarValue::$scalar_type(Some($primitive_type::MIN)) + } + value @ ScalarValue::$scalar_type(Some(_)) => value, + _ => ScalarValue::$scalar_type(None), + }; + + Interval { lower, upper } + }}; +} + +/// Ordering floating-point numbers according to their binary representations +/// contradicts with their natural ordering. Floating-point number ordering +/// after unsigned integer transmutation looks like: +/// +/// ```text +/// 0, 1, 2, 3, ..., MAX, -0, -1, -2, ..., -MAX +/// ``` +/// +/// This macro applies a one-to-one map that fixes the ordering above. +macro_rules! map_floating_point_order { + ($value:expr, $ty:ty) => {{ + let num_bits = std::mem::size_of::<$ty>() * 8; + let sign_bit = 1 << (num_bits - 1); + if $value & sign_bit == sign_bit { + // Negative numbers: + !$value + } else { + // Positive numbers: + $value | sign_bit + } + }}; +} + +impl Interval { + /// Attempts to create a new `Interval` from the given lower and upper bounds. + /// + /// # Notes + /// + /// This constructor creates intervals in a "canonical" form where: + /// - **Boolean intervals**: + /// - Unboundedness (`NULL`) for boolean endpoints is converted to `false` + /// for lower and `true` for upper bounds. + /// - **Floating-point intervals**: + /// - Floating-point endpoints with `NaN`, `INF`, or `NEG_INF` are converted + /// to `NULL`s. + pub fn try_new(lower: ScalarValue, upper: ScalarValue) -> Result { + if lower.data_type() != upper.data_type() { + return internal_err!("Endpoints of an Interval should have the same type"); + } + + let interval = Self::new(lower, upper); + + if interval.lower.is_null() + || interval.upper.is_null() + || interval.lower <= interval.upper + { + Ok(interval) + } else { + internal_err!( + "Interval's lower bound {} is greater than the upper bound {}", + interval.lower, + interval.upper + ) + } + } + + /// Only for internal usage. Responsible for standardizing booleans and + /// floating-point values, as well as fixing NaNs. It doesn't validate + /// the given bounds for ordering, or verify that they have the same data + /// type. For its user-facing counterpart and more details, see + /// [`Interval::try_new`]. + fn new(lower: ScalarValue, upper: ScalarValue) -> Self { + if let ScalarValue::Boolean(lower_bool) = lower { + let ScalarValue::Boolean(upper_bool) = upper else { + // We are sure that upper and lower bounds have the same type. + unreachable!(); + }; + // Standardize boolean interval endpoints: + Self { + lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), + upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), + } + } + // Standardize floating-point endpoints: + else if lower.data_type() == DataType::Float32 { + handle_float_intervals!(Float32, f32, lower, upper) + } else if lower.data_type() == DataType::Float64 { + handle_float_intervals!(Float64, f64, lower, upper) + } else { + // Other data types do not require standardization: + Self { lower, upper } + } + } + + /// Convenience function to create a new `Interval` from the given (optional) + /// bounds, for use in tests only. Absence of either endpoint indicates + /// unboundedness on that side. See [`Interval::try_new`] for more information. + pub fn make(lower: Option, upper: Option) -> Result + where + ScalarValue: From>, + { + Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) + } + + /// Creates an unbounded interval from both sides if the datatype supported. + pub fn make_unbounded(data_type: &DataType) -> Result { + let unbounded_endpoint = ScalarValue::try_from(data_type)?; + Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) + } + + /// Returns a reference to the lower bound. + pub fn lower(&self) -> &ScalarValue { + &self.lower + } + + /// Returns a reference to the upper bound. + pub fn upper(&self) -> &ScalarValue { + &self.upper + } + + /// Converts this `Interval` into its boundary scalar values. It's useful + /// when you need to work with the individual bounds directly. + pub fn into_bounds(self) -> (ScalarValue, ScalarValue) { + (self.lower, self.upper) + } + + /// This function returns the data type of this interval. + pub fn data_type(&self) -> DataType { + let lower_type = self.lower.data_type(); + let upper_type = self.upper.data_type(); + + // There must be no way to create an interval whose endpoints have + // different types. + assert!( + lower_type == upper_type, + "Interval bounds have different types: {lower_type} != {upper_type}" + ); + lower_type + } + + /// Casts this interval to `data_type` using `cast_options`. + pub fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Self::try_new( + cast_scalar_value(&self.lower, data_type, cast_options)?, + cast_scalar_value(&self.upper, data_type, cast_options)?, + ) + } + + pub const CERTAINLY_FALSE: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + }; + + pub const UNCERTAIN: Self = Self { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + }; + + pub const CERTAINLY_TRUE: Self = Self { + lower: ScalarValue::Boolean(Some(true)), + upper: ScalarValue::Boolean(Some(true)), + }; + + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && self.upper <= rhs.lower + { + // Values in this interval are certainly less than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_FALSE) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && (self.lower > rhs.upper) + { + // Values in this interval are certainly greater than those in the + // given interval. + Ok(Self::CERTAINLY_TRUE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly greater than or equal to, possibly + /// greater than or equal to, or can't be greater than or equal to `other` + /// by returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn gt_eq>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + internal_err!( + "Only intervals with the same data type are comparable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !(self.lower.is_null() || rhs.upper.is_null()) + && self.lower >= rhs.upper + { + // Values in this interval are certainly greater than or equal to + // those in the given interval. + Ok(Self::CERTAINLY_TRUE) + } else if !(self.upper.is_null() || rhs.lower.is_null()) + && (self.upper < rhs.lower) + { + // Values in this interval are certainly less than those in the + // given interval. + Ok(Self::CERTAINLY_FALSE) + } else { + // All outcomes are possible. + Ok(Self::UNCERTAIN) + } + } + + /// Decide if this interval is certainly less than, possibly less than, or + /// can't be less than `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt>(&self, other: T) -> Result { + other.borrow().gt(self) + } + + /// Decide if this interval is certainly less than or equal to, possibly + /// less than or equal to, or can't be less than or equal to `other` by + /// returning `[true, true]`, `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn lt_eq>(&self, other: T) -> Result { + other.borrow().gt_eq(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, or + /// can't be equal to `other` by returning `[true, true]`, `[false, true]` + /// or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub(crate) fn equal>(&self, other: T) -> Result { + let rhs = other.borrow(); + if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { + internal_err!( + "Interval data types must be compatible for equality checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ) + } else if !self.lower.is_null() + && (self.lower == self.upper) + && (rhs.lower == rhs.upper) + && (self.lower == rhs.lower) + { + Ok(Self::CERTAINLY_TRUE) + } else if self.intersect(rhs)?.is_none() { + Ok(Self::CERTAINLY_FALSE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. + pub(crate) fn and>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower && other_lower; + let upper = self_upper && other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Result { + if self.data_type().ne(&DataType::Boolean) { + internal_err!("Cannot apply logical negation to a non-boolean interval") + } else if self == &Self::CERTAINLY_TRUE { + Ok(Self::CERTAINLY_FALSE) + } else if self == &Self::CERTAINLY_FALSE { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + + /// Compute the intersection of this interval with the given interval. + /// If the intersection is empty, return `None`. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn intersect>(&self, other: T) -> Result> { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + // If it is evident that the result is an empty interval, short-circuit + // and directly return `None`. + if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) + || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + { + return Ok(None); + } + + let lower = max_of_bounds(&self.lower, &rhs.lower); + let upper = min_of_bounds(&self.upper, &rhs.upper); + + // New lower and upper bounds must always construct a valid interval. + assert!( + (lower.is_null() || upper.is_null() || (lower <= upper)), + "The intersection of two intervals can not be an invalid interval" + ); + + Ok(Some(Self { lower, upper })) + } + + /// Decide if this interval certainly contains, possibly contains, or can't + /// contain a [`ScalarValue`] (`other`) by returning `[true, true]`, + /// `[false, true]` or `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains_value>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Data types must be compatible for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + } + + // We only check the upper bound for a `None` value because `None` + // values are less than `Some` values according to Rust. + Ok(&self.lower <= rhs && (self.upper.is_null() || rhs <= &self.upper)) + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if self.data_type().ne(&rhs.data_type()) { + return internal_err!( + "Interval data types must match for containment checks, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + match self.intersect(rhs)? { + Some(intersection) => { + if &intersection == rhs { + Ok(Self::CERTAINLY_TRUE) + } else { + Ok(Self::UNCERTAIN) + } + } + None => Ok(Self::CERTAINLY_FALSE), + } + } + + /// Add the given interval (`other`) to this interval. Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their sum is `[a1 + a2, b1 + b2]`. Note + /// that this represents all possible values the sum can take if one can + /// choose single values arbitrarily from each of the operands. + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Plus, &rhs.data_type())?; + + Ok(Self::new( + add_bounds::(&dt, &self.lower, &rhs.lower), + add_bounds::(&dt, &self.upper, &rhs.upper), + )) + } + + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their difference is + /// `[a1 - b2, b1 - a2]`. Note that this represents all possible values the + /// difference can take if one can choose single values arbitrarily from + /// each of the operands. + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = get_result_type(&self.data_type(), &Operator::Minus, &rhs.data_type())?; + + Ok(Self::new( + sub_bounds::(&dt, &self.lower, &rhs.upper), + sub_bounds::(&dt, &self.upper, &rhs.lower), + )) + } + + /// Multiply the given interval (`other`) with this interval. Say we have + /// intervals `[a1, b1]` and `[a2, b2]`, then their product is `[min(a1 * a2, + /// a1 * b2, b1 * a2, b1 * b2), max(a1 * a2, a1 * b2, b1 * a2, b1 * b2)]`. + /// Note that this represents all possible values the product can take if + /// one can choose single values arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn mul>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + + let result = match ( + self.contains_value(&zero)?, + rhs.contains_value(&zero)?, + dt.is_unsigned_integer(), + ) { + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, false, false) => { + mul_helper_single_zero_inclusive(&dt, self, rhs, zero) + } + (false, true, false) => { + mul_helper_single_zero_inclusive(&dt, rhs, self, zero) + } + _ => mul_helper_zero_exclusive(&dt, self, rhs, zero), + }; + Ok(result) + } + + /// Divide this interval by the given interval (`other`). Say we have intervals + /// `[a1, b1]` and `[a2, b2]`, then their division is `[a1, b1] * [1 / b2, 1 / a2]` + /// if `0 ∉ [a2, b2]` and `[NEG_INF, INF]` otherwise. Note that this represents + /// all possible values the quotient can take if one can choose single values + /// arbitrarily from each of the operands. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + /// + /// **TODO**: Once interval sets are supported, cases where the divisor contains + /// zero should result in an interval set, not the universal set. + pub fn div>(&self, other: T) -> Result { + let rhs = other.borrow(); + let dt = if self.data_type().eq(&rhs.data_type()) { + self.data_type() + } else { + return internal_err!( + "Intervals must have the same data type for division, lhs:{}, rhs:{}", + self.data_type(), + rhs.data_type() + ); + }; + + let zero = ScalarValue::new_zero(&dt)?; + // We want 0 to be approachable from both negative and positive sides. + let zero_point = match &dt { + DataType::Float32 | DataType::Float64 => Self::new(zero.clone(), zero), + _ => Self::new(prev_value(zero.clone()), next_value(zero)), + }; + + // Exit early with an unbounded interval if zero is strictly inside the + // right hand side: + if rhs.contains(&zero_point)? == Self::CERTAINLY_TRUE && !dt.is_unsigned_integer() + { + Self::make_unbounded(&dt) + } + // At this point, we know that only one endpoint of the right hand side + // can be zero. + else if self.contains(&zero_point)? == Self::CERTAINLY_TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + } else { + Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + } + } + + /// Returns the cardinality of this interval, which is the number of all + /// distinct points inside it. This function returns `None` if: + /// - The interval is unbounded from either side, or + /// - Cardinality calculations for the datatype in question is not + /// implemented yet, or + /// - An overflow occurs during the calculation: This case can only arise + /// when the calculated cardinality does not fit in an `u64`. + pub fn cardinality(&self) -> Option { + let data_type = self.data_type(); + if data_type.is_integer() { + self.upper.distance(&self.lower).map(|diff| diff as u64) + } else if data_type.is_floating() { + // Negative numbers are sorted in the reverse order. To + // always have a positive difference after the subtraction, + // we perform following transformation: + match (&self.lower, &self.upper) { + // Exploit IEEE 754 ordering properties to calculate the correct + // cardinality in all cases (including subnormals). + ( + ScalarValue::Float32(Some(lower)), + ScalarValue::Float32(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u32); + let upper_bits = map_floating_point_order!(upper.to_bits(), u32); + Some((upper_bits - lower_bits) as u64) + } + ( + ScalarValue::Float64(Some(lower)), + ScalarValue::Float64(Some(upper)), + ) => { + let lower_bits = map_floating_point_order!(lower.to_bits(), u64); + let upper_bits = map_floating_point_order!(upper.to_bits(), u64); + let count = upper_bits - lower_bits; + (count != u64::MAX).then_some(count) + } + _ => None, + } + } else { + // Cardinality calculations are not implemented for this data type yet: + None + } + .map(|result| result + 1) + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "[{}, {}]", self.lower, self.upper) + } +} + +/// Applies the given binary operator the `lhs` and `rhs` arguments. +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => lhs.equal(rhs), + Operator::NotEq => lhs.equal(rhs)?.not(), + Operator::Gt => lhs.gt(rhs), + Operator::GtEq => lhs.gt_eq(rhs), + Operator::Lt => lhs.lt(rhs), + Operator::LtEq => lhs.lt_eq(rhs), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + Operator::Multiply => lhs.mul(rhs), + Operator::Divide => lhs.div(rhs), + _ => internal_err!("Interval arithmetic does not support the operator {op}"), + } +} + +/// Helper function used for adding the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn add_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.add_checked(rhs)) + } + _ => lhs.add_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Plus, lhs, rhs)) +} + +/// Helper function used for subtracting the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn sub_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.sub_checked(rhs)) + } + _ => lhs.sub_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Minus, lhs, rhs)) +} + +/// Helper function used for multiplying the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn mul_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + if lhs.is_null() || rhs.is_null() { + return ScalarValue::try_from(dt).unwrap(); + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.mul_checked(rhs)) + } + _ => lhs.mul_checked(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Multiply, lhs, rhs)) +} + +/// Helper function used for dividing the end-point values of intervals. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, and the following +/// interval creation is standardized with `Interval::new`. +fn div_bounds( + dt: &DataType, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + + if (lhs.is_null() || rhs.eq(&zero)) || (dt.is_unsigned_integer() && rhs.is_null()) { + return ScalarValue::try_from(dt).unwrap(); + } else if rhs.is_null() { + return zero; + } + + match dt { + DataType::Float64 | DataType::Float32 => { + alter_fp_rounding_mode::(lhs, rhs, |lhs, rhs| lhs.div(rhs)) + } + _ => lhs.div(rhs), + } + .unwrap_or_else(|_| handle_overflow::(dt, Operator::Divide, lhs, rhs)) +} + +/// This function handles cases where an operation results in an overflow. Such +/// results are converted to an *unbounded endpoint* if: +/// - We are calculating an upper bound and we have a positive overflow. +/// - We are calculating a lower bound and we have a negative overflow. +/// Otherwise; the function sets the endpoint as: +/// - The minimum representable number with the given datatype (`dt`) if +/// we are calculating an upper bound and we have a negative overflow. +/// - The maximum representable number with the given datatype (`dt`) if +/// we are calculating a lower bound and we have a positive overflow. +/// +/// **Caution:** This function contains multiple calls to `unwrap()`, and may +/// return non-standardized interval bounds. Therefore, it should be used +/// with caution. Currently, it is used in contexts where the `DataType` +/// (`dt`) is validated prior to calling this function, `op` is supported by +/// interval library, and the following interval creation is standardized with +/// `Interval::new`. +fn handle_overflow( + dt: &DataType, + op: Operator, + lhs: &ScalarValue, + rhs: &ScalarValue, +) -> ScalarValue { + let zero = ScalarValue::new_zero(dt).unwrap(); + let positive_sign = match op { + Operator::Multiply | Operator::Divide => { + lhs.lt(&zero) && rhs.lt(&zero) || lhs.gt(&zero) && rhs.gt(&zero) + } + Operator::Plus => lhs.ge(&zero), + Operator::Minus => lhs.ge(rhs), + _ => { + unreachable!() + } + }; + match (UPPER, positive_sign) { + (true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(), + (true, false) => { + get_extreme_value!(MIN, dt) + } + (false, true) => { + get_extreme_value!(MAX, dt) + } + } +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn next_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MAX, true, value) +} + +// This function should remain private since it may corrupt the an interval if +// used without caution. +fn prev_value(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + value_transition!(MIN, false, value) +} + +trait OneTrait: Sized + std::ops::Add + std::ops::Sub { + fn one() -> Self; +} +macro_rules! impl_OneTrait{ + ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m } })*} +} +impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} + +/// This function either increments or decrements its argument, depending on +/// the `INC` value (where a `true` value corresponds to the increment). +fn increment_decrement( + mut value: T, +) -> T { + if INC { + value.add_assign(T::one()); + } else { + value.sub_assign(T::one()); + } + value +} + +/// This function returns the next/previous value depending on the `INC` value. +/// If `true`, it returns the next value; otherwise it returns the previous value. +fn next_value_helper(value: ScalarValue) -> ScalarValue { + use ScalarValue::*; + match value { + // f32/f64::NEG_INF/INF and f32/f64::NaN values should not emerge at this point. + Float32(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float32(Some(if INC { next_up(val) } else { next_down(val) })) + } + Float64(Some(val)) => { + assert!(val.is_finite(), "Non-standardized floating point usage"); + Float64(Some(if INC { next_up(val) } else { next_down(val) })) + } + Int8(Some(val)) => Int8(Some(increment_decrement::(val))), + Int16(Some(val)) => Int16(Some(increment_decrement::(val))), + Int32(Some(val)) => Int32(Some(increment_decrement::(val))), + Int64(Some(val)) => Int64(Some(increment_decrement::(val))), + UInt8(Some(val)) => UInt8(Some(increment_decrement::(val))), + UInt16(Some(val)) => UInt16(Some(increment_decrement::(val))), + UInt32(Some(val)) => UInt32(Some(increment_decrement::(val))), + UInt64(Some(val)) => UInt64(Some(increment_decrement::(val))), + DurationSecond(Some(val)) => { + DurationSecond(Some(increment_decrement::(val))) + } + DurationMillisecond(Some(val)) => { + DurationMillisecond(Some(increment_decrement::(val))) + } + DurationMicrosecond(Some(val)) => { + DurationMicrosecond(Some(increment_decrement::(val))) + } + DurationNanosecond(Some(val)) => { + DurationNanosecond(Some(increment_decrement::(val))) + } + TimestampSecond(Some(val), tz) => { + TimestampSecond(Some(increment_decrement::(val)), tz) + } + TimestampMillisecond(Some(val), tz) => { + TimestampMillisecond(Some(increment_decrement::(val)), tz) + } + TimestampMicrosecond(Some(val), tz) => { + TimestampMicrosecond(Some(increment_decrement::(val)), tz) + } + TimestampNanosecond(Some(val), tz) => { + TimestampNanosecond(Some(increment_decrement::(val)), tz) + } + IntervalYearMonth(Some(val)) => { + IntervalYearMonth(Some(increment_decrement::(val))) + } + IntervalDayTime(Some(val)) => { + IntervalDayTime(Some(increment_decrement::(val))) + } + IntervalMonthDayNano(Some(val)) => { + IntervalMonthDayNano(Some(increment_decrement::(val))) + } + _ => value, // Unbounded values return without change. + } +} + +/// Returns the greater of the given interval bounds. Assumes that a `NULL` +/// value represents `NEG_INF`. +fn max_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first >= second) { + first.clone() + } else { + second.clone() + } +} + +/// Returns the lesser of the given interval bounds. Assumes that a `NULL` +/// value represents `INF`. +fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { + if !first.is_null() && (second.is_null() || first <= second) { + first.clone() + } else { + second.clone() + } +} + +/// This function updates the given intervals by enforcing (i.e. propagating) +/// the inequality `left > right` (or the `left >= right` inequality, if `strict` +/// is `true`). +/// +/// Returns a `Result` wrapping an `Option` containing the tuple of resulting +/// intervals. If the comparison is infeasible, returns `None`. +/// +/// Example usage: +/// ``` +/// use datafusion_common::DataFusionError; +/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// +/// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; +/// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; +/// let strict = false; +/// assert_eq!( +/// satisfy_greater(&left, &right, strict)?, +/// Some(( +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))?, +/// Interval::make(Some(500.0_f32), Some(1000.0_f32))? +/// )) +/// ); +/// Ok::<(), DataFusionError>(()) +/// ``` +/// +/// NOTE: This function only works with intervals of the same data type. +/// Attempting to compare intervals of different data types will lead +/// to an error. +pub fn satisfy_greater( + left: &Interval, + right: &Interval, + strict: bool, +) -> Result> { + if left.data_type().ne(&right.data_type()) { + return internal_err!( + "Intervals must have the same data type, lhs:{}, rhs:{}", + left.data_type(), + right.data_type() + ); + } + + if !left.upper.is_null() && left.upper <= right.lower { + if !strict && left.upper == right.lower { + // Singleton intervals: + return Ok(Some(( + Interval::new(left.upper.clone(), left.upper.clone()), + Interval::new(left.upper.clone(), left.upper.clone()), + ))); + } else { + // Left-hand side: <--======----0------------> + // Right-hand side: <------------0--======----> + // No intersection, infeasible to propagate: + return Ok(None); + } + } + + // Only the lower bound of left hand side and the upper bound of the right + // hand side can change after propagating the greater-than operation. + let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { + if strict { + next_value(right.lower.clone()) + } else { + right.lower.clone() + } + } else { + left.lower.clone() + }; + // Below code is asymmetric relative to the above if statement, because + // `None` compares less than `Some` in Rust. + let new_right_upper = if right.upper.is_null() + || (!left.upper.is_null() && left.upper <= right.upper) + { + if strict { + prev_value(left.upper.clone()) + } else { + left.upper.clone() + } + } else { + right.upper.clone() + }; + + Ok(Some(( + Interval::new(new_left_lower, left.upper.clone()), + Interval::new(right.lower.clone(), new_right_upper), + ))) +} + +/// Multiplies two intervals that both contain zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that contain zero within their +/// ranges. Returns an error if the multiplication of bounds fails. +/// +/// ```text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <-------=====0=====-------> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_multi_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, +) -> Interval { + if lhs.lower.is_null() + || lhs.upper.is_null() + || rhs.lower.is_null() + || rhs.upper.is_null() + { + return Interval::make_unbounded(dt).unwrap(); + } + // Since unbounded cases are handled above, we can safely + // use the utility functions here to eliminate code duplication. + let lower = min_of_bounds( + &mul_bounds::(dt, &lhs.lower, &rhs.upper), + &mul_bounds::(dt, &rhs.lower, &lhs.upper), + ); + let upper = max_of_bounds( + &mul_bounds::(dt, &lhs.upper, &rhs.upper), + &mul_bounds::(dt, &lhs.lower, &rhs.lower), + ); + // There is no possibility to create an invalid interval. + Interval::new(lower, upper) +} + +/// Multiplies two intervals when only left-hand side interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. The interval not containing zero, i.e. rhs, can lie +/// on either side of zero. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_single_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = mul_bounds::(dt, &lhs.upper, &rhs.lower); + let upper = mul_bounds::(dt, &lhs.lower, &rhs.lower); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = mul_bounds::(dt, &lhs.lower, &rhs.upper); + let upper = mul_bounds::(dt, &lhs.upper, &rhs.upper); + Interval::new(lower, upper) + } +} + +/// Multiplies two intervals when neither of them contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their product (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the multiplication of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn mul_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero: ScalarValue, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero && !lhs.upper.is_null(), + rhs.upper <= zero && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + mul_bounds::(dt, &lhs.upper, &rhs.upper), + mul_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.upper), + mul_bounds::(dt, &lhs.upper, &rhs.lower), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + mul_bounds::(dt, &rhs.lower, &lhs.upper), + mul_bounds::(dt, &rhs.upper, &lhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + mul_bounds::(dt, &lhs.lower, &rhs.lower), + mul_bounds::(dt, &lhs.upper, &rhs.upper), + ), + }; + Interval::new(lower, upper) +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// the former contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). This function +/// serves as a subroutine that handles the specific case when only `lhs` contains +/// zero within its range. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <-------=====0=====-------> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_lhs_zero_inclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + // With the following interval bounds, there is no possibility to create an invalid interval. + if rhs.upper <= zero_point.lower && !rhs.upper.is_null() { + // <-------=====0=====-------> + // <--======----0------------> + let lower = div_bounds::(dt, &lhs.upper, &rhs.upper); + let upper = div_bounds::(dt, &lhs.lower, &rhs.upper); + Interval::new(lower, upper) + } else { + // <-------=====0=====-------> + // <------------0--======----> + let lower = div_bounds::(dt, &lhs.lower, &rhs.lower); + let upper = div_bounds::(dt, &lhs.upper, &rhs.lower); + Interval::new(lower, upper) + } +} + +/// Divides the left-hand side interval by the right-hand side interval when +/// neither interval contains zero. +/// +/// This function takes in two intervals (`lhs` and `rhs`) as arguments and +/// returns their quotient (whose data type is known to be `dt`). It is +/// specifically designed to handle intervals that do not contain zero within +/// their ranges. Returns an error if the division of bounds fails. +/// +/// ``` text +/// Left-hand side: <--======----0------------> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <--======----0------------> +/// Right-hand side: <------------0--======----> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <--======----0------------> +/// +/// or +/// +/// Left-hand side: <------------0--======----> +/// Right-hand side: <------------0--======----> +/// ``` +/// +/// **Caution:** This function contains multiple calls to `unwrap()`. Therefore, +/// it should be used with caution. Currently, it is used in contexts where the +/// `DataType` (`dt`) is validated prior to calling this function. +fn div_helper_zero_exclusive( + dt: &DataType, + lhs: &Interval, + rhs: &Interval, + zero_point: &Interval, +) -> Interval { + let (lower, upper) = match ( + lhs.upper <= zero_point.lower && !lhs.upper.is_null(), + rhs.upper <= zero_point.lower && !rhs.upper.is_null(), + ) { + // With the following interval bounds, there is no possibility to create an invalid interval. + (true, true) => ( + // <--======----0------------> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.lower), + div_bounds::(dt, &lhs.lower, &rhs.upper), + ), + (true, false) => ( + // <--======----0------------> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.lower), + div_bounds::(dt, &lhs.upper, &rhs.upper), + ), + (false, true) => ( + // <------------0--======----> + // <--======----0------------> + div_bounds::(dt, &lhs.upper, &rhs.upper), + div_bounds::(dt, &lhs.lower, &rhs.lower), + ), + (false, false) => ( + // <------------0--======----> + // <------------0--======----> + div_bounds::(dt, &lhs.lower, &rhs.upper), + div_bounds::(dt, &lhs.upper, &rhs.lower), + ), + }; + Interval::new(lower, upper) +} + +/// This function computes the selectivity of an operation by computing the +/// cardinality ratio of the given input/output intervals. If this can not be +/// calculated for some reason, it returns `1.0` meaning fully selective (no +/// filtering). +pub fn cardinality_ratio(initial_interval: &Interval, final_interval: &Interval) -> f64 { + match (final_interval.cardinality(), initial_interval.cardinality()) { + (Some(final_interval), Some(initial_interval)) => { + (final_interval as f64) / (initial_interval as f64) + } + _ => 1.0, + } +} + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use arrow::datatypes::DataType; +/// use datafusion_common::ScalarValue; +/// use datafusion_expr::interval_arithmetic::Interval; +/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// +/// // [1, 2) U {NULL} +/// let maybe_null = NullableInterval::MaybeNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(1)), +/// ScalarValue::Int32(Some(2)), +/// ).unwrap(), +/// }; +/// +/// // (0, ∞) +/// let not_null = NullableInterval::NotNull { +/// values: Interval::try_new( +/// ScalarValue::Int32(Some(0)), +/// ScalarValue::Int32(None), +/// ).unwrap(), +/// }; +/// +/// // {NULL} +/// let null_interval = NullableInterval::Null { datatype: DataType::Int32 }; +/// +/// // {4} +/// let single_value = NullableInterval::from(ScalarValue::Int32(Some(4))); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null. This is typed so it can be used in physical + /// expressions, which don't do type coercion. + Null { datatype: DataType }, + /// The value may or may not be null. If it is non-null, its is within the + /// specified range. + MaybeNull { values: Interval }, + /// The value is definitely not null, and is within the specified range. + NotNull { values: Interval }, +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + } + } +} + +impl From for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: ScalarValue) -> Self { + if value.is_null() { + Self::Null { + datatype: value.data_type(), + } + } else { + Self::NotNull { + values: Interval { + lower: value.clone(), + upper: value, + }, + } + } + } +} + +impl NullableInterval { + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn data_type(&self) -> DataType { + match self { + Self::Null { datatype } => datatype.clone(), + Self::MaybeNull { values } | Self::NotNull { values } => values.data_type(), + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(3)), + /// ).unwrap(), + /// }; + /// let rhs = NullableInterval::NotNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// }); + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self, rhs) { + // NULL is distinct from NULL -> False + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, + // x is distinct from y -> x != y, + // if at least one of them is never null. + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values)?.not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } + } + _ => Interval::UNCERTAIN, + }; + // IsDistinctFrom never returns null. + Ok(Self::NotNull { values }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| i.not())?, + _ => { + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } else { + Ok(Self::Null { + datatype: self.data_type(), + }) + } + } + } + } + + /// Decide if this interval is a superset of, overlaps with, or + /// disjoint with `other` by returning `[true, true]`, `[false, true]` or + /// `[false, false]` respectively. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + left_values + .contains(right_values) + .map(|values| match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Self::NotNull { values } + } + _ => Self::MaybeNull { values }, + }) + } else { + Ok(Self::Null { + datatype: DataType::Boolean, + }) + } + } + + /// If the interval has collapsed to a single value, return that value. + /// Otherwise, returns `None`. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval::MaybeNull { + /// values: Interval::try_new( + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(4)), + /// ).unwrap(), + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + match self { + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower == values.upper && !values.lower.is_null() => + { + Some(values.lower.clone()) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::interval_arithmetic::{next_value, prev_value, satisfy_greater, Interval}; + + use arrow::datatypes::DataType; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn test_next_prev_value() -> Result<()> { + let zeros = vec![ + ScalarValue::new_zero(&DataType::UInt8)?, + ScalarValue::new_zero(&DataType::UInt16)?, + ScalarValue::new_zero(&DataType::UInt32)?, + ScalarValue::new_zero(&DataType::UInt64)?, + ScalarValue::new_zero(&DataType::Int8)?, + ScalarValue::new_zero(&DataType::Int16)?, + ScalarValue::new_zero(&DataType::Int32)?, + ScalarValue::new_zero(&DataType::Int64)?, + ]; + let ones = vec![ + ScalarValue::new_one(&DataType::UInt8)?, + ScalarValue::new_one(&DataType::UInt16)?, + ScalarValue::new_one(&DataType::UInt32)?, + ScalarValue::new_one(&DataType::UInt64)?, + ScalarValue::new_one(&DataType::Int8)?, + ScalarValue::new_one(&DataType::Int16)?, + ScalarValue::new_one(&DataType::Int32)?, + ScalarValue::new_one(&DataType::Int64)?, + ]; + zeros.into_iter().zip(ones).for_each(|(z, o)| { + assert_eq!(next_value(z.clone()), o); + assert_eq!(prev_value(o), z); + }); + + let values = vec![ + ScalarValue::new_zero(&DataType::Float32)?, + ScalarValue::new_zero(&DataType::Float64)?, + ]; + let eps = vec![ + ScalarValue::Float32(Some(1e-6)), + ScalarValue::Float64(Some(1e-6)), + ]; + values.into_iter().zip(eps).for_each(|(value, eps)| { + assert!(next_value(value.clone()) + .sub(value.clone()) + .unwrap() + .lt(&eps)); + assert!(value + .clone() + .sub(prev_value(value.clone())) + .unwrap() + .lt(&eps)); + assert_ne!(next_value(value.clone()), value); + assert_ne!(prev_value(value.clone()), value); + }); + + let min_max = vec![ + ( + ScalarValue::UInt64(Some(u64::MIN)), + ScalarValue::UInt64(Some(u64::MAX)), + ), + ( + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ), + ( + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(f32::MAX)), + ), + ( + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(f64::MAX)), + ), + ]; + let inf = vec![ + ScalarValue::UInt64(None), + ScalarValue::Int8(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ]; + min_max.into_iter().zip(inf).for_each(|((min, max), inf)| { + assert_eq!(next_value(max.clone()), inf); + assert_ne!(prev_value(max.clone()), max); + assert_ne!(prev_value(max.clone()), inf); + + assert_eq!(prev_value(min.clone()), inf); + assert_ne!(next_value(min.clone()), min); + assert_ne!(next_value(min.clone()), inf); + + assert_eq!(next_value(inf.clone()), inf); + assert_eq!(prev_value(inf.clone()), inf); + }); + + Ok(()) + } + + #[test] + fn test_new_interval() -> Result<()> { + use ScalarValue::*; + + let cases = vec![ + ( + (Boolean(None), Boolean(Some(false))), + Boolean(Some(false)), + Boolean(Some(false)), + ), + ( + (Boolean(Some(false)), Boolean(None)), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (Boolean(Some(false)), Boolean(Some(true))), + Boolean(Some(false)), + Boolean(Some(true)), + ), + ( + (UInt16(Some(u16::MAX)), UInt16(None)), + UInt16(Some(u16::MAX)), + UInt16(None), + ), + ( + (Int16(None), Int16(Some(-1000))), + Int16(None), + Int16(Some(-1000)), + ), + ( + (Float32(Some(f32::MAX)), Float32(Some(f32::MAX))), + Float32(Some(f32::MAX)), + Float32(Some(f32::MAX)), + ), + ( + (Float32(Some(f32::NAN)), Float32(Some(f32::MIN))), + Float32(None), + Float32(Some(f32::MIN)), + ), + ( + ( + Float64(Some(f64::NEG_INFINITY)), + Float64(Some(f64::INFINITY)), + ), + Float64(None), + Float64(None), + ), + ]; + for (inputs, lower, upper) in cases { + let result = Interval::try_new(inputs.0, inputs.1)?; + assert_eq!(result.clone().lower(), &lower); + assert_eq!(result.upper(), &upper); + } + + let invalid_intervals = vec![ + (Float32(Some(f32::INFINITY)), Float32(Some(100_f32))), + (Float64(Some(0_f64)), Float64(Some(f64::NEG_INFINITY))), + (Boolean(Some(true)), Boolean(Some(false))), + (Int32(Some(1000)), Int32(Some(-2000))), + (UInt64(Some(1)), UInt64(Some(0))), + ]; + for (lower, upper) in invalid_intervals { + Interval::try_new(lower, upper).expect_err( + "Given parameters should have given an invalid interval error", + ); + } + + Ok(()) + } + + #[test] + fn test_make_unbounded() -> Result<()> { + use ScalarValue::*; + + let unbounded_cases = vec![ + (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), + (DataType::UInt8, UInt8(None), UInt8(None)), + (DataType::UInt16, UInt16(None), UInt16(None)), + (DataType::UInt32, UInt32(None), UInt32(None)), + (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::Int8, Int8(None), Int8(None)), + (DataType::Int16, Int16(None), Int16(None)), + (DataType::Int32, Int32(None), Int32(None)), + (DataType::Int64, Int64(None), Int64(None)), + (DataType::Float32, Float32(None), Float32(None)), + (DataType::Float64, Float64(None), Float64(None)), + ]; + for (dt, lower, upper) in unbounded_cases { + let inf = Interval::make_unbounded(&dt)?; + assert_eq!(inf.clone().lower(), &lower); + assert_eq!(inf.upper(), &upper); + } + + Ok(()) + } + + #[test] + fn gt_lt_test() -> Result<()> { + let exactly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(0.0))), + next_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + prev_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in exactly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(0.0_f32)), + next_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + ScalarValue::Float32(Some(-1.0_f32)), + )?, + ), + ]; + for (first, second) in possibly_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt(first)?, Interval::UNCERTAIN); + } + + let not_gt_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + ScalarValue::Float32(Some(0.0_f32)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1.0_f32)), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in not_gt_cases { + assert_eq!(first.gt(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn gteq_lteq_test() -> Result<()> { + let exactly_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(-1000_i64), Some(1000_i64))?, + Interval::make(None, Some(-1500_i64))?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::try_new( + ScalarValue::Float32(Some(-1.0)), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + ScalarValue::Float32(Some(-1.0)), + )?, + ), + ]; + for (first, second) in exactly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_gteq_cases = vec![ + ( + Interval::make(Some(999_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1001_i64))?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0_f32))), + next_value(ScalarValue::Float32(Some(-1.0_f32))), + )?, + ), + ]; + for (first, second) in possibly_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.lt_eq(first)?, Interval::UNCERTAIN); + } + + let not_gteq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0_f32))), + prev_value(ScalarValue::Float32(Some(0.0_f32))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_gteq_cases { + assert_eq!(first.gt_eq(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.lt_eq(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn equal_test() -> Result<()> { + let exactly_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + ), + ( + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + Interval::make(Some(f64::MIN), Some(f64::MIN))?, + ), + ]; + for (first, second) in exactly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_TRUE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_TRUE); + } + + let possibly_eq_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(100.0_f32), Some(200.0_f32))?, + Interval::make(Some(0.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + ScalarValue::Float32(Some(0.0)), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + prev_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in possibly_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::UNCERTAIN); + assert_eq!(second.equal(first)?, Interval::UNCERTAIN); + } + + let not_eq_cases = vec![ + ( + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(999_i64))?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1001_i64), Some(1500_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(0.0))), + prev_value(ScalarValue::Float32(Some(0.0))), + )?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-1.0_f32), Some(-1.0_f32))?, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-1.0))), + next_value(ScalarValue::Float32(Some(-1.0))), + )?, + ), + ]; + for (first, second) in not_eq_cases { + assert_eq!(first.equal(second.clone())?, Interval::CERTAINLY_FALSE); + assert_eq!(second.equal(first)?, Interval::CERTAINLY_FALSE); + } + + Ok(()) + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))? + .and(Interval::make(Some(case.2), Some(case.3))?)?, + Interval::make(Some(case.4), Some(case.5))? + ); + } + Ok(()) + } + + #[test] + fn not_test() -> Result<()> { + let cases = vec![ + (false, true, false, true), + (false, false, true, true), + (true, true, false, false), + ]; + + for case in cases { + assert_eq!( + Interval::make(Some(case.0), Some(case.1))?.not()?, + Interval::make(Some(case.2), Some(case.3))? + ); + } + Ok(()) + } + + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::make(Some(1000_i64), None)?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), None)?, + Interval::make(Some(1000_i64), Some(2000_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(1000_i64), Some(1500_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(2000_u64))?, + Interval::make(Some(500_u64), None)?, + Interval::make(Some(500_u64), Some(2000_u64))?, + ), + ( + Interval::make(Some(0_u64), Some(0_u64))?, + Interval::make(Some(0_u64), None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), + ( + Interval::make(Some(1000.0_f32), None)?, + Interval::make(None, Some(1000.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1000.0_f32))?, + ), + ( + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + Interval::make(Some(0.0_f32), Some(1500.0_f32))?, + Interval::make(Some(1000.0_f32), Some(1500.0_f32))?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + Interval::make(Some(-1500.0_f64), Some(2000.0_f64))?, + Interval::make(Some(-1000.0_f64), Some(1500.0_f64))?, + ), + ( + Interval::make(Some(16.0_f64), Some(32.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(32.0_f64), Some(32.0_f64))?, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.intersect(second)?.unwrap(), expected) + } + + let empty_cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(999_i64))?, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(2000_i64), Some(3000_i64))?, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + ), + ]; + for (first, second) in empty_cases { + assert_eq!(first.intersect(second)?, None) + } + + Ok(()) + } + + #[test] + fn test_contains() -> Result<()> { + let possible_cases = vec![ + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1501_i64), Some(1999_i64))?, + Interval::CERTAINLY_TRUE, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make::(None, None)?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), Some(2000_i64))?, + Interval::make(Some(500), Some(1500_i64))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(16.0), Some(32.0))?, + Interval::make(Some(32.0), Some(64.0))?, + Interval::UNCERTAIN, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(0_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::make(Some(1500_i64), Some(2000_i64))?, + Interval::make(Some(1000_i64), Some(1499_i64))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + prev_value(ScalarValue::Float32(Some(1.0))), + prev_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ( + Interval::try_new( + next_value(ScalarValue::Float32(Some(1.0))), + next_value(ScalarValue::Float32(Some(1.0))), + )?, + Interval::make(Some(1.0_f32), Some(1.0_f32))?, + Interval::CERTAINLY_FALSE, + ), + ]; + for (first, second, expected) in possible_cases { + assert_eq!(first.contains(second)?, expected) + } + + Ok(()) + } + + #[test] + fn test_add() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(400_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(300_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-200_i64), Some(350_i64))?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much greater than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make(None, Some(300_f64))?, + ), + ]; + for case in cases { + let result = case.0.add(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_sub() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(i32::MAX), Some(i32::MAX))?, + Interval::make(Some(11_i32), Some(11_i32))?, + Interval::make(Some(i32::MAX - 11), Some(i32::MAX - 11))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(200_i64))?, + Interval::make(Some(-100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(200_i64), None)?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(None, Some(200_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(100_i64))?, + ), + ( + Interval::make(Some(200_i64), None)?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-300_i64), Some(150_i64))?, + Interval::make(Some(-50_i64), Some(500_i64))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(-10_i64), Some(-10_i64))?, + Interval::make(Some(i64::MIN + 10), Some(i64::MIN + 10))?, + ), + ( + Interval::make(Some(1), Some(i64::MAX))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(Some(1 - i64::MAX), Some(0))?, + ), + ( + Interval::make(Some(i64::MIN), Some(i64::MIN))?, + Interval::make(Some(i64::MAX), Some(i64::MAX))?, + Interval::make(None, Some(i64::MIN))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(4_u32), Some(6_u32))?, + Interval::make(None, Some(6_u32))?, + ), + ( + Interval::make(Some(2_u32), Some(10_u32))?, + Interval::make(Some(20_u32), Some(30_u32))?, + Interval::make(None, Some(0_u32))?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + // Since rounding mode is up, the result would be much larger than f32::MIN + // (f32::MIN = -3.4_028_235e38, the result is -3.4_028_233e38) + Interval::make( + None, + Some(-340282330000000000000000000000000000000.0_f32), + )?, + ), + ( + Interval::make(Some(100_f64), None)?, + Interval::make(None, Some(200_f64))?, + Interval::make(Some(-100_f64), None)?, + ), + ( + Interval::make(None, Some(100_f64))?, + Interval::make(None, Some(200_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.sub(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper(),) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_mul() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(None, Some(2_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(None, Some(4_i64))?, + ), + ( + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(2_i64), None)?, + ), + ( + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-3_i64), Some(15_i64))?, + Interval::make(Some(-6_i64), Some(30_i64))?, + ), + ( + Interval::make(Some(-0.0), Some(0.0))?, + Interval::make(None, Some(0.0))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(10_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(0_u32), Some(1_u32))?, + Interval::make(None, Some(2_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(None, Some(4_u32))?, + ), + ( + Interval::make(None, Some(2_u32))?, + Interval::make(Some(1_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u32), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(11_f32), Some(11_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(-10_f32), Some(-10_f32))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(1.0), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(f32::MAX), None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(f32::MIN))?, + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(None, Some(f32::MIN))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), None)?, + ), + ( + Interval::make(Some(1_f64), None)?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(None, Some(1_f64))?, + Interval::make(None, Some(2_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(-0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(1.0_f64))?, + Interval::make(Some(1_f64), Some(2_f64))?, + Interval::make(Some(0.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(1.0_f64))?, + Interval::make(Some(-1_f64), Some(2_f64))?, + Interval::make(Some(-1.0_f64), Some(2.0_f64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make::(None, Some(10.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make::(None, None)?, + ), + ]; + for case in cases { + let result = case.0.mul(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_div() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(50_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(-1_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(-100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(-50_i64))?, + ), + ( + Interval::make(Some(-200_i64), Some(100_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-200_i64), Some(100_i64))?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(1_i64), Some(2_i64))?, + Interval::make(Some(-100_i64), Some(200_i64))?, + ), + ( + Interval::make(Some(10_i64), Some(20_i64))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-1_i64), Some(2_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-100_i64), Some(200_i64))?, + Interval::make(Some(-2_i64), Some(1_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), None)?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(None, Some(0_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1_i64))?, + Interval::make(Some(100_i64), Some(200_i64))?, + Interval::make(Some(0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(1_u32), Some(2_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(None, Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(2_u32))?, + Interval::make(Some(5_u32), None)?, + ), + ( + Interval::make(Some(10_u32), Some(20_u32))?, + Interval::make(Some(0_u32), Some(0_u32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(10_u64), Some(20_u64))?, + Interval::make(Some(0_u64), Some(4_u64))?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(None, Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(Some(12_u64), Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make(Some(6_u64), None)?, + ), + ( + Interval::make(None, Some(48_u64))?, + Interval::make(Some(0_u64), Some(2_u64))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MAX), Some(f32::MAX))?, + Interval::make(Some(-0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), None)?, + Interval::make(Some(0.1_f32), Some(0.1_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(-10.0_f32), Some(10.0_f32))?, + Interval::make(Some(-0.1_f32), Some(-0.1_f32))?, + Interval::make(Some(-100.0_f32), Some(100.0_f32))?, + ), + ( + Interval::make(Some(-10.0_f32), Some(f32::MAX))?, + Interval::make::(None, None)?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + Interval::make(Some(1.0_f32), None)?, + Interval::make(Some(f32::MIN), Some(10.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(-0.0_f32), Some(0.0_f32))?, + Interval::make(None, Some(-0.0_f32))?, + Interval::make::(None, None)?, + ), + ( + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + Interval::make(Some(f32::MAX), None)?, + Interval::make(Some(0.0_f32), Some(0.0_f32))?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(0.0_f32), Some(4.0_f32))?, + Interval::make(Some(0.25_f32), None)?, + ), + ( + Interval::make(Some(1.0_f32), Some(2.0_f32))?, + Interval::make(Some(-4.0_f32), Some(-0.0_f32))?, + Interval::make(None, Some(-0.25_f32))?, + ), + ( + Interval::make(Some(-4.0_f64), Some(2.0_f64))?, + Interval::make(Some(10.0_f64), Some(20.0_f64))?, + Interval::make(Some(-0.4_f64), Some(0.2_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + Interval::make(None, Some(-0.0_f64))?, + Interval::make(Some(0.0_f64), None)?, + ), + ( + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make::(None, None)?, + Interval::make(Some(0.0_f64), None)?, + ), + ]; + for case in cases { + let result = case.0.div(case.1)?; + if case.0.data_type().is_floating() { + assert!( + result.lower().is_null() && case.2.lower().is_null() + || result.lower().le(case.2.lower()) + ); + assert!( + result.upper().is_null() && case.2.upper().is_null() + || result.upper().ge(case.2.upper()) + ); + } else { + assert_eq!(result, case.2); + } + } + + Ok(()) + } + + #[test] + fn test_cardinality_of_intervals() -> Result<()> { + // In IEEE 754 standard for floating-point arithmetic, if we keep the sign and exponent fields same, + // we can represent 4503599627370496+1 different numbers by changing the mantissa + // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and 2^23 = 8388608 for f32). + // TODO: Add tests for non-exponential boundary aligned intervals too. + let distinct_f64 = 4503599627370497; + let distinct_f32 = 8388609; + let intervals = [ + Interval::make(Some(0.25_f64), Some(0.50_f64))?, + Interval::make(Some(0.5_f64), Some(1.0_f64))?, + Interval::make(Some(1.0_f64), Some(2.0_f64))?, + Interval::make(Some(32.0_f64), Some(64.0_f64))?, + Interval::make(Some(-0.50_f64), Some(-0.25_f64))?, + Interval::make(Some(-32.0_f64), Some(-16.0_f64))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f64); + } + + let intervals = [ + Interval::make(Some(0.25_f32), Some(0.50_f32))?, + Interval::make(Some(-1_f32), Some(-0.5_f32))?, + ]; + for interval in intervals { + assert_eq!(interval.cardinality().unwrap(), distinct_f32); + } + + // The regular logarithmic distribution of floating-point numbers are + // only applicable outside of the `(-phi, phi)` interval where `phi` + // denotes the largest positive subnormal floating-point number. Since + // the following intervals include such subnormal points, we cannot use + // a simple powers-of-two type formula for our expectations. Therefore, + // we manually supply the actual expected cardinality. + let interval = Interval::make(Some(-0.0625), Some(0.0625))?; + assert_eq!(interval.cardinality().unwrap(), 9178336040581070850); + + let interval = Interval::try_new( + ScalarValue::UInt64(Some(u64::MIN + 1)), + ScalarValue::UInt64(Some(u64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Int64(Some(i64::MIN + 1)), + ScalarValue::Int64(Some(i64::MAX)), + )?; + assert_eq!(interval.cardinality().unwrap(), u64::MAX); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(-0.0_f32)), + ScalarValue::Float32(Some(0.0_f32)), + )?; + assert_eq!(interval.cardinality().unwrap(), 2); + + Ok(()) + } + + #[test] + fn test_satisfy_comparison() -> Result<()> { + let cases = vec![ + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + true, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + true, + Interval::make(Some(1000_i64), Some(1000_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + false, + Interval::make(Some(1000_i64), None)?, + Interval::make(None, Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + true, + Interval::make(Some(500_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + true, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make(Some(0_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(1500_i64))?, + false, + Interval::make(Some(501_i64), Some(1000_i64))?, + Interval::make(Some(500_i64), Some(999_i64))?, + ), + ( + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + false, + Interval::make(Some(500_i64), Some(1500_i64))?, + Interval::make(Some(0_i64), Some(1000_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + false, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(0_i64))?, + ), + ( + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make::(None, None)?, + true, + Interval::make(Some(1_i64), Some(1_i64))?, + Interval::make(None, Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + false, + Interval::make(Some(2_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make::(None, None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + true, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(1_i64), Some(1_i64))?, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + false, + Interval::try_new( + next_value(ScalarValue::Float32(Some(-500.0))), + ScalarValue::Float32(Some(1000.0)), + )?, + Interval::make(Some(-500_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + true, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(500.0_f32))?, + ), + ( + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + false, + Interval::make(Some(-500.0_f32), Some(500.0_f32))?, + Interval::try_new( + ScalarValue::Float32(Some(-1000.0_f32)), + prev_value(ScalarValue::Float32(Some(500.0_f32))), + )?, + ), + ( + Interval::make(Some(-1000.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + true, + Interval::make(Some(-500.0_f64), Some(1000.0_f64))?, + Interval::make(Some(-500.0_f64), Some(500.0_f64))?, + ), + ]; + for (first, second, includes_endpoints, left_modified, right_modified) in cases { + assert_eq!( + satisfy_greater(&first, &second, !includes_endpoints)?.unwrap(), + (left_modified, right_modified) + ); + } + + let infeasible_cases = vec![ + ( + Interval::make(None, Some(1000_i64))?, + Interval::make(Some(1000_i64), None)?, + false, + ), + ( + Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?, + Interval::make(Some(1500.0_f32), Some(2000.0_f32))?, + false, + ), + ]; + for (first, second, includes_endpoints) in infeasible_cases { + assert_eq!(satisfy_greater(&first, &second, !includes_endpoints)?, None); + } + + Ok(()) + } + + #[test] + fn test_interval_display() { + let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); + assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + + let interval = Interval::try_new( + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ScalarValue::Float32(Some(f32::INFINITY)), + ) + .unwrap(); + assert_eq!(format!("{}", interval), "[NULL, NULL]"); + } + + macro_rules! capture_mode_change { + ($TYPE:ty) => { + paste::item! { + capture_mode_change_helper!([], + [], + $TYPE); + } + }; + } + + macro_rules! capture_mode_change_helper { + ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { + fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { + Interval::try_new( + ScalarValue::try_from(Some(lower as $TYPE)).unwrap(), + ScalarValue::try_from(Some(upper as $TYPE)).unwrap(), + ) + .unwrap() + } + + fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { + assert!(expect_low || expect_high); + let interval1 = $CREATE_FN_NAME(input.0, input.0); + let interval2 = $CREATE_FN_NAME(input.1, input.1); + let result = interval1.add(&interval2).unwrap(); + let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); + assert!( + (!expect_low || result.lower < without_fe.lower) + && (!expect_high || result.upper > without_fe.upper) + ); + } + }; + } + + capture_mode_change!(f32); + capture_mode_change!(f64); + + #[cfg(all( + any(target_arch = "x86_64", target_arch = "aarch64"), + not(target_os = "windows") + ))] + #[test] + fn test_add_intervals_lower_affected_f32() { + // Lower is affected + let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 + let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 + capture_mode_change_f32((lower, upper), true, false); + + // Upper is affected + let lower = f32::from_bits(1072693248); //111111111100000000000000000000 + let upper = f32::from_bits(715827883); //101010101010101010101010101011 + capture_mode_change_f32((lower, upper), false, true); + + // Lower is affected + let lower = 1.0; // 0x3FF0000000000000 + let upper = 0.3; // 0x3FD3333333333333 + capture_mode_change_f64((lower, upper), true, false); + + // Upper is affected + let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF + let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F + capture_mode_change_f64((lower, upper), false, true); + } + + #[cfg(any( + not(any(target_arch = "x86_64", target_arch = "aarch64")), + target_os = "windows" + ))] + #[test] + fn test_next_impl_add_intervals_f64() { + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f64((lower, upper), true, true); + + let lower = 1.5; + let upper = 1.5; + capture_mode_change_f32((lower, upper), true, true); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5945480aba1d8..6172d17365adc 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -26,10 +26,20 @@ //! The [expr_fn] module contains functions for creating expressions. mod accumulator; -pub mod aggregate_function; -pub mod array_expressions; mod built_in_function; mod columnar_value; +mod literal; +mod nullif; +mod operator; +mod partition_evaluator; +mod signature; +mod table_source; +mod udaf; +mod udf; +mod udwf; + +pub mod aggregate_function; +pub mod array_expressions; pub mod conditional_expressions; pub mod expr; pub mod expr_fn; @@ -37,43 +47,42 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -pub mod function_err; -mod literal; +pub mod interval_arithmetic; pub mod logical_plan; -mod nullif; -mod operator; -mod signature; pub mod struct_expressions; -mod table_source; pub mod tree_node; pub mod type_coercion; -mod udaf; -mod udf; pub mod utils; pub mod window_frame; pub mod window_function; +pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::ColumnarValue; pub use expr::{ - Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, GroupingSet, Like, TryCast, + Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + Like, ScalarFunctionDefinition, TryCast, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; pub use function::{ - AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation, - StateTypeFunction, + AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, + ScalarFunctionImplementation, StateTypeFunction, }; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; -pub use signature::{Signature, TypeSignature, Volatility}; +pub use partition_evaluator::PartitionEvaluator; +pub use signature::{ + FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, +}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; +pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index dc7412b5946c2..2f04729af2edb 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -43,19 +43,19 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(*self)) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::from(self.as_ref())) } } @@ -88,6 +88,17 @@ macro_rules! make_literal { }; } +macro_rules! make_nonzero_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl Literal for $TYPE { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + } + } + }; +} + macro_rules! make_timestamp_literal { ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { #[doc = $DOC] @@ -114,6 +125,47 @@ make_literal!(u16, UInt16, "literal expression containing a u16"); make_literal!(u32, UInt32, "literal expression containing a u32"); make_literal!(u64, UInt64, "literal expression containing a u64"); +make_nonzero_literal!( + std::num::NonZeroI8, + Int8, + "literal expression containing an i8" +); +make_nonzero_literal!( + std::num::NonZeroI16, + Int16, + "literal expression containing an i16" +); +make_nonzero_literal!( + std::num::NonZeroI32, + Int32, + "literal expression containing an i32" +); +make_nonzero_literal!( + std::num::NonZeroI64, + Int64, + "literal expression containing an i64" +); +make_nonzero_literal!( + std::num::NonZeroU8, + UInt8, + "literal expression containing a u8" +); +make_nonzero_literal!( + std::num::NonZeroU16, + UInt16, + "literal expression containing a u16" +); +make_nonzero_literal!( + std::num::NonZeroU32, + UInt32, + "literal expression containing a u32" +); +make_nonzero_literal!( + std::num::NonZeroU64, + UInt64, + "literal expression containing a u64" +); + make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); @@ -124,10 +176,19 @@ make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); #[cfg(test)] mod test { + use std::num::NonZeroU32; + use super::*; use crate::expr_fn::col; use datafusion_common::ScalarValue; + #[test] + fn test_lit_nonzero() { + let expr = col("id").eq(lit(NonZeroU32::new(1).unwrap())); + let expected = col("id").eq(lit(ScalarValue::UInt32(Some(1)))); + assert_eq!(expr, expected); + } + #[test] fn test_lit_timestamp_nano() { let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 606b990cfe9ea..be2c45b901fa7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -17,37 +17,43 @@ //! This module provides a builder for creating LogicalPlans +use std::any::Any; +use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; +use std::convert::TryFrom; +use std::iter::zip; +use std::sync::Arc; + +use crate::dml::{CopyOptions, CopyTo}; +use crate::expr::Alias; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, rewrite_sort_cols_by_aggs, }; +use crate::logical_plan::{ + Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, + Window, +}; use crate::type_coercion::binary::comparison_coercion; -use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan}; -use crate::{and, binary_expr, DmlStatement, Operator, WriteOp}; +use crate::utils::{ + can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard, + expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, +}; use crate::{ - logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, ToStringifiedPlan, - Union, Unnest, Values, Window, - }, - utils::{ - can_hash, expand_qualified_wildcard, expand_wildcard, - find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, - }, - Expr, ExprSchemable, TableSource, + and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, + TableProviderFilterPushDown, TableSource, WriteOp, }; + use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, - ScalarValue, TableReference, ToDFSchema, + get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, + DFSchema, DFSchemaRef, DataFusionError, FileType, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; -use std::any::Any; -use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; -use std::sync::Arc; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -126,13 +132,11 @@ impl LogicalPlanBuilder { /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. pub fn values(mut values: Vec>) -> Result { if values.is_empty() { - return Err(DataFusionError::Plan("Values list cannot be empty".into())); + return plan_err!("Values list cannot be empty"); } let n_cols = values[0].len(); if n_cols == 0 { - return Err(DataFusionError::Plan( - "Values list cannot be zero length".into(), - )); + return plan_err!("Values list cannot be zero length"); } let empty_schema = DFSchema::empty(); let mut field_types: Vec> = Vec::with_capacity(n_cols); @@ -143,12 +147,12 @@ impl LogicalPlanBuilder { let mut nulls: Vec<(usize, usize)> = Vec::new(); for (i, row) in values.iter().enumerate() { if row.len() != n_cols { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Inconsistent data length across values list: got {} values in row {} but expected {}", row.len(), i, n_cols - ))); + ); } field_types = row .iter() @@ -161,8 +165,7 @@ impl LogicalPlanBuilder { let data_type = expr.get_type(&empty_schema)?; if let Some(prev_data_type) = &field_types[j] { if prev_data_type != &data_type { - let err = format!("Inconsistent data type across values list at row {i} column {j}"); - return Err(DataFusionError::Plan(err)); + return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_data_type} but found {data_type}") } } Ok(Some(data_type)) @@ -231,17 +234,42 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } + /// Create a [CopyTo] for copying the contents of this builder to the specified file(s) + pub fn copy_to( + input: LogicalPlan, + output_url: String, + file_format: FileType, + single_file_output: bool, + copy_options: CopyOptions, + ) -> Result { + Ok(Self::from(LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url, + file_format, + single_file_output, + copy_options, + }))) + } + /// Create a [DmlStatement] for inserting the contents of this builder into the named table pub fn insert_into( input: LogicalPlan, table_name: impl Into, table_schema: &Schema, + overwrite: bool, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; + + let op = if overwrite { + WriteOp::InsertOverwrite + } else { + WriteOp::InsertInto + }; + Ok(Self::from(LogicalPlan::Dml(DmlStatement { table_name: table_name.into(), table_schema, - op: WriteOp::Insert, + op, input: Arc::new(input), }))) } @@ -253,44 +281,9 @@ impl LogicalPlanBuilder { projection: Option>, filters: Vec, ) -> Result { - let table_name = table_name.into(); - - if table_name.table().is_empty() { - return Err(DataFusionError::Plan( - "table_name cannot be empty".to_string(), - )); - } - - let schema = table_source.schema(); - - let projected_schema = projection - .as_ref() - .map(|p| { - DFSchema::new_with_metadata( - p.iter() - .map(|i| { - DFField::from_qualified( - table_name.clone(), - schema.field(*i).clone(), - ) - }) - .collect(), - schema.metadata().clone(), - ) - }) - .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema) - })?; - - let table_scan = LogicalPlan::TableScan(TableScan { - table_name, - source: table_source, - projected_schema: Arc::new(projected_schema), - projection, - filters, - fetch: None, - }); - Ok(Self::from(table_scan)) + TableScan::try_new(table_name, table_source, projection, filters, None) + .map(LogicalPlan::TableScan) + .map(Self::from) } /// Wrap a plan in a window @@ -335,7 +328,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - Ok(Self::from(project(self.plan, expr)?)) + project(self.plan, expr).map(Self::from) } /// Select the given column indices @@ -351,10 +344,9 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Ok(Self::from(LogicalPlan::Filter(Filter::try_new( - expr, - Arc::new(self.plan), - )?))) + Filter::try_new(expr, Arc::new(self.plan)) + .map(LogicalPlan::Filter) + .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan @@ -382,7 +374,7 @@ impl LogicalPlanBuilder { /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - Ok(Self::from(subquery_alias(self.plan, alias)?)) + subquery_alias(self.plan, alias).map(Self::from) } /// Add missing sort columns to all downstream projection @@ -437,7 +429,7 @@ impl LogicalPlanBuilder { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } expr.extend(missing_exprs); - Ok(project((*input).clone(), expr)?) + project((*input).clone(), expr) } _ => { let is_distinct = @@ -453,9 +445,7 @@ impl LogicalPlanBuilder { ) }) .collect::>>()?; - - let expr = curr_plan.expressions(); - from_plan(&curr_plan, &expr, &new_inputs) + curr_plan.with_new_inputs(&new_inputs) } } } @@ -478,7 +468,7 @@ impl LogicalPlanBuilder { // As described in https://github.com/apache/arrow-datafusion/issues/5293 let all_aliases = missing_exprs.iter().all(|e| { projection_exprs.iter().any(|proj_expr| { - if let Expr::Alias(expr, _) = proj_expr { + if let Expr::Alias(Alias { expr, .. }) = proj_expr { e == expr.as_ref() } else { false @@ -494,9 +484,7 @@ impl LogicalPlanBuilder { .map(|col| col.flat_name()) .collect::(); - Err(DataFusionError::Plan(format!( - "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list", - ))) + plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } /// Apply a sort @@ -548,46 +536,59 @@ impl LogicalPlanBuilder { fetch: None, }); - Ok(Self::from(LogicalPlan::Projection(Projection::try_new( - new_expr, - Arc::new(sort_plan), - )?))) + Projection::try_new(new_expr, Arc::new(sort_plan)) + .map(LogicalPlan::Projection) + .map(Self::from) } /// Apply a union, preserving duplicate rows pub fn union(self, plan: LogicalPlan) -> Result { - Ok(Self::from(union(self.plan, plan)?)) + union(self.plan, plan).map(Self::from) } /// Apply a union, removing duplicate rows pub fn union_distinct(self, plan: LogicalPlan) -> Result { - // unwrap top-level Distincts, to avoid duplication - let left_plan: LogicalPlan = match self.plan { - LogicalPlan::Distinct(Distinct { input }) => (*input).clone(), - _ => self.plan, - }; - let right_plan: LogicalPlan = match plan { - LogicalPlan::Distinct(Distinct { input }) => (*input).clone(), - _ => plan, - }; + let left_plan: LogicalPlan = self.plan; + let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(union(left_plan, right_plan)?), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + union(left_plan, right_plan)?, + ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(self.plan), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + self.plan, + ))))) } - /// Apply a join with on constraint. + /// Project first values of the specified expression list according to the provided + /// sorting expressions grouped by the `DISTINCT ON` clause expressions. + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + Ok(Self::from(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + )))) + } + + /// Apply a join to `right` using explicitly specified columns and an + /// optional filter expression. /// - /// Filter expression expected to contain non-equality predicates that can not be pushed - /// down to any of join inputs. - /// In case of outer join, filter applied to only matched rows. + /// See [`join_on`](Self::join_on) for a more concise way to specify the + /// join condition. Since DataFusion will automatically identify and + /// optimize equality predicates there is no performance difference between + /// this function and `join_on` + /// + /// `left_cols` and `right_cols` are used to form "equijoin" predicates (see + /// example below), which are then combined with the optional `filter` + /// expression. + /// + /// Note that in case of outer join, the `filter` is applied to only matched rows. pub fn join( self, right: LogicalPlan, @@ -598,6 +599,63 @@ impl LogicalPlanBuilder { self.join_detailed(right, join_type, join_keys, filter, false) } + /// Apply a join with using the specified expressions. + /// + /// Note that DataFusion automatically optimizes joins, including + /// identifying and optimizing equality predicates. + /// + /// # Example + /// + /// ``` + /// # use datafusion_expr::{Expr, col, LogicalPlanBuilder, + /// # logical_plan::builder::LogicalTableSource, logical_plan::JoinType,}; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{Schema, DataType, Field}; + /// # use datafusion_common::Result; + /// # fn main() -> Result<()> { + /// let example_schema = Arc::new(Schema::new(vec![ + /// Field::new("a", DataType::Int32, false), + /// Field::new("b", DataType::Int32, false), + /// Field::new("c", DataType::Int32, false), + /// ])); + /// let table_source = Arc::new(LogicalTableSource::new(example_schema)); + /// let left_table = table_source.clone(); + /// let right_table = table_source.clone(); + /// + /// let right_plan = LogicalPlanBuilder::scan("right", right_table, None)?.build()?; + /// + /// // Form the expression `(left.a != right.a)` AND `(left.b != right.b)` + /// let exprs = vec![ + /// col("left.a").eq(col("right.a")), + /// col("left.b").not_eq(col("right.b")) + /// ]; + /// + /// // Perform the equivalent of `left INNER JOIN right ON (a != a2 AND b != b2)` + /// // finding all pairs of rows from `left` and `right` where + /// // where `a = a2` and `b != b2`. + /// let plan = LogicalPlanBuilder::scan("left", left_table, None)? + /// .join_on(right_plan, JoinType::Inner, exprs)? + /// .build()?; + /// # Ok(()) + /// # } + /// ``` + pub fn join_on( + self, + right: LogicalPlan, + join_type: JoinType, + on_exprs: impl IntoIterator, + ) -> Result { + let filter = on_exprs.into_iter().reduce(Expr::and); + + self.join_detailed( + right, + join_type, + (Vec::::new(), Vec::::new()), + filter, + false, + ) + } + pub(crate) fn normalize( plan: &LogicalPlan, column: impl Into + Clone, @@ -611,8 +669,14 @@ impl LogicalPlanBuilder { ) } - /// Apply a join with on constraint and specified null equality - /// If null_equals_null is true then null == null, else null != null + /// Apply a join with on constraint and specified null equality. + /// + /// The behavior is the same as [`join`](Self::join) except that it allows + /// specifying the null equality behavior. + /// + /// If `null_equals_null=true`, rows where both join keys are `null` will be + /// emitted. Otherwise rows where either or both join keys are `null` will be + /// omitted. pub fn join_detailed( self, right: LogicalPlan, @@ -622,9 +686,7 @@ impl LogicalPlanBuilder { null_equals_null: bool, ) -> Result { if join_keys.0.len() != join_keys.1.len() { - return Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )); + return plan_err!("left_keys and right_keys were not the same length"); } let filter = if let Some(expr) = filter { @@ -642,7 +704,7 @@ impl LogicalPlanBuilder { join_keys .0 .into_iter() - .zip(join_keys.1.into_iter()) + .zip(join_keys.1) .map(|(l, r)| { let l = l.into(); let r = r.into(); @@ -719,7 +781,7 @@ impl LogicalPlanBuilder { let on = left_keys .into_iter() - .zip(right_keys.into_iter()) + .zip(right_keys) .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) .collect(); let join_schema = @@ -754,7 +816,7 @@ impl LogicalPlanBuilder { .map(|c| Self::normalize(&right, c)) .collect::>()?; - let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect(); let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; let mut join_on: Vec<(Expr, Expr)> = vec![]; @@ -804,11 +866,12 @@ impl LogicalPlanBuilder { /// Apply a cross join pub fn cross_join(self, right: LogicalPlan) -> Result { - let schema = self.plan.schema().join(right.schema())?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; Ok(Self::from(LogicalPlan::CrossJoin(CrossJoin { left: Arc::new(self.plan), right: Arc::new(right), - schema: DFSchemaRef::new(schema), + schema: DFSchemaRef::new(join_schema), }))) } @@ -826,17 +889,11 @@ impl LogicalPlanBuilder { window_expr: impl IntoIterator>, ) -> Result { let window_expr = normalize_cols(window_expr, &self.plan)?; - let all_expr = window_expr.iter(); - validate_unique_names("Windows", all_expr.clone())?; - let mut window_fields: Vec = self.plan.schema().fields().clone(); - window_fields.extend_from_slice(&exprlist_to_fields(all_expr, &self.plan)?); - let metadata = self.plan.schema().metadata().clone(); - - Ok(Self::from(LogicalPlan::Window(Window { - input: Arc::new(self.plan), + validate_unique_names("Windows", &window_expr)?; + Ok(Self::from(LogicalPlan::Window(Window::try_new( window_expr, - schema: Arc::new(DFSchema::new_with_metadata(window_fields, metadata)?), - }))) + Arc::new(self.plan), + )?))) } /// Apply an aggregate: grouping on the `group_expr` expressions @@ -847,13 +904,30 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let group_expr = normalize_cols(group_expr, &self.plan)?; + let mut group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(self.plan), - group_expr, - aggr_expr, - )?))) + + // Rewrite groupby exprs according to functional dependencies + let group_by_expr_names = group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let schema = self.plan.schema(); + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_expr_names) + { + for idx in target_indices { + let field = schema.field(idx); + let expr = + Expr::Column(Column::new(field.qualifier().cloned(), field.name())); + if !group_expr.contains(&expr) { + group_expr.push(expr); + } + } + } + Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + .map(LogicalPlan::Aggregate) + .map(Self::from) } /// Create an expression to represent the explanation of the plan @@ -925,9 +999,9 @@ impl LogicalPlanBuilder { let right_len = right_plan.schema().fields().len(); if left_len != right_len { - return Err(DataFusionError::Plan(format!( + return plan_err!( "INTERSECT/EXCEPT query must have the same number of columns. Left is {left_len} and right is {right_len}." - ))); + ); } let join_keys = left_plan @@ -973,9 +1047,7 @@ impl LogicalPlanBuilder { filter: Option, ) -> Result { if equi_exprs.0.len() != equi_exprs.1.len() { - return Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )); + return plan_err!("left_keys and right_keys were not the same length"); } let join_key_pairs = equi_exprs @@ -1007,9 +1079,9 @@ impl LogicalPlanBuilder { self.plan.schema().clone(), right.schema().clone(), )?.ok_or_else(|| - DataFusionError::Plan(format!( + plan_datafusion_err!( "can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})" - ))) + )) }) .collect::>>()?; @@ -1032,6 +1104,19 @@ impl LogicalPlanBuilder { pub fn unnest_column(self, column: impl Into) -> Result { Ok(Self::from(unnest(self.plan, column.into())?)) } + + /// Unnest the given column given [`UnnestOptions`] + pub fn unnest_column_with_options( + self, + column: impl Into, + options: UnnestOptions, + ) -> Result { + Ok(Self::from(unnest_with_options( + self.plan, + column.into(), + options, + )?)) + } } /// Creates a schema for a join operation. @@ -1093,10 +1178,15 @@ pub fn build_join_schema( right_fields.clone() } }; - + let func_dependencies = left.functional_dependencies().join( + right.functional_dependencies(), + join_type, + left_fields.len(), + ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) + let schema = DFSchema::new_with_metadata(fields, metadata)?; + schema.with_functional_dependencies(func_dependencies) } /// Errors if one or more expressions have equal names. @@ -1113,12 +1203,10 @@ pub(crate) fn validate_unique_names<'a>( Ok(()) }, Some((existing_position, existing_expr)) => { - Err(DataFusionError::Plan( - format!("{node_name} require unique expression names \ - but the expression \"{existing_expr:?}\" at position {existing_position} and \"{expr:?}\" \ - at position {position} have the same name. Consider aliasing (\"AS\") one of them.", + plan_err!("{node_name} require unique expression names \ + but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \ + at position {position} have the same name. Consider aliasing (\"AS\") one of them." ) - )) } } }) @@ -1133,7 +1221,7 @@ pub fn project_with_column_index( .into_iter() .enumerate() .map(|(i, e)| match e { - Expr::Alias(_, ref name) if name != schema.field(i).name() => { + Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { e.unalias().alias(schema.field(i).name()) } Expr::Column(Column { @@ -1145,9 +1233,8 @@ pub fn project_with_column_index( }) .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - alias_expr, input, schema, - )?)) + Projection::try_new_with_schema(alias_expr, input, schema) + .map(LogicalPlan::Projection) } /// Union two logical plans. @@ -1157,45 +1244,41 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result>>()? - .to_dfschema()?; + left_field.data_type() + ) + })?; + + Ok(DFField::new( + left_field.qualifier().cloned(), + left_field.name(), + data_type, + nullable, + )) + }) + .collect::>>()? + .to_dfschema()?; let inputs = vec![left_plan, right_plan] .into_iter() - .flat_map(|p| match p { - LogicalPlan::Union(Union { inputs, .. }) => inputs, - other_plan => vec![Arc::new(other_plan)], - }) .map(|p| { let plan = coerce_plan_expr_for_schema(&p, &union_schema)?; match plan { @@ -1212,7 +1295,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result>>()?; if inputs.is_empty() { - return Err(DataFusionError::Plan("Empty UNION".to_string())); + return plan_err!("Empty UNION"); } Ok(LogicalPlan::Union(Union { @@ -1230,31 +1313,29 @@ pub fn project( plan: LogicalPlan, expr: impl IntoIterator>, ) -> Result { + // TODO: move it into analyzer let input_schema = plan.schema(); let mut projected_expr = vec![]; for e in expr { let e = e.into(); match e { - Expr::Wildcard => { + Expr::Wildcard { qualifier: None } => { projected_expr.extend(expand_wildcard(input_schema, &plan, None)?) } - Expr::QualifiedWildcard { ref qualifier } => projected_expr - .extend(expand_qualified_wildcard(qualifier, input_schema, None)?), + Expr::Wildcard { + qualifier: Some(qualifier), + } => projected_expr.extend(expand_qualified_wildcard( + &qualifier, + input_schema, + None, + )?), _ => projected_expr .push(columnize_expr(normalize_col(e, &plan)?, input_schema)), } } validate_unique_names("Projections", projected_expr.iter())?; - let input_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projected_expr, &plan)?, - plan.schema().metadata().clone(), - )?; - - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - projected_expr, - Arc::new(plan.clone()), - DFSchemaRef::new(input_schema), - )?)) + + Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) } /// Create a SubqueryAlias to wrap a LogicalPlan. @@ -1262,9 +1343,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - plan, alias, - )?)) + SubqueryAlias::try_new(plan, alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1317,7 +1396,7 @@ pub fn wrap_projection_for_join_if_necessary( // then a and cast(a as int) will use the same field name - `a` in projection schema. // https://github.com/apache/arrow-datafusion/issues/4478 if matches!(key, Expr::Cast(_)) || matches!(key, Expr::TryCast(_)) { - let alias = format!("{key:?}"); + let alias = format!("{key}"); key.clone().alias(alias) } else { key.clone() @@ -1375,10 +1454,26 @@ impl TableSource for LogicalTableSource { fn schema(&self) -> SchemaRef { self.table_schema.clone() } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) + } } -/// Create an unnest plan. +/// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, column: Column) -> Result { + unnest_with_options(input, column, UnnestOptions::new()) +} + +/// Create a [`LogicalPlan::Unnest`] plan with options +pub fn unnest_with_options( + input: LogicalPlan, + column: Column, + options: UnnestOptions, +) -> Result { let unnest_field = input.schema().field_from_column(&column)?; // Extract the type of the nested field in the list. @@ -1411,28 +1506,28 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result { }) .collect::>(); - let schema = Arc::new(DFSchema::new_with_metadata( - fields, - input_schema.metadata().clone(), - )?); + let metadata = input_schema.metadata().clone(); + let df_schema = DFSchema::new_with_metadata(fields, metadata)?; + // We can use the existing functional dependencies: + let deps = input_schema.functional_dependencies().clone(); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), column: unnested_field.qualified_column(), schema, + options, })) } #[cfg(test)] mod tests { - use crate::{expr, expr_fn::exists}; - use arrow::datatypes::{DataType, Field}; - use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; - + use super::*; use crate::logical_plan::StringifiedPlan; + use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery, sum}; - use super::*; - use crate::{col, in_subquery, lit, scalar_subquery, sum}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; #[test] fn plan_builder_simple() -> Result<()> { @@ -1481,7 +1576,7 @@ mod tests { let err = LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); assert_eq!( - err.to_string(), + err.strip_backtrace(), "Error during planning: table_name cannot be empty" ); } @@ -1532,7 +1627,7 @@ mod tests { let plan = table_scan(Some("t1"), &employee_schema(), None)? .join_using(t2, JoinType::Inner, vec!["id"])? - .project(vec![Expr::Wildcard])? + .project(vec![Expr::Wildcard { qualifier: None }])? .build()?; // id column should only show up once in projection @@ -1547,7 +1642,7 @@ mod tests { } #[test] - fn plan_builder_union_combined_single_union() -> Result<()> { + fn plan_builder_union() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; @@ -1558,11 +1653,12 @@ mod tests { .union(plan.build()?)? .build()?; - // output has only one union let expected = "Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ + \n Union\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; assert_eq!(expected, format!("{plan:?}")); @@ -1571,7 +1667,7 @@ mod tests { } #[test] - fn plan_builder_union_distinct_combined_single_union() -> Result<()> { + fn plan_builder_union_distinct() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))?; @@ -1582,13 +1678,16 @@ mod tests { .union_distinct(plan.build()?)? .build()?; - // output has only one union let expected = "\ Distinct:\ \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ + \n Distinct:\ + \n Union\ + \n Distinct:\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; assert_eq!(expected, format!("{plan:?}")); @@ -1607,8 +1706,8 @@ mod tests { let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err(); let err_msg2 = plan1.union_distinct(plan2.build()?).unwrap_err(); - assert_eq!(err_msg1.to_string(), expected); - assert_eq!(err_msg2.to_string(), expected); + assert_eq!(err_msg1.strip_backtrace(), expected); + assert_eq!(err_msg2.strip_backtrace(), expected); Ok(()) } @@ -1737,9 +1836,7 @@ mod tests { assert_eq!("id", &name); Ok(()) } - _ => Err(DataFusionError::Plan( - "Plan should have returned an DataFusionError::SchemaError".to_string(), - )), + _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), } } @@ -1766,9 +1863,7 @@ mod tests { assert_eq!("state", &name); Ok(()) } - _ => Err(DataFusionError::Plan( - "Plan should have returned an DataFusionError::SchemaError".to_string(), - )), + _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), } } @@ -1836,7 +1931,7 @@ mod tests { LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true) .unwrap_err(); - assert_eq!(err_msg1.to_string(), expected); + assert_eq!(err_msg1.strip_backtrace(), expected); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index e005f114719df..e74992d993734 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - parsers::CompressionTypeVariant, DFSchemaRef, OwnedTableReference, -}; -use datafusion_common::{Column, OwnedSchemaReference}; use std::collections::HashMap; use std::sync::Arc; use std::{ @@ -28,6 +24,11 @@ use std::{ use crate::{Expr, LogicalPlan}; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + Constraints, DFSchemaRef, OwnedSchemaReference, OwnedTableReference, +}; + /// Various types of DDL (CREATE / DROP) catalog manipulation #[derive(Clone, PartialEq, Eq, Hash)] pub enum DdlStatement { @@ -111,22 +112,17 @@ impl DdlStatement { match self.0 { DdlStatement::CreateExternalTable(CreateExternalTable { ref name, + constraints, .. }) => { - write!(f, "CreateExternalTable: {name:?}") + write!(f, "CreateExternalTable: {name:?}{constraints}") } DdlStatement::CreateMemoryTable(CreateMemoryTable { name, - primary_key, + constraints, .. }) => { - let pk: Vec = - primary_key.iter().map(|c| c.name.to_string()).collect(); - let mut pk = pk.join(", "); - if !pk.is_empty() { - pk = format!(" primary_key=[{pk}]"); - } - write!(f, "CreateMemoryTable: {name:?}{pk}") + write!(f, "CreateMemoryTable: {name:?}{constraints}") } DdlStatement::CreateView(CreateView { name, .. }) => { write!(f, "CreateView: {name:?}") @@ -196,6 +192,10 @@ pub struct CreateExternalTable { pub unbounded: bool, /// Table(provider) specific options pub options: HashMap, + /// The list of constraints in the schema, such as primary key, unique, etc. + pub constraints: Constraints, + /// Default values for columns + pub column_defaults: HashMap, } // Hashing refers to a subset of fields considered in PartialEq. @@ -222,14 +222,16 @@ impl Hash for CreateExternalTable { pub struct CreateMemoryTable { /// The table name pub name: OwnedTableReference, - /// The ordered list of columns in the primary key, or an empty vector if none - pub primary_key: Vec, + /// The list of constraints in the schema, such as primary key, unique, etc. + pub constraints: Constraints, /// The logical plan pub input: Arc, /// Option to not error if table already exists pub if_not_exists: bool, /// Option to replace table content if table already exists pub or_replace: bool, + /// Default values for columns + pub column_defaults: Vec<(String, Expr)>, } /// Creates a view. diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index c82689b2ccd74..112dbf74dba18 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -18,6 +18,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; +use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; use datafusion_common::DataFusionError; use std::fmt; @@ -123,37 +124,6 @@ pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { Wrapper(schema) } -/// Logic related to creating DOT language graphs. -#[derive(Default)] -struct GraphvizBuilder { - id_gen: usize, -} - -impl GraphvizBuilder { - fn next_id(&mut self) -> usize { - self.id_gen += 1; - self.id_gen - } - - // write out the start of the subgraph cluster - fn start_cluster(&mut self, f: &mut fmt::Formatter, title: &str) -> fmt::Result { - writeln!(f, " subgraph cluster_{}", self.next_id())?; - writeln!(f, " {{")?; - writeln!(f, " graph[label={}]", Self::quoted(title)) - } - - // write out the end of the subgraph cluster - fn end_cluster(&mut self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, " }}") - } - - /// makes a quoted string suitable for inclusion in a graphviz chart - fn quoted(label: &str) -> String { - let label = label.replace('"', "_"); - format!("\"{label}\"") - } -} - /// Formats plans for graphical display using the `DOT` language. This /// format can be visualized using software from /// [`graphviz`](https://graphviz.org/) @@ -190,6 +160,14 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { pub fn post_visit_plan(&mut self) -> fmt::Result { self.graphviz_builder.end_cluster(self.f) } + + pub fn start_graph(&mut self) -> fmt::Result { + self.graphviz_builder.start_graph(self.f) + } + + pub fn end_graph(&mut self) -> fmt::Result { + self.graphviz_builder.end_graph(self.f) + } } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { @@ -213,22 +191,16 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { format!("{}", plan.display()) }; - writeln!( - self.f, - " {}[shape=box label={}]", - id, - GraphvizBuilder::quoted(&label) - ) - .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; + self.graphviz_builder + .add_node(self.f, id, &label, None) + .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; // Create an edge to our parent node, if any // parent_id -> id if let Some(parent_id) = self.parent_ids.last() { - writeln!( - self.f, - " {parent_id} -> {id} [arrowhead=none, arrowtail=normal, dir=back]" - ) - .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; + self.graphviz_builder + .add_edge(self.f, *parent_id, id) + .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; } self.parent_ids.push(id); diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 117a42cda9702..4cd56b89ac635 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -20,10 +20,70 @@ use std::{ sync::Arc, }; -use datafusion_common::{DFSchemaRef, OwnedTableReference}; +use datafusion_common::{ + file_options::StatementOptions, DFSchemaRef, FileType, FileTypeWriterOptions, + OwnedTableReference, +}; use crate::LogicalPlan; +/// Operator that copies the contents of a database to file(s) +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct CopyTo { + /// The relation that determines the tuples to write to the output file(s) + pub input: Arc, + /// The location to write the file(s) + pub output_url: String, + /// The file format to output (explicitly defined or inferred from file extension) + pub file_format: FileType, + /// If false, it is assumed output_url is a file to which all data should be written + /// regardless of input partitioning. Otherwise, output_url is assumed to be a directory + /// to which each output partition is written to its own output file + pub single_file_output: bool, + /// Arbitrary options as tuples + pub copy_options: CopyOptions, +} + +/// When the logical plan is constructed from SQL, CopyOptions +/// will contain arbitrary string tuples which must be parsed into +/// FileTypeWriterOptions. When the logical plan is constructed directly +/// from rust code (such as via the DataFrame API), FileTypeWriterOptions +/// can be provided directly, avoiding the run time cost and fallibility of +/// parsing string based options. +#[derive(Clone)] +pub enum CopyOptions { + /// Holds StatementOptions parsed from a SQL statement + SQLOptions(StatementOptions), + /// Holds FileTypeWriterOptions directly provided + WriterOptions(Box), +} + +impl PartialEq for CopyOptions { + fn eq(&self, other: &CopyOptions) -> bool { + match self { + Self::SQLOptions(statement1) => match other { + Self::SQLOptions(statement2) => statement1.eq(statement2), + Self::WriterOptions(_) => false, + }, + Self::WriterOptions(_) => false, + } + } +} + +impl Eq for CopyOptions {} + +impl std::hash::Hash for CopyOptions { + fn hash(&self, hasher: &mut H) + where + H: std::hash::Hasher, + { + match self { + Self::SQLOptions(statement) => statement.hash(hasher), + Self::WriterOptions(_) => (), + } + } +} + /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) #[derive(Clone, PartialEq, Eq, Hash)] @@ -38,21 +98,37 @@ pub struct DmlStatement { pub input: Arc, } +impl DmlStatement { + /// Return a descriptive name of this [`DmlStatement`] + pub fn name(&self) -> &str { + self.op.name() + } +} + #[derive(Clone, PartialEq, Eq, Hash)] pub enum WriteOp { - Insert, + InsertOverwrite, + InsertInto, Delete, Update, Ctas, } -impl Display for WriteOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl WriteOp { + /// Return a descriptive name of this [`WriteOp`] + pub fn name(&self) -> &str { match self { - WriteOp::Insert => write!(f, "Insert"), - WriteOp::Delete => write!(f, "Delete"), - WriteOp::Update => write!(f, "Update"), - WriteOp::Ctas => write!(f, "Ctas"), + WriteOp::InsertOverwrite => "Insert Overwrite", + WriteOp::InsertInto => "Insert Into", + WriteOp::Delete => "Delete", + WriteOp::Update => "Update", + WriteOp::Ctas => "Ctas", } } } + +impl Display for WriteOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 01862c3d5427e..bc722dd69acea 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -18,7 +18,7 @@ pub mod builder; mod ddl; pub mod display; -mod dml; +pub mod dml; mod extension; mod plan; mod statement; @@ -33,10 +33,11 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, EmptyRelation, Explain, - Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, + JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, + Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, + ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e19b327785a2c..d74015bf094d2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,36 +17,44 @@ //! Logical plan types -use crate::expr::InSubquery; -use crate::expr::{Exists, Placeholder}; +use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use super::dml::CopyTo; +use super::DdlStatement; +use crate::dml::CopyOptions; +use crate::expr::{Alias, Exists, InSubquery, Placeholder, Sort as SortExpr}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, from_plan, + enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + split_conjunction, }; use crate::{ - build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, + build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, + ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, + TableSource, }; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeVisitor, VisitRecursion, + RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, + VisitRecursion, }; use datafusion_common::{ - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - Result, ScalarValue, + aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, + DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies, + OwnedTableReference, ParamValues, Result, UnnestOptions, }; -use std::collections::{HashMap, HashSet}; -use std::fmt::{self, Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -// backwards compatible +// backwards compatibility +pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; -use super::DdlStatement; - /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -64,61 +72,84 @@ pub enum LogicalPlan { /// expression (essentially a WHERE clause with a predicate /// expression). /// - /// Semantically, `` is evaluated for each row of the input; - /// If the value of `` is true, the input row is passed to - /// the output. If the value of `` is false, the row is - /// discarded. + /// Semantically, `` is evaluated for each row of the + /// input; If the value of `` is true, the input row is + /// passed to the output. If the value of `` is false + /// (or null), the row is discarded. Filter(Filter), - /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) + /// Windows input based on a set of window spec and window + /// function (e.g. SUM or RANK). This is used to implement SQL + /// window functions, and the `OVER` clause. Window(Window), /// Aggregates its input based on a set of grouping and aggregate - /// expressions (e.g. SUM). + /// expressions (e.g. SUM). This is used to implement SQL aggregates + /// and `GROUP BY`. Aggregate(Aggregate), - /// Sorts its input according to a list of sort expressions. + /// Sorts its input according to a list of sort expressions. This + /// is used to implement SQL `ORDER BY` Sort(Sort), - /// Join two logical plans on one or more join columns + /// Join two logical plans on one or more join columns. + /// This is used to implement SQL `JOIN` Join(Join), - /// Apply Cross Join to two logical plans + /// Apply Cross Join to two logical plans. + /// This is used to implement SQL `CROSS JOIN` CrossJoin(CrossJoin), - /// Repartition the plan based on a partitioning scheme. + /// Repartitions the input based on a partitioning scheme. This is + /// used to add parallelism and is sometimes referred to as an + /// "exchange" operator in other systems Repartition(Repartition), - /// Union multiple inputs + /// Union multiple inputs with the same schema into a single + /// output stream. This is used to implement SQL `UNION [ALL]` and + /// `INTERSECT [ALL]`. Union(Union), - /// Produces rows from a table provider by reference or from the context + /// Produces rows from a [`TableSource`], used to implement SQL + /// `FROM` tables or views. TableScan(TableScan), - /// Produces no rows: An empty relation with an empty schema + /// Produces no rows: An empty relation with an empty schema that + /// produces 0 or 1 row. This is used to implement SQL `SELECT` + /// that has no values in the `FROM` clause. EmptyRelation(EmptyRelation), - /// Subquery + /// Produces the output of running another query. This is used to + /// implement SQL subqueries Subquery(Subquery), /// Aliased relation provides, or changes, the name of a relation. SubqueryAlias(SubqueryAlias), /// Skip some number of rows, and then fetch some number of rows. Limit(Limit), - /// [`Statement`] + /// A DataFusion [`Statement`] such as `SET VARIABLE` or `START TRANSACTION` Statement(Statement), /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) - /// documentation for more details. + /// documentation for more details. This is used to implement SQL such as + /// `VALUES (1, 2), (3, 4)` Values(Values), /// Produces a relation with string representations of - /// various parts of the plan + /// various parts of the plan. This is used to implement SQL `EXPLAIN`. Explain(Explain), - /// Runs the actual plan, and then prints the physical plan with - /// with execution metrics. + /// Runs the input, and prints annotated physical plan as a string + /// with execution metric. This is used to implement SQL + /// `EXPLAIN ANALYZE`. Analyze(Analyze), - /// Extension operator defined outside of DataFusion + /// Extension operator defined outside of DataFusion. This is used + /// to extend DataFusion with custom relational operations that Extension(Extension), - /// Remove duplicate rows from the input + /// Remove duplicate rows from the input. This is used to + /// implement SQL `SELECT DISTINCT ...`. Distinct(Distinct), - /// Prepare a statement + /// Prepare a statement and find any bind parameters + /// (e.g. `?`). This is used to implement SQL-prepared statements. Prepare(Prepare), - /// Insert / Update / Delete + /// Data Manipulaton Language (DML): Insert / Update / Delete Dml(DmlStatement), - /// CREATE / DROP TABLES / VIEWS / SCHEMAs + /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), - /// Describe the schema of table + /// `COPY TO` for writing plan results to files + Copy(CopyTo), + /// Describe the schema of the table. This is used to implement the + /// SQL `DESCRIBE` command from MySQL. DescribeTable(DescribeTable), - /// Unnest a column that contains a nested list type. + /// Unnest a column that contains a nested list type such as an + /// ARRAY. This is used to implement SQL `UNNEST` Unnest(Unnest), } @@ -133,7 +164,8 @@ impl LogicalPlan { }) => projected_schema, LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct { input }) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, LogicalPlan::Window(Window { schema, .. }) => schema, LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), @@ -149,10 +181,11 @@ impl LogicalPlan { LogicalPlan::Analyze(analyze) => &analyze.schema, LogicalPlan::Extension(extension) => extension.node.schema(), LogicalPlan::Union(Union { schema, .. }) => schema, - LogicalPlan::DescribeTable(DescribeTable { dummy_schema, .. }) => { - dummy_schema + LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => { + output_schema } LogicalPlan::Dml(DmlStatement { table_schema, .. }) => table_schema, + LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, } @@ -199,6 +232,7 @@ impl LogicalPlan { | LogicalPlan::EmptyRelation(_) | LogicalPlan::Ddl(_) | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) | LogicalPlan::Values(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Union(_) @@ -229,6 +263,15 @@ impl LogicalPlan { ])) } + /// Returns the (fixed) output schema for `DESCRIBE` plans + pub fn describe_schema() -> Schema { + Schema::new(vec![ + Field::new("column_name", DataType::Utf8, false), + Field::new("data_type", DataType::Utf8, false), + Field::new("is_nullable", DataType::Utf8, false), + ]) + } + /// returns all expressions (non-recursively) in the current /// logical plan node. This does not include expressions in any /// children @@ -326,6 +369,16 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.clone().unwrap_or(vec![]).iter()) + .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -336,9 +389,10 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) + | LogicalPlan::Distinct(Distinct::All(_)) | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Prepare(_) => Ok(()), } @@ -363,10 +417,13 @@ impl LogicalPlan { LogicalPlan::Union(Union { inputs, .. }) => { inputs.iter().map(|arc| arc.as_ref()).collect() } - LogicalPlan::Distinct(Distinct { input }) => vec![input], + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => vec![input], LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::Dml(write) => vec![&write.input], + LogicalPlan::Copy(copy) => vec![©.input], LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], @@ -405,44 +462,567 @@ impl LogicalPlan { Ok(using_columns) } + /// returns the first output expression of this `LogicalPlan` node. + pub fn head_output_expr(&self) -> Result> { + match self { + LogicalPlan::Projection(projection) => { + Ok(Some(projection.expr.as_slice()[0].clone())) + } + LogicalPlan::Aggregate(agg) => { + if agg.group_expr.is_empty() { + Ok(Some(agg.aggr_expr.as_slice()[0].clone())) + } else { + Ok(Some(agg.group_expr.as_slice()[0].clone())) + } + } + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + Ok(Some(select_expr[0].clone())) + } + LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Distinct(Distinct::All(input)) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), + LogicalPlan::Join(Join { + left, + right, + join_type, + .. + }) => match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + if left.schema().fields().is_empty() { + right.head_output_expr() + } else { + left.head_output_expr() + } + } + JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), + JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), + }, + LogicalPlan::CrossJoin(cross) => { + if cross.left.schema().fields().is_empty() { + cross.right.head_output_expr() + } else { + cross.left.head_output_expr() + } + } + LogicalPlan::Union(union) => Ok(Some(Expr::Column( + union.schema.fields()[0].qualified_column(), + ))), + LogicalPlan::TableScan(table) => Ok(Some(Expr::Column( + table.projected_schema.fields()[0].qualified_column(), + ))), + LogicalPlan::SubqueryAlias(subquery_alias) => { + let expr_opt = subquery_alias.input.head_output_expr()?; + expr_opt + .map(|expr| { + Ok(Expr::Column(create_col_from_scalar_expr( + &expr, + subquery_alias.alias.to_string(), + )?)) + }) + .map_or(Ok(None), |v| v.map(Some)) + } + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Prepare(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Unnest(_) => Ok(None), + } + } + + /// Returns a copy of this `LogicalPlan` with the new inputs pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - from_plan(self, &self.expressions(), inputs) + // with_new_inputs use original expression, + // so we don't need to recompute Schema. + match &self { + LogicalPlan::Projection(projection) => { + // Schema of the projection may change + // when its input changes. Hence we should use + // `try_new` method instead of `try_new_with_schema`. + Projection::try_new(projection.expr.to_vec(), Arc::new(inputs[0].clone())) + .map(LogicalPlan::Projection) + } + LogicalPlan::Window(Window { window_expr, .. }) => Ok(LogicalPlan::Window( + Window::try_new(window_expr.to_vec(), Arc::new(inputs[0].clone()))?, + )), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => Aggregate::try_new( + // Schema of the aggregate may change + // when its input changes. Hence we should use + // `try_new` method instead of `try_new_with_schema`. + Arc::new(inputs[0].clone()), + group_expr.to_vec(), + aggr_expr.to_vec(), + ) + .map(LogicalPlan::Aggregate), + _ => self.with_new_exprs(self.expressions(), inputs), + } } - /// Convert a prepared [`LogicalPlan`] into its inner logical plan - /// with all params replaced with their corresponding values - pub fn with_param_values( - self, - param_values: Vec, + /// Returns a new `LogicalPlan` based on `self` with inputs and + /// expressions replaced. + /// + /// The exprs correspond to the same order of expressions returned + /// by [`Self::expressions`]. This function is used by optimizers + /// to rewrite plans using the following pattern: + /// + /// ```text + /// let new_inputs = optimize_children(..., plan, props); + /// + /// // get the plans expressions to optimize + /// let exprs = plan.expressions(); + /// + /// // potentially rewrite plan expressions + /// let rewritten_exprs = rewrite_exprs(exprs); + /// + /// // create new plan using rewritten_exprs in same position + /// let new_plan = plan.new_with_exprs(rewritten_exprs, new_inputs); + /// ``` + /// + /// Note: sometimes [`Self::with_new_exprs`] will use schema of + /// original plan, it will not change the scheam. Such as + /// `Projection/Aggregate/Window` + pub fn with_new_exprs( + &self, + mut expr: Vec, + inputs: &[LogicalPlan], ) -> Result { match self { - LogicalPlan::Prepare(prepare_lp) => { - // Verify if the number of params matches the number of values - if prepare_lp.data_types.len() != param_values.len() { - return Err(DataFusionError::Internal(format!( - "Expected {} parameters, got {}", - prepare_lp.data_types.len(), - param_values.len() - ))); - } + // Since expr may be different than the previous expr, schema of the projection + // may change. We need to use try_new method instead of try_new_with_schema method. + LogicalPlan::Projection(Projection { .. }) => { + Projection::try_new(expr, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Projection) + } + LogicalPlan::Dml(DmlStatement { + table_name, + table_schema, + op, + .. + }) => Ok(LogicalPlan::Dml(DmlStatement { + table_name: table_name.clone(), + table_schema: table_schema.clone(), + op: op.clone(), + input: Arc::new(inputs[0].clone()), + })), + LogicalPlan::Copy(CopyTo { + input: _, + output_url, + file_format, + copy_options, + single_file_output, + }) => Ok(LogicalPlan::Copy(CopyTo { + input: Arc::new(inputs[0].clone()), + output_url: output_url.clone(), + file_format: file_format.clone(), + single_file_output: *single_file_output, + copy_options: copy_options.clone(), + })), + LogicalPlan::Values(Values { schema, .. }) => { + Ok(LogicalPlan::Values(Values { + schema: schema.clone(), + values: expr + .chunks_exact(schema.fields().len()) + .map(|s| s.to_vec()) + .collect::>(), + })) + } + LogicalPlan::Filter { .. } => { + assert_eq!(1, expr.len()); + let predicate = expr.pop().unwrap(); + + // filter predicates should not contain aliased expressions so we remove any aliases + // before this logic was added we would have aliases within filters such as for + // benchmark q6: + // + // lineitem.l_shipdate >= Date32(\"8766\") + // AND lineitem.l_shipdate < Date32(\"9131\") + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= + // Decimal128(Some(49999999999999),30,15) + // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= + // Decimal128(Some(69999999999999),30,15) + // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) + + struct RemoveAliases {} + + impl TreeNodeRewriter for RemoveAliases { + type N = Expr; + + fn pre_visit(&mut self, expr: &Expr) -> Result { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(RewriteRecursion::Stop) + } + Expr::Alias(_) => Ok(RewriteRecursion::Mutate), + _ => Ok(RewriteRecursion::Continue), + } + } - // Verify if the types of the params matches the types of the values - let iter = prepare_lp.data_types.iter().zip(param_values.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.get_datatype() { - return Err(DataFusionError::Internal(format!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.get_datatype(), - i - ))); + fn mutate(&mut self, expr: Expr) -> Result { + Ok(expr.unalias()) } } + let mut remove_aliases = RemoveAliases {}; + let predicate = predicate.rewrite(&mut remove_aliases)?; + + Filter::try_new(predicate, Arc::new(inputs[0].clone())) + .map(LogicalPlan::Filter) + } + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::RoundRobinBatch(n) => { + Ok(LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::RoundRobinBatch(*n), + input: Arc::new(inputs[0].clone()), + })) + } + Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::Hash(expr, *n), + input: Arc::new(inputs[0].clone()), + })), + Partitioning::DistributeBy(_) => { + Ok(LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::DistributeBy(expr), + input: Arc::new(inputs[0].clone()), + })) + } + }, + LogicalPlan::Window(Window { + window_expr, + schema, + .. + }) => { + assert_eq!(window_expr.len(), expr.len()); + Ok(LogicalPlan::Window(Window { + input: Arc::new(inputs[0].clone()), + window_expr: expr, + schema: schema.clone(), + })) + } + LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { + // group exprs are the first expressions + let agg_expr = expr.split_off(group_expr.len()); + + Aggregate::try_new(Arc::new(inputs[0].clone()), expr, agg_expr) + .map(LogicalPlan::Aggregate) + } + LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { + expr, + input: Arc::new(inputs[0].clone()), + fetch: *fetch, + })), + LogicalPlan::Join(Join { + join_type, + join_constraint, + on, + null_equals_null, + .. + }) => { + let schema = + build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; + + let equi_expr_count = on.len(); + assert!(expr.len() >= equi_expr_count); + + // Assume that the last expr, if any, + // is the filter_expr (non equality predicate from ON clause) + let filter_expr = if expr.len() > equi_expr_count { + expr.pop() + } else { + None + }; + + // The first part of expr is equi-exprs, + // and the struct of each equi-expr is like `left-expr = right-expr`. + assert_eq!(expr.len(), equi_expr_count); + let new_on:Vec<(Expr,Expr)> = expr.into_iter().map(|equi_expr| { + // SimplifyExpression rule may add alias to the equi_expr. + let unalias_expr = equi_expr.clone().unalias(); + if let Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) = unalias_expr { + Ok((*left, *right)) + } else { + internal_err!( + "The front part expressions should be an binary equality expression, actual:{equi_expr}" + ) + } + }).collect::>>()?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + join_type: *join_type, + join_constraint: *join_constraint, + on: new_on, + filter: filter_expr, + schema: DFSchemaRef::new(schema), + null_equals_null: *null_equals_null, + })) + } + LogicalPlan::CrossJoin(_) => { + let left = inputs[0].clone(); + let right = inputs[1].clone(); + LogicalPlanBuilder::from(left).cross_join(right)?.build() + } + LogicalPlan::Subquery(Subquery { + outer_ref_columns, .. + }) => { + let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?; + Ok(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(subquery), + outer_ref_columns: outer_ref_columns.clone(), + })) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { + SubqueryAlias::try_new(inputs[0].clone(), alias.clone()) + .map(LogicalPlan::SubqueryAlias) + } + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + Ok(LogicalPlan::Limit(Limit { + skip: *skip, + fetch: *fetch, + input: Arc::new(inputs[0].clone()), + })) + } + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + name, + if_not_exists, + or_replace, + column_defaults, + .. + })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + CreateMemoryTable { + input: Arc::new(inputs[0].clone()), + constraints: Constraints::empty(), + name: name.clone(), + if_not_exists: *if_not_exists, + or_replace: *or_replace, + column_defaults: column_defaults.clone(), + }, + ))), + LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + name, + or_replace, + definition, + .. + })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + input: Arc::new(inputs[0].clone()), + name: name.clone(), + or_replace: *or_replace, + definition: definition.clone(), + }))), + LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { + node: e.node.from_template(&expr, inputs), + })), + LogicalPlan::Union(Union { schema, .. }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema. + let schema = if schema.fields().len() == input_schema.fields().len() { + schema + } else { + input_schema + }; + Ok(LogicalPlan::Union(Union { + inputs: inputs.iter().cloned().map(Arc::new).collect(), + schema: schema.clone(), + })) + } + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), + Distinct::On(DistinctOn { + on_expr, + select_expr, + .. + }) => { + let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); + let select_expr = expr.split_off(on_expr.len()); + Distinct::On(DistinctOn::try_new( + expr, + select_expr, + if !sort_expr.is_empty() { + Some(sort_expr) + } else { + None + }, + Arc::new(inputs[0].clone()), + )?) + } + }; + Ok(LogicalPlan::Distinct(distinct)) + } + LogicalPlan::Analyze(a) => { + assert!(expr.is_empty()); + assert_eq!(inputs.len(), 1); + Ok(LogicalPlan::Analyze(Analyze { + verbose: a.verbose, + schema: a.schema.clone(), + input: Arc::new(inputs[0].clone()), + })) + } + LogicalPlan::Explain(e) => { + assert!( + expr.is_empty(), + "Invalid EXPLAIN command. Expression should empty" + ); + assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: Arc::new(inputs[0].clone()), + stringified_plans: e.stringified_plans.clone(), + schema: e.schema.clone(), + logical_optimization_succeeded: e.logical_optimization_succeeded, + })) + } + LogicalPlan::Prepare(Prepare { + name, data_types, .. + }) => Ok(LogicalPlan::Prepare(Prepare { + name: name.clone(), + data_types: data_types.clone(), + input: Arc::new(inputs[0].clone()), + })), + LogicalPlan::TableScan(ts) => { + assert!(inputs.is_empty(), "{self:?} should have no inputs"); + Ok(LogicalPlan::TableScan(TableScan { + filters: expr, + ..ts.clone() + })) + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Statement(_) => { + // All of these plan types have no inputs / exprs so should not be called + assert!(expr.is_empty(), "{self:?} should have no exprs"); + assert!(inputs.is_empty(), "{self:?} should have no inputs"); + Ok(self.clone()) + } + LogicalPlan::DescribeTable(_) => Ok(self.clone()), + LogicalPlan::Unnest(Unnest { + column, + schema, + options, + .. + }) => { + // Update schema with unnested column type. + let input = Arc::new(inputs[0].clone()); + let nested_field = input.schema().field_from_column(column)?; + let unnested_field = schema.field_from_column(column)?; + let fields = input + .schema() + .fields() + .iter() + .map(|f| { + if f == nested_field { + unnested_field.clone() + } else { + f.clone() + } + }) + .collect::>(); + + let schema = Arc::new( + DFSchema::new_with_metadata( + fields, + input.schema().metadata().clone(), + )? + // We can use the existing functional dependencies as is: + .with_functional_dependencies( + input.schema().functional_dependencies().clone(), + )?, + ); + + Ok(LogicalPlan::Unnest(Unnest { + input, + column: column.clone(), + schema, + options: options.clone(), + })) + } + } + } + /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] + /// with the specified `param_values`. + /// + /// [`LogicalPlan::Prepare`] are + /// converted to their inner logical plan for execution. + /// + /// # Example + /// ``` + /// # use arrow::datatypes::{Field, Schema, DataType}; + /// use datafusion_common::ScalarValue; + /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan, placeholder}; + /// # let schema = Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # ]); + /// // Build SELECT * FROM t1 WHRERE id = $1 + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$1"))).unwrap() + /// .build().unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = $1\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// // Fill in the parameter $1 with a literal 3 + /// let plan = plan.with_param_values(vec![ + /// ScalarValue::from(3i32) // value at index 0 --> $1 + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// // Note you can also used named parameters + /// // Build SELECT * FROM t1 WHRERE id = $my_param + /// let plan = table_scan(Some("t1"), &schema, None).unwrap() + /// .filter(col("id").eq(placeholder("$my_param"))).unwrap() + /// .build().unwrap() + /// // Fill in the parameter $my_param with a literal 3 + /// .with_param_values(vec![ + /// ("my_param", ScalarValue::from(3i32)), + /// ]).unwrap(); + /// + /// assert_eq!( + /// "Filter: t1.id = Int32(3)\ + /// \n TableScan: t1", + /// plan.display_indent().to_string() + /// ); + /// + /// ``` + pub fn with_param_values( + self, + param_values: impl Into, + ) -> Result { + let param_values = param_values.into(); + match self { + LogicalPlan::Prepare(prepare_lp) => { + param_values.verify(&prepare_lp.data_types)?; let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } - _ => Ok(self), + _ => self.replace_params_with_values(¶m_values), } } @@ -453,7 +1033,13 @@ impl LogicalPlan { pub fn max_rows(self: &LogicalPlan) -> Option { match self { LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), - LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(), + LogicalPlan::Filter(filter) => { + if filter.is_scalar() { + Some(1) + } else { + filter.input.max_rows() + } + } LogicalPlan::Window(Window { input, .. }) => input.max_rows(), LogicalPlan::Aggregate(Aggregate { input, group_expr, .. @@ -524,13 +1110,16 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, - LogicalPlan::Distinct(Distinct { input }) => input.max_rows(), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, LogicalPlan::Ddl(_) | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Prepare(_) | LogicalPlan::Statement(_) @@ -540,7 +1129,7 @@ impl LogicalPlan { } impl LogicalPlan { - /// applies collect to any subqueries in the plan + /// applies `op` to any subqueries in the plan pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> where F: FnMut(&Self) -> datafusion_common::Result, @@ -592,17 +1181,22 @@ impl LogicalPlan { Ok(()) } - /// Return a logical plan with all placeholders/params (e.g $1 $2, - /// ...) replaced with corresponding values provided in the - /// params_values + /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, + /// ...) replaced with corresponding values provided in + /// `params_values` + /// + /// See [`Self::with_param_values`] for examples and usage pub fn replace_params_with_values( &self, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { let new_exprs = self .expressions() .into_iter() - .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .map(|e| { + let e = e.infer_placeholder_types(self.schema())?; + Self::replace_placeholders_with_values(e, param_values) + }) .collect::>>()?; let new_inputs_with_values = self @@ -611,10 +1205,10 @@ impl LogicalPlan { .map(|inp| inp.replace_params_with_values(param_values)) .collect::>>()?; - from_plan(self, &new_exprs, &new_inputs_with_values) + self.with_new_exprs(new_exprs, &new_inputs_with_values) } - /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes + /// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes pub fn get_parameter_types( &self, ) -> Result>, DataFusionError> { @@ -628,9 +1222,7 @@ impl LogicalPlan { match (prev, data_type) { (Some(Some(prev)), Some(dt)) => { if prev != dt { - Err(DataFusionError::Plan(format!( - "Conflicting types for {id}" - )))?; + plan_err!("Conflicting types for {id}")?; } } (_, Some(dt)) => { @@ -653,38 +1245,15 @@ impl LogicalPlan { /// corresponding values provided in the params_values fn replace_placeholders_with_values( expr: Expr, - param_values: &[ScalarValue], + param_values: &ParamValues, ) -> Result { expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { - if id.is_empty() || id == "$0" { - return Err(DataFusionError::Plan( - "Empty placeholder id".to_string(), - )); - } - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if Some(value.get_datatype()) != *data_type { - return Err(DataFusionError::Internal(format!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.get_datatype() - ))); - } + let value = + param_values.get_placeholders_with_values(id, data_type)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value.clone()))) + Ok(Transformed::Yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = @@ -703,7 +1272,9 @@ impl LogicalPlan { // Various implementations for printing out LogicalPlans impl LogicalPlan { /// Return a `format`able structure that produces a single line - /// per node. For example: + /// per node. + /// + /// # Example /// /// ```text /// Projection: employee.id @@ -822,14 +1393,10 @@ impl LogicalPlan { struct Wrapper<'a>(&'a LogicalPlan); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - writeln!( - f, - "// Begin DataFusion GraphViz Plan (see https://graphviz.org)" - )?; - writeln!(f, "digraph {{")?; - let mut visitor = GraphvizVisitor::new(f); + visitor.start_graph()?; + visitor.pre_visit_plan("LogicalPlan")?; self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; @@ -839,8 +1406,7 @@ impl LogicalPlan { self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; - writeln!(f, "}}")?; - writeln!(f, "// End DataFusion GraphViz Plan")?; + visitor.end_graph()?; Ok(()) } } @@ -942,15 +1508,24 @@ impl LogicalPlan { } if !full_filter.is_empty() { - write!(f, ", full_filters={full_filter:?}")?; + write!( + f, + ", full_filters=[{}]", + expr_vec_fmt!(full_filter) + )?; }; if !partial_filter.is_empty() { - write!(f, ", partial_filters={partial_filter:?}")?; + write!( + f, + ", partial_filters=[{}]", + expr_vec_fmt!(partial_filter) + )?; } if !unsupported_filters.is_empty() { write!( f, - ", unsupported_filters={unsupported_filters:?}" + ", unsupported_filters=[{}]", + expr_vec_fmt!(unsupported_filters) )?; } } @@ -967,24 +1542,48 @@ impl LogicalPlan { if i > 0 { write!(f, ", ")?; } - write!(f, "{expr_item:?}")?; + write!(f, "{expr_item}")?; } Ok(()) } LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { write!(f, "Dml: op=[{op}] table=[{table_name}]") } + LogicalPlan::Copy(CopyTo { + input: _, + output_url, + file_format, + single_file_output, + copy_options, + }) => { + let op_str = match copy_options { + CopyOptions::SQLOptions(statement) => statement + .clone() + .into_inner() + .iter() + .map(|(k, v)| format!("{k} {v}")) + .collect::>() + .join(", "), + CopyOptions::WriterOptions(_) => "".into(), + }; + + write!(f, "CopyTo: format={file_format} output_url={output_url} single_file_output={single_file_output} options: ({op_str})") + } LogicalPlan::Ddl(ddl) => { write!(f, "{}", ddl.display()) } LogicalPlan::Filter(Filter { predicate: ref expr, .. - }) => write!(f, "Filter: {expr:?}"), + }) => write!(f, "Filter: {expr}"), LogicalPlan::Window(Window { ref window_expr, .. }) => { - write!(f, "WindowAggr: windowExpr=[{window_expr:?}]") + write!( + f, + "WindowAggr: windowExpr=[[{}]]", + expr_vec_fmt!(window_expr) + ) } LogicalPlan::Aggregate(Aggregate { ref group_expr, @@ -992,7 +1591,9 @@ impl LogicalPlan { .. }) => write!( f, - "Aggregate: groupBy=[{group_expr:?}], aggr=[{aggr_expr:?}]" + "Aggregate: groupBy=[[{}]], aggr=[[{}]]", + expr_vec_fmt!(group_expr), + expr_vec_fmt!(aggr_expr) ), LogicalPlan::Sort(Sort { expr, fetch, .. }) => { write!(f, "Sort: ")?; @@ -1000,7 +1601,7 @@ impl LogicalPlan { if i > 0 { write!(f, ", ")?; } - write!(f, "{expr_item:?}")?; + write!(f, "{expr_item}")?; } if let Some(a) = fetch { write!(f, ", fetch={a}")?; @@ -1054,7 +1655,7 @@ impl LogicalPlan { } Partitioning::Hash(expr, n) => { let hash_expr: Vec = - expr.iter().map(|e| format!("{e:?}")).collect(); + expr.iter().map(|e| format!("{e}")).collect(); write!( f, "Repartition: Hash({}) partition_count={}", @@ -1064,7 +1665,7 @@ impl LogicalPlan { } Partitioning::DistributeBy(expr) => { let dist_by_expr: Vec = - expr.iter().map(|e| format!("{e:?}")).collect(); + expr.iter().map(|e| format!("{e}")).collect(); write!( f, "Repartition: DistributeBy({})", @@ -1093,9 +1694,21 @@ impl LogicalPlan { LogicalPlan::Statement(statement) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(Distinct { .. }) => { - write!(f, "Distinct:") - } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => write!(f, "Distinct:"), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => write!( + f, + "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", + expr_vec_fmt!(on_expr), + expr_vec_fmt!(select_expr), + if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + ), + }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), @@ -1167,11 +1780,8 @@ pub struct Projection { impl Projection { /// Create a new Projection pub fn try_new(expr: Vec, input: Arc) -> Result { - let schema = Arc::new(DFSchema::new_with_metadata( - exprlist_to_fields(&expr, &input)?, - input.schema().metadata().clone(), - )?); - Self::try_new_with_schema(expr, input, schema) + let projection_schema = projection_schema(&input, &expr)?; + Self::try_new_with_schema(expr, input, projection_schema) } /// Create a new Projection using the specified output schema @@ -1181,7 +1791,7 @@ impl Projection { schema: DFSchemaRef, ) -> Result { if expr.len() != schema.fields().len() { - return Err(DataFusionError::Plan(format!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()))); + return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } Ok(Self { expr, @@ -1204,13 +1814,30 @@ impl Projection { schema, } } +} - pub fn try_from_plan(plan: &LogicalPlan) -> Result<&Projection> { - match plan { - LogicalPlan::Projection(it) => Ok(it), - _ => plan_err!("Could not coerce into Projection!"), - } - } +/// Computes the schema of the result produced by applying a projection to the input logical plan. +/// +/// # Arguments +/// +/// * `input`: A reference to the input `LogicalPlan` for which the projection schema +/// will be computed. +/// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. +/// +/// # Returns +/// +/// A `Result` containing an `Arc` representing the schema of the result +/// produced by the projection operation. If the schema computation is successful, +/// the `Result` will contain the schema; otherwise, it will contain an error. +pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { + let mut schema = DFSchema::new_with_metadata( + exprlist_to_fields(exprs, input)?, + input.schema().metadata().clone(), + )?; + schema = schema.with_functional_dependencies(calc_func_dependencies_for_project( + exprs, input, + )?)?; + Ok(Arc::new(schema)) } /// Aliased subquery @@ -1233,8 +1860,13 @@ impl SubqueryAlias { ) -> Result { let alias = alias.into(); let schema: Schema = plan.schema().as_ref().clone().into(); - let schema = - DFSchemaRef::new(DFSchema::try_from_qualified_schema(&alias, &schema)?); + // Since schema is the same, other than qualifier, we can use existing + // functional dependencies: + let func_dependencies = plan.schema().functional_dependencies().clone(); + let schema = DFSchemaRef::new( + DFSchema::try_from_qualified_schema(&alias, &schema)? + .with_functional_dependencies(func_dependencies)?, + ); Ok(SubqueryAlias { input: Arc::new(plan), alias, @@ -1272,29 +1904,89 @@ impl Filter { // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { if predicate_type != DataType::Boolean { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" - ))); + ); } } // filter predicates should not be aliased - if let Expr::Alias(expr, alias) = predicate { - return Err(DataFusionError::Plan(format!( + if let Expr::Alias(Alias { expr, name, .. }) = predicate { + return plan_err!( "Attempted to create Filter predicate with \ - expression `{expr}` aliased as '{alias}'. Filter predicates should not be \ + expression `{expr}` aliased as '{name}'. Filter predicates should not be \ aliased." - ))); + ); } Ok(Self { predicate, input }) } - pub fn try_from_plan(plan: &LogicalPlan) -> Result<&Filter> { - match plan { - LogicalPlan::Filter(it) => Ok(it), - _ => plan_err!("Could not coerce into Filter!"), + /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? + /// + /// This function will return `true` if its predicate contains a conjunction of + /// `col(a) = `, where its schema has a unique filter that is covered + /// by this conjunction. + /// + /// For example, for the table: + /// ```sql + /// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER); + /// ``` + /// `Filter(a = 2).is_scalar() == true` + /// , whereas + /// `Filter(b = 2).is_scalar() == false` + /// and + /// `Filter(a = 2 OR b = 2).is_scalar() == false` + fn is_scalar(&self) -> bool { + let schema = self.input.schema(); + + let functional_dependencies = self.input.schema().functional_dependencies(); + let unique_keys = functional_dependencies.iter().filter(|dep| { + let nullable = dep.nullable + && dep + .source_indices + .iter() + .any(|&source| schema.field(source).is_nullable()); + !nullable + && dep.mode == Dependency::Single + && dep.target_indices.len() == schema.fields().len() + }); + + let exprs = split_conjunction(&self.predicate); + let eq_pred_cols: HashSet<_> = exprs + .iter() + .filter_map(|expr| { + let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + else { + return None; + }; + // This is a no-op filter expression + if left == right { + return None; + } + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Column(_)) => None, + (Expr::Column(c), _) | (_, Expr::Column(c)) => { + Some(schema.index_of_column(c).unwrap()) + } + _ => None, + } + }) + .collect(); + + // If we have a functional dependence that is a subset of our predicate, + // this filter is scalar + for key in unique_keys { + if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) { + return true; + } } + false } } @@ -1309,6 +2001,29 @@ pub struct Window { pub schema: DFSchemaRef, } +impl Window { + /// Create a new window operator. + pub fn try_new(window_expr: Vec, input: Arc) -> Result { + let mut window_fields: Vec = input.schema().fields().clone(); + window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), &input)?); + let metadata = input.schema().metadata().clone(); + + // Update functional dependencies for window: + let mut window_func_dependencies = + input.schema().functional_dependencies().clone(); + window_func_dependencies.extend_target_indices(window_fields.len()); + + Ok(Window { + input, + window_expr, + schema: Arc::new( + DFSchema::new_with_metadata(window_fields, metadata)? + .with_functional_dependencies(window_func_dependencies)?, + ), + }) + } +} + /// Produces rows from a table provider by reference or from the context #[derive(Clone)] pub struct TableScan { @@ -1348,6 +2063,61 @@ impl Hash for TableScan { } } +impl TableScan { + /// Initialize TableScan with appropriate schema from the given + /// arguments. + pub fn try_new( + table_name: impl Into, + table_source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, + ) -> Result { + let table_name = table_name.into(); + + if table_name.table().is_empty() { + return plan_err!("table_name cannot be empty"); + } + let schema = table_source.schema(); + let func_dependencies = FunctionalDependencies::new_from_constraints( + table_source.constraints(), + schema.fields.len(), + ); + let projected_schema = projection + .as_ref() + .map(|p| { + let projected_func_dependencies = + func_dependencies.project_functional_dependencies(p, p.len()); + let df_schema = DFSchema::new_with_metadata( + p.iter() + .map(|i| { + DFField::from_qualified( + table_name.clone(), + schema.field(*i).clone(), + ) + }) + .collect(), + schema.metadata().clone(), + )?; + df_schema.with_functional_dependencies(projected_func_dependencies) + }) + .unwrap_or_else(|| { + let df_schema = + DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; + df_schema.with_functional_dependencies(func_dependencies) + })?; + let projected_schema = Arc::new(projected_schema); + Ok(Self { + table_name, + source: table_source, + projection, + projected_schema, + filters, + fetch, + }) + } +} + /// Apply Cross Join to two logical plans #[derive(Clone, PartialEq, Eq, Hash)] pub struct CrossJoin { @@ -1390,12 +2160,33 @@ pub struct Prepare { } /// Describe the schema of table +/// +/// # Example output: +/// +/// ```sql +/// ❯ describe traces; +/// +--------------------+-----------------------------+-------------+ +/// | column_name | data_type | is_nullable | +/// +--------------------+-----------------------------+-------------+ +/// | attributes | Utf8 | YES | +/// | duration_nano | Int64 | YES | +/// | end_time_unix_nano | Int64 | YES | +/// | service.name | Dictionary(Int32, Utf8) | YES | +/// | span.kind | Utf8 | YES | +/// | span.name | Utf8 | YES | +/// | span_id | Dictionary(Int32, Utf8) | YES | +/// | time | Timestamp(Nanosecond, None) | NO | +/// | trace_id | Dictionary(Int32, Utf8) | YES | +/// | otel.status_code | Utf8 | YES | +/// | parent_span_id | Utf8 | YES | +/// +--------------------+-----------------------------+-------------+ +/// ``` #[derive(Clone, PartialEq, Eq, Hash)] pub struct DescribeTable { /// Table schema pub schema: Arc, - /// Dummy schema - pub dummy_schema: DFSchemaRef, + /// schema of describe table output + pub output_schema: DFSchemaRef, } /// Produces a relation with string representations of @@ -1460,9 +2251,93 @@ pub struct Limit { /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Distinct { +pub enum Distinct { + /// Plain `DISTINCT` referencing all selection expressions + All(Arc), + /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns + On(DistinctOn), +} + +/// Removes duplicate rows from the input +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct DistinctOn { + /// The `DISTINCT ON` clause expression list + pub on_expr: Vec, + /// The selected projection expression list + pub select_expr: Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, + /// The schema description of the DISTINCT ON output + pub schema: DFSchemaRef, +} + +impl DistinctOn { + /// Create a new `DistinctOn` struct. + pub fn try_new( + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + input: Arc, + ) -> Result { + if on_expr.is_empty() { + return plan_err!("No `ON` expressions provided"); + } + + let on_expr = normalize_cols(on_expr, input.as_ref())?; + + let schema = DFSchema::new_with_metadata( + exprlist_to_fields(&select_expr, &input)?, + input.schema().metadata().clone(), + )?; + + let mut distinct_on = DistinctOn { + on_expr, + select_expr, + sort_expr: None, + input, + schema: Arc::new(schema), + }; + + if let Some(sort_expr) = sort_expr { + distinct_on = distinct_on.with_sort_expr(sort_expr)?; + } + + Ok(distinct_on) + } + + /// Try to update `self` with a new sort expressions. + /// + /// Validates that the sort expressions are a super-set of the `ON` expressions. + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + + // Check that the left-most sort expressions are the same as the `ON` expressions. + let mut matched = true; + for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { + match sort { + Expr::Sort(SortExpr { expr, .. }) => { + if on != &**expr { + matched = false; + break; + } + } + _ => return plan_err!("Not a sort expression: {sort}"), + } + } + + if self.on_expr.len() > sort_expr.len() || !matched { + return plan_err!( + "SELECT DISTINCT ON expressions must match initial ORDER BY expressions" + ); + } + + self.sort_expr = Some(sort_expr); + Ok(self) + } } /// Aggregates its input based on a set of grouping and aggregate @@ -1489,12 +2364,26 @@ impl Aggregate { aggr_expr: Vec, ) -> Result { let group_expr = enumerate_grouping_sets(group_expr)?; + + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, &input)?, - input.schema().metadata().clone(), - )?; + + let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?; + + // Even columns that cannot be null will become nullable when used in a grouping set. + if is_grouping_set { + fields = fields + .into_iter() + .map(|field| field.with_nullable(true)) + .collect::>(); + } + + fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?); + + let schema = + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?; + Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -1510,19 +2399,25 @@ impl Aggregate { schema: DFSchemaRef, ) -> Result { if group_expr.is_empty() && aggr_expr.is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Aggregate requires at least one grouping or aggregate expression" - .to_string(), - )); + ); } let group_expr_count = grouping_set_expr_count(&group_expr)?; if schema.fields().len() != group_expr_count + aggr_expr.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Aggregate schema has wrong number of fields. Expected {} got {}", group_expr_count + aggr_expr.len(), schema.fields().len() - ))); + ); } + + let aggregate_func_dependencies = + calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; + let new_schema = schema.as_ref().clone(); + let schema = Arc::new( + new_schema.with_functional_dependencies(aggregate_func_dependencies)?, + ); Ok(Self { input, group_expr, @@ -1531,14 +2426,79 @@ impl Aggregate { }) } - pub fn try_from_plan(plan: &LogicalPlan) -> Result<&Aggregate> { - match plan { - LogicalPlan::Aggregate(it) => Ok(it), - _ => plan_err!("Could not coerce into Aggregate!"), - } + /// Get the length of the group by expression in the output schema + /// This is not simply group by expression length. Expression may be + /// GroupingSet, etc. In these case we need to get inner expression lengths. + pub fn group_expr_len(&self) -> Result { + grouping_set_expr_count(&self.group_expr) } } +/// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. +fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr + .iter() + .any(|expr| matches!(expr, Expr::GroupingSet(_))) +} + +/// Calculates functional dependencies for aggregate expressions. +fn calc_func_dependencies_for_aggregate( + // Expressions in the GROUP BY clause: + group_expr: &[Expr], + // Input plan of the aggregate: + input: &LogicalPlan, + // Aggregate schema + aggr_schema: &DFSchema, +) -> Result { + // We can do a case analysis on how to propagate functional dependencies based on + // whether the GROUP BY in question contains a grouping set expression: + // - If so, the functional dependencies will be empty because we cannot guarantee + // that GROUP BY expression results will be unique. + // - Otherwise, it may be possible to propagate functional dependencies. + if !contains_grouping_set(group_expr) { + let group_by_expr_names = group_expr + .iter() + .map(|item| item.display_name()) + .collect::>>()?; + let aggregate_func_dependencies = aggregate_functional_dependencies( + input.schema(), + &group_by_expr_names, + aggr_schema, + ); + Ok(aggregate_func_dependencies) + } else { + Ok(FunctionalDependencies::empty()) + } +} + +/// This function projects functional dependencies of the `input` plan according +/// to projection expressions `exprs`. +fn calc_func_dependencies_for_project( + exprs: &[Expr], + input: &LogicalPlan, +) -> Result { + let input_fields = input.schema().fields(); + // Calculate expression indices (if present) in the input schema. + let proj_indices = exprs + .iter() + .filter_map(|expr| { + let expr_name = match expr { + Expr::Alias(alias) => { + format!("{}", alias.expr) + } + _ => format!("{}", expr), + }; + input_fields + .iter() + .position(|item| item.qualified_name() == expr_name) + }) + .collect::>(); + Ok(input + .schema() + .functional_dependencies() + .project_functional_dependencies(&proj_indices, exprs.len())) +} + /// Sorts its input according to a list of sort expressions. #[derive(Clone, PartialEq, Eq, Hash)] pub struct Sort { @@ -1587,7 +2547,7 @@ impl Join { let on: Vec<(Expr, Expr)> = column_on .0 .into_iter() - .zip(column_on.1.into_iter()) + .zip(column_on.1) .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) .collect(); let join_schema = @@ -1650,94 +2610,8 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents which type of plan, when storing multiple -/// for use in EXPLAIN plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum PlanType { - /// The initial LogicalPlan provided to DataFusion - InitialLogicalPlan, - /// The LogicalPlan which results from applying an analyzer pass - AnalyzedLogicalPlan { - /// The name of the analyzer which produced this plan - analyzer_name: String, - }, - /// The LogicalPlan after all analyzer passes have been applied - FinalAnalyzedLogicalPlan, - /// The LogicalPlan which results from applying an optimizer pass - OptimizedLogicalPlan { - /// The name of the optimizer which produced this plan - optimizer_name: String, - }, - /// The final, fully optimized LogicalPlan that was converted to a physical plan - FinalLogicalPlan, - /// The initial physical plan, prepared for execution - InitialPhysicalPlan, - /// The ExecutionPlan which results from applying an optimizer pass - OptimizedPhysicalPlan { - /// The name of the optimizer which produced this plan - optimizer_name: String, - }, - /// The final, fully optimized physical which would be executed - FinalPhysicalPlan, -} - -impl Display for PlanType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), - PlanType::AnalyzedLogicalPlan { analyzer_name } => { - write!(f, "logical_plan after {analyzer_name}") - } - PlanType::FinalAnalyzedLogicalPlan => write!(f, "analyzed_logical_plan"), - PlanType::OptimizedLogicalPlan { optimizer_name } => { - write!(f, "logical_plan after {optimizer_name}") - } - PlanType::FinalLogicalPlan => write!(f, "logical_plan"), - PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), - PlanType::OptimizedPhysicalPlan { optimizer_name } => { - write!(f, "physical_plan after {optimizer_name}") - } - PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), - } - } -} - -/// Represents some sort of execution plan, in String form -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct StringifiedPlan { - /// An identifier of what type of plan this string represents - pub plan_type: PlanType, - /// The string representation of the plan - pub plan: Arc, -} - -impl StringifiedPlan { - /// Create a new Stringified plan of `plan_type` with string - /// representation `plan` - pub fn new(plan_type: PlanType, plan: impl Into) -> Self { - StringifiedPlan { - plan_type, - plan: Arc::new(plan.into()), - } - } - - /// returns true if this plan should be displayed. Generally - /// `verbose_mode = true` will display all available plans - pub fn should_display(&self, verbose_mode: bool) -> bool { - match self.plan_type { - PlanType::FinalLogicalPlan | PlanType::FinalPhysicalPlan => true, - _ => verbose_mode, - } - } -} - -/// Trait for something that can be formatted as a stringified plan -pub trait ToStringifiedPlan { - /// Create a stringified plan with the specified type - fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan; -} - -/// Unnest a column that contains a nested list type. +/// Unnest a column that contains a nested list type. See +/// [`UnnestOptions`] for more details. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Unnest { /// The incoming logical plan @@ -1746,17 +2620,25 @@ pub struct Unnest { pub column: Column, /// The output schema, containing the unnested field column. pub schema: DFSchemaRef, + /// Options + pub options: UnnestOptions, } #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use super::*; + use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, exists, in_subquery, lit}; + use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeVisitor; - use datafusion_common::{DFSchema, TableReference}; - use std::collections::HashMap; + use datafusion_common::{ + not_impl_err, Constraint, DFSchema, ScalarValue, TableReference, + }; fn employee_schema() -> Schema { Schema::new(vec![ @@ -1830,31 +2712,46 @@ mod tests { fn test_display_graphviz() -> Result<()> { let plan = display_plan()?; + let expected_graphviz = r#" +// Begin DataFusion GraphViz Plan, +// display it online here: https://dreampuf.github.io/GraphvizOnline + +digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Projection: employee_csv.id"] + 3[shape=box label="Filter: employee_csv.state IN ()"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Subquery:"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: employee_csv projection=[state]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + 6[shape=box label="TableScan: employee_csv projection=[id, state]"] + 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_7 + { + graph[label="Detailed LogicalPlan"] + 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] + 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] + 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] + 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] + 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] + } +} +// End DataFusion GraphViz Plan +"#; + // just test for a few key lines in the output rather than the // whole thing to make test mainteance easier. let graphviz = format!("{}", plan.display_graphviz()); - assert!( - graphviz.contains( - r#"// Begin DataFusion GraphViz Plan (see https://graphviz.org)"# - ), - "\n{}", - plan.display_graphviz() - ); - assert!( - graphviz.contains( - r#"[shape=box label="TableScan: employee_csv projection=[id, state]"]"# - ), - "\n{}", - plan.display_graphviz() - ); - assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"]"#), - "\n{}", plan.display_graphviz()); - assert!( - graphviz.contains(r#"// End DataFusion GraphViz Plan"#), - "\n{}", - plan.display_graphviz() - ); + assert_eq!(expected_graphviz, graphviz); Ok(()) } @@ -1873,9 +2770,7 @@ mod tests { LogicalPlan::Filter { .. } => "pre_visit Filter", LogicalPlan::TableScan { .. } => "pre_visit TableScan", _ => { - return Err(DataFusionError::NotImplemented( - "unknown plan type".to_string(), - )) + return not_impl_err!("unknown plan type"); } }; @@ -1889,9 +2784,7 @@ mod tests { LogicalPlan::Filter { .. } => "post_visit Filter", LogicalPlan::TableScan { .. } => "post_visit TableScan", _ => { - return Err(DataFusionError::NotImplemented( - "unknown plan type".to_string(), - )) + return not_impl_err!("unknown plan type"); } }; @@ -2025,9 +2918,7 @@ mod tests { fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { - return Err(DataFusionError::NotImplemented( - "Error in pre_visit".to_string(), - )); + return not_impl_err!("Error in pre_visit"); } self.inner.pre_visit(plan) @@ -2035,9 +2926,7 @@ mod tests { fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { - return Err(DataFusionError::NotImplemented( - "Error in post_visit".to_string(), - )); + return not_impl_err!("Error in post_visit"); } self.inner.post_visit(plan) @@ -2051,14 +2940,11 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); - - if let Err(DataFusionError::NotImplemented(e)) = res { - assert_eq!("Error in pre_visit", e); - } else { - panic!("Expected an error"); - } - + let res = plan.visit(&mut visitor).unwrap_err(); + assert_eq!( + "This feature is not implemented: Error in pre_visit", + res.strip_backtrace() + ); assert_eq!( visitor.inner.strings, vec!["pre_visit Projection", "pre_visit Filter"] @@ -2072,13 +2958,11 @@ mod tests { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); - if let Err(DataFusionError::NotImplemented(e)) = res { - assert_eq!("Error in post_visit", e); - } else { - panic!("Expected an error"); - } - + let res = plan.visit(&mut visitor).unwrap_err(); + assert_eq!( + "This feature is not implemented: Error in post_visit", + res.strip_backtrace() + ); assert_eq!( visitor.inner.strings, vec![ @@ -2101,7 +2985,7 @@ mod tests { })), empty_schema, ); - assert_eq!("Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)", format!("{}", p.err().unwrap())); + assert_eq!(p.err().unwrap().strip_backtrace(), "Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); Ok(()) } @@ -2196,15 +3080,13 @@ mod tests { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder(""))) .unwrap() .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + let param_values = vec![ScalarValue::Int32(Some(42))]; + plan.replace_params_with_values(¶m_values.clone().into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); // test $0 placeholder @@ -2212,15 +3094,146 @@ mod tests { let plan = table_scan(TableReference::none(), &schema, None) .unwrap() - .filter(col("id").eq(Expr::Placeholder(Placeholder::new( - "$0".into(), - Some(DataType::Int32), - )))) + .filter(col("id").eq(placeholder("$0"))) .unwrap() .build() .unwrap(); - plan.replace_params_with_values(&[42i32.into()]) + plan.replace_params_with_values(¶m_values.into()) .expect_err("unexpectedly succeeded to replace an invalid placeholder"); } + + #[test] + fn test_nullable_schema_after_grouping_set() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate( + vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("foo")], + vec![col("bar")], + ]))], + vec![count(lit(true))], + ) + .unwrap() + .build() + .unwrap(); + + let output_schema = plan.schema(); + + assert!(output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(),); + assert!(output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable()); + } + + #[test] + fn test_filter_is_scalar() { + // test empty placeholder + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let source = Arc::new(LogicalTableSource::new(schema)); + let schema = Arc::new( + DFSchema::try_from_qualified_schema( + TableReference::bare("tab"), + &source.schema(), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source: source.clone(), + projection: None, + projected_schema: schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(!filter.is_scalar()); + let unique_schema = Arc::new( + schema + .as_ref() + .clone() + .with_functional_dependencies( + FunctionalDependencies::new_from_constraints( + Some(&Constraints::new_unverified(vec![Constraint::Unique( + vec![0], + )])), + 1, + ), + ) + .unwrap(), + ); + let scan = Arc::new(LogicalPlan::TableScan(TableScan { + table_name: TableReference::bare("tab"), + source, + projection: None, + projected_schema: unique_schema.clone(), + filters: vec![], + fetch: None, + })); + let col = schema.field(0).qualified_column(); + + let filter = Filter::try_new( + Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + scan, + ) + .unwrap(); + assert!(filter.is_scalar()); + } + + #[test] + fn test_transform_explain() { + let schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Int32, false), + ]); + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .explain(false, false) + .unwrap() + .build() + .unwrap(); + + let external_filter = + col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + + // after transformation, because plan is not the same anymore, + // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs + let plan = plan + .transform(&|plan| match plan { + LogicalPlan::TableScan(table) => { + let filter = Filter::try_new( + external_filter.clone(), + Arc::new(LogicalPlan::TableScan(table)), + ) + .unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + } + x => Ok(Transformed::No(x)), + }) + .unwrap(); + + let expected = "Explain\ + \n Filter: foo = Boolean(true)\ + \n TableScan: ?table?"; + let actual = format!("{}", plan.display_indent()); + assert_eq!(expected.to_string(), actual) + } } diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 686a681e36459..57888a11d426c 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -53,9 +53,13 @@ pub enum Operator { And, /// Logical OR, like `||` Or, - /// IS DISTINCT FROM + /// `IS DISTINCT FROM` (see [`distinct`]) + /// + /// [`distinct`]: arrow::compute::kernels::cmp::distinct IsDistinctFrom, - /// IS NOT DISTINCT FROM + /// `IS NOT DISTINCT FROM` (see [`not_distinct`]) + /// + /// [`not_distinct`]: arrow::compute::kernels::cmp::not_distinct IsNotDistinctFrom, /// Case sensitive regex match RegexMatch, @@ -69,7 +73,7 @@ pub enum Operator { BitwiseAnd, /// Bitwise or, like `|` BitwiseOr, - /// Bitwise xor, like `#` + /// Bitwise xor, such as `^` in MySQL or `#` in PostgreSQL BitwiseXor, /// Bitwise right, like `>>` BitwiseShiftRight, @@ -77,6 +81,10 @@ pub enum Operator { BitwiseShiftLeft, /// String concat StringConcat, + /// At arrow, like `@>` + AtArrow, + /// Arrow at, like `<@` + ArrowAt, } impl Operator { @@ -108,7 +116,9 @@ impl Operator { | Operator::BitwiseXor | Operator::BitwiseShiftRight | Operator::BitwiseShiftLeft - | Operator::StringConcat => None, + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt => None, } } @@ -167,6 +177,8 @@ impl Operator { Operator::LtEq => Some(Operator::GtEq), Operator::Gt => Some(Operator::Lt), Operator::GtEq => Some(Operator::LtEq), + Operator::AtArrow => Some(Operator::ArrowAt), + Operator::ArrowAt => Some(Operator::AtArrow), Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::Plus @@ -214,7 +226,9 @@ impl Operator { | Operator::BitwiseShiftLeft | Operator::BitwiseShiftRight | Operator::BitwiseXor - | Operator::StringConcat => 0, + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt => 0, } } } @@ -243,10 +257,12 @@ impl fmt::Display for Operator { Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", Operator::BitwiseAnd => "&", Operator::BitwiseOr => "|", - Operator::BitwiseXor => "#", + Operator::BitwiseXor => "BIT_XOR", Operator::BitwiseShiftRight => ">>", Operator::BitwiseShiftLeft => "<<", Operator::StringConcat => "||", + Operator::AtArrow => "@>", + Operator::ArrowAt => "<@", }; write!(f, "{display}") } @@ -351,6 +367,7 @@ impl ops::Neg for Expr { } } +/// Support `NOT ` fluent style impl Not for Expr { type Output = Self; @@ -361,19 +378,27 @@ impl Not for Expr { expr, pattern, escape_char, - }) => Expr::Like(Like::new(!negated, expr, pattern, escape_char)), - Expr::ILike(Like { - negated, + case_insensitive, + }) => Expr::Like(Like::new( + !negated, expr, pattern, escape_char, - }) => Expr::ILike(Like::new(!negated, expr, pattern, escape_char)), + case_insensitive, + )), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, - }) => Expr::SimilarTo(Like::new(!negated, expr, pattern, escape_char)), + case_insensitive, + }) => Expr::SimilarTo(Like::new( + !negated, + expr, + pattern, + escape_char, + case_insensitive, + )), _ => Expr::Not(Box::new(self)), } } @@ -387,57 +412,57 @@ mod tests { fn test_operators() { // Add assert_eq!( - format!("{:?}", lit(1u32) + lit(2u32)), + format!("{}", lit(1u32) + lit(2u32)), "UInt32(1) + UInt32(2)" ); // Sub assert_eq!( - format!("{:?}", lit(1u32) - lit(2u32)), + format!("{}", lit(1u32) - lit(2u32)), "UInt32(1) - UInt32(2)" ); // Mul assert_eq!( - format!("{:?}", lit(1u32) * lit(2u32)), + format!("{}", lit(1u32) * lit(2u32)), "UInt32(1) * UInt32(2)" ); // Div assert_eq!( - format!("{:?}", lit(1u32) / lit(2u32)), + format!("{}", lit(1u32) / lit(2u32)), "UInt32(1) / UInt32(2)" ); // Rem assert_eq!( - format!("{:?}", lit(1u32) % lit(2u32)), + format!("{}", lit(1u32) % lit(2u32)), "UInt32(1) % UInt32(2)" ); // BitAnd assert_eq!( - format!("{:?}", lit(1u32) & lit(2u32)), + format!("{}", lit(1u32) & lit(2u32)), "UInt32(1) & UInt32(2)" ); // BitOr assert_eq!( - format!("{:?}", lit(1u32) | lit(2u32)), + format!("{}", lit(1u32) | lit(2u32)), "UInt32(1) | UInt32(2)" ); // BitXor assert_eq!( - format!("{:?}", lit(1u32) ^ lit(2u32)), - "UInt32(1) # UInt32(2)" + format!("{}", lit(1u32) ^ lit(2u32)), + "UInt32(1) BIT_XOR UInt32(2)" ); // Shl assert_eq!( - format!("{:?}", lit(1u32) << lit(2u32)), + format!("{}", lit(1u32) << lit(2u32)), "UInt32(1) << UInt32(2)" ); // Shr assert_eq!( - format!("{:?}", lit(1u32) >> lit(2u32)), + format!("{}", lit(1u32) >> lit(2u32)), "UInt32(1) >> UInt32(2)" ); // Neg - assert_eq!(format!("{:?}", -lit(1u32)), "(- UInt32(1))"); + assert_eq!(format!("{}", -lit(1u32)), "(- UInt32(1))"); // Not - assert_eq!(format!("{:?}", !lit(1u32)), "NOT UInt32(1)"); + assert_eq!(format!("{}", !lit(1u32)), "NOT UInt32(1)"); } } diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs new file mode 100644 index 0000000000000..0a765b30b0736 --- /dev/null +++ b/datafusion/expr/src/partition_evaluator.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partition evaluation module + +use arrow::array::ArrayRef; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue}; +use std::fmt::Debug; +use std::ops::Range; + +use crate::window_state::WindowAggState; + +/// Partition evaluator for Window Functions +/// +/// # Background +/// +/// An implementation of this trait is created and used for each +/// partition defined by an `OVER` clause and is instantiated by +/// the DataFusion runtime. +/// +/// For example, evaluating `window_func(val) OVER (PARTITION BY col)` +/// on the following data: +/// +/// ```text +/// col | val +/// --- + ---- +/// A | 10 +/// A | 10 +/// C | 20 +/// D | 30 +/// D | 30 +/// ``` +/// +/// Will instantiate three `PartitionEvaluator`s, one each for the +/// partitions defined by `col=A`, `col=B`, and `col=C`. +/// +/// ```text +/// col | val +/// --- + ---- +/// A | 10 <--- partition 1 +/// A | 10 +/// +/// col | val +/// --- + ---- +/// C | 20 <--- partition 2 +/// +/// col | val +/// --- + ---- +/// D | 30 <--- partition 3 +/// D | 30 +/// ``` +/// +/// Different methods on this trait will be called depending on the +/// capabilities described by [`supports_bounded_execution`], +/// [`uses_window_frame`], and [`include_rank`], +/// +/// When implementing a new `PartitionEvaluator`, implement +/// corresponding evaluator according to table below. +/// +/// # Implementation Table +/// +/// |[`uses_window_frame`]|[`supports_bounded_execution`]|[`include_rank`]|function_to_implement| +/// |---|---|----|----| +/// |false (default) |false (default) |false (default) | [`evaluate_all`] | +/// |false |true |false | [`evaluate`] | +/// |false |true/false |true | [`evaluate_all_with_rank`] | +/// |true |true/false |true/false | [`evaluate`] | +/// +/// [`evaluate`]: Self::evaluate +/// [`evaluate_all`]: Self::evaluate_all +/// [`evaluate_all_with_rank`]: Self::evaluate_all_with_rank +/// [`uses_window_frame`]: Self::uses_window_frame +/// [`include_rank`]: Self::include_rank +/// [`supports_bounded_execution`]: Self::supports_bounded_execution +pub trait PartitionEvaluator: Debug + Send { + /// When the window frame has a fixed beginning (e.g UNBOUNDED + /// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and + /// NTH_VALUE do not need the (unbounded) input once they have + /// seen a certain amount of input. + /// + /// `memoize` is called after each input batch is processed, and + /// such functions can save whatever they need and modify + /// [`WindowAggState`] appropriately to allow rows to be pruned + fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> { + Ok(()) + } + + /// If `uses_window_frame` flag is `false`. This method is used to + /// calculate required range for the window function during + /// stateful execution. + /// + /// Generally there is no required range, hence by default this + /// returns smallest range(current row). e.g seeing current row is + /// enough to calculate window result (such as row_number, rank, + /// etc) + fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { + if self.uses_window_frame() { + exec_err!("Range should be calculated from window frame") + } else { + Ok(Range { + start: idx, + end: idx + 1, + }) + } + } + + /// Evaluate a window function on an entire input partition. + /// + /// This function is called once per input *partition* for window + /// functions that *do not use* values from the window frame, + /// such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`, + /// `CUME_DIST`, `LEAD`, `LAG`). + /// + /// It produces the result of all rows in a single pass. It + /// expects to receive the entire partition as the `value` and + /// must produce an output column with one output row for every + /// input row. + /// + /// `num_rows` is requied to correctly compute the output in case + /// `values.len() == 0` + /// + /// Implementing this function is an optimization: certain window + /// functions are not affected by the window frame definition or + /// the query doesn't have a frame, and `evaluate` skips the + /// (costly) window frame boundary calculation and the overhead of + /// calling `evaluate` for each output row. + /// + /// For example, the `LAG` built in window function does not use + /// the values of its window frame (it can be computed in one shot + /// on the entire partition with `Self::evaluate_all` regardless of the + /// window defined in the `OVER` clause) + /// + /// ```sql + /// lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + /// ``` + /// + /// However, `avg()` computes the average in the window and thus + /// does use its window frame + /// + /// ```sql + /// avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + /// ``` + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + // When window frame boundaries are not used and evaluator supports bounded execution + // We can calculate evaluate result by repeatedly calling `self.evaluate` `num_rows` times + // If user wants to implement more efficient version, this method should be overwritten + // Default implementation may behave suboptimally (For instance `NumRowEvaluator` overwrites it) + if !self.uses_window_frame() && self.supports_bounded_execution() { + let res = (0..num_rows) + .map(|idx| self.evaluate(values, &self.get_range(idx, num_rows)?)) + .collect::>>()?; + ScalarValue::iter_to_array(res) + } else { + not_impl_err!("evaluate_all is not implemented by default") + } + } + + /// Evaluate window function on a range of rows in an input + /// partition.x + /// + /// This is the simplest and most general function to implement + /// but also the least performant as it creates output one row at + /// a time. It is typically much faster to implement stateful + /// evaluation using one of the other specialized methods on this + /// trait. + /// + /// Returns a [`ScalarValue`] that is the value of the window + /// function within `range` for the entire partition. Argument + /// `values` contains the evaluation result of function arguments + /// and evaluation results of ORDER BY expressions. If function has a + /// single argument, `values[1..]` will contain ORDER BY expression results. + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &Range, + ) -> Result { + not_impl_err!("evaluate is not implemented by default") + } + + /// [`PartitionEvaluator::evaluate_all_with_rank`] is called for window + /// functions that only need the rank of a row within its window + /// frame. + /// + /// Evaluate the partition evaluator against the partition using + /// the row ranks. For example, `RANK(col)` produces + /// + /// ```text + /// col | rank + /// --- + ---- + /// A | 1 + /// A | 1 + /// C | 3 + /// D | 4 + /// D | 5 + /// ``` + /// + /// For this case, `num_rows` would be `5` and the + /// `ranks_in_partition` would be called with + /// + /// ```text + /// [ + /// (0,1), + /// (2,2), + /// (3,4), + /// ] + /// ``` + fn evaluate_all_with_rank( + &self, + _num_rows: usize, + _ranks_in_partition: &[Range], + ) -> Result { + not_impl_err!("evaluate_partition_with_rank is not implemented by default") + } + + /// Can the window function be incrementally computed using + /// bounded memory? + /// + /// See the table on [`Self`] for what functions to implement + fn supports_bounded_execution(&self) -> bool { + false + } + + /// Does the window function use the values from the window frame, + /// if one is specified? + /// + /// See the table on [`Self`] for what functions to implement + fn uses_window_frame(&self) -> bool { + false + } + + /// Can this function be evaluated with (only) rank + /// + /// See the table on [`Self`] for what functions to implement + fn include_rank(&self) -> bool { + false + } +} diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index a2caba4fb8bbd..685601523f9bb 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -20,86 +20,197 @@ use arrow::datatypes::DataType; +/// Constant that is used as a placeholder for any valid timezone. +/// This is used where a function can accept a timestamp type with any +/// valid timezone, it exists to avoid the need to enumerate all possible +/// timezones. See [`TypeSignature`] for more details. +/// +/// Type coercion always ensures that functions will be executed using +/// timestamp arrays that have a valid time zone. Functions must never +/// return results with this timezone. +pub const TIMEZONE_WILDCARD: &str = "+TZ"; + ///A function's volatility, which defines the functions eligibility for certain optimizations #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { - /// Immutable - An immutable function will always return the same output when given the same - /// input. An example of this is [super::BuiltinScalarFunction::Cos]. + /// An immutable function will always return the same output when given the same + /// input. An example of this is [super::BuiltinScalarFunction::Cos]. DataFusion + /// will attempt to inline immutable functions during planning. Immutable, - /// Stable - A stable function may return different values given the same input across different + /// A stable function may return different values given the same input across different /// queries but must return the same value for a given input within a query. An example of - /// this is [super::BuiltinScalarFunction::Now]. + /// this is [super::BuiltinScalarFunction::Now]. DataFusion + /// will attempt to inline `Stable` functions during planning, when possible. + /// For query `select col1, now() from t1`, it might take a while to execute but + /// `now()` column will be the same for each output row, which is evaluated + /// during planning. Stable, - /// Volatile - A volatile function may change the return value from evaluation to evaluation. + /// A volatile function may change the return value from evaluation to evaluation. /// Multiple invocations of a volatile function may return different results when used in the - /// same query. An example of this is [super::BuiltinScalarFunction::Random]. + /// same query. An example of this is [super::BuiltinScalarFunction::Random]. DataFusion + /// can not evaluate such functions during planning. + /// In the query `select col1, random() from t1`, `random()` function will be evaluated + /// for each output row, resulting in a unique random value for each row. Volatile, } -/// A function's type signature, which defines the function's supported argument types. +/// A function's type signature defines the types of arguments the function supports. +/// +/// Functions typically support only a few different types of arguments compared to the +/// different datatypes in Arrow. To make functions easy to use, when possible DataFusion +/// automatically coerces (add casts to) function arguments so they match the type signature. +/// +/// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query +/// that calles `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically +/// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. +/// +/// # Data Types +/// Types to match are represented using Arrow's [`DataType`]. [`DataType::Timestamp`] has an optional variable +/// timezone specification. To specify a function can handle a timestamp with *ANY* timezone, use +/// the [`TIMEZONE_WILDCARD`]. For example: +/// +/// ``` +/// # use arrow::datatypes::{DataType, TimeUnit}; +/// # use datafusion_expr::{TIMEZONE_WILDCARD, TypeSignature}; +/// let type_signature = TypeSignature::Exact(vec![ +/// // A nanosecond precision timestamp with ANY timezone +/// // matches Timestamp(Nanosecond, Some("+0:00")) +/// // matches Timestamp(Nanosecond, Some("+5:00")) +/// // does not match Timestamp(Nanosecond, None) +/// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), +/// ]); +/// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeSignature { - /// arbitrary number of arguments of an common type out of a list of valid types - // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + /// One or more arguments of an common type out of a list of valid types. + /// + /// # Examples + /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), - /// arbitrary number of arguments of an arbitrary but equal type - // A function such as `array` is `VariadicEqual` - // The first argument decides the type used for coercion + /// One or more arguments of an arbitrary but equal type. + /// DataFusion attempts to coerce all argument types to match the first argument's type + /// + /// # Examples + /// A function such as `array` is `VariadicEqual` VariadicEqual, - /// arbitrary number of arguments with arbitrary types + /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types - // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// + /// # Examples + /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` + /// 2. A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), - /// exact number of arguments of an exact type + /// Exact number of arguments of an exact type Exact(Vec), - /// fixed number of arguments of arbitrary types + /// Fixed number of arguments of arbitrary types + /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` Any(usize), - /// One of a list of signatures + /// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match + /// the signatures in order, and stops after the first success, if any. + /// + /// # Examples + /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` + /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), } -/// The signature of a function defines the supported argument types -/// and its volatility. +impl TypeSignature { + pub(crate) fn to_string_repr(&self) -> Vec { + match self { + TypeSignature::Variadic(types) => { + vec![format!("{}, ..", Self::join_types(types, "/"))] + } + TypeSignature::Uniform(arg_count, valid_types) => { + vec![std::iter::repeat(Self::join_types(valid_types, "/")) + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::Exact(types) => { + vec![Self::join_types(types, ", ")] + } + TypeSignature::Any(arg_count) => { + vec![std::iter::repeat("Any") + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], + TypeSignature::OneOf(sigs) => { + sigs.iter().flat_map(|s| s.to_string_repr()).collect() + } + } + } + + /// Helper function to join types with specified delimiter. + pub(crate) fn join_types( + types: &[T], + delimiter: &str, + ) -> String { + types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(delimiter) + } + + /// Check whether 0 input argument is valid for given `TypeSignature` + pub fn supports_zero_argument(&self) -> bool { + match &self { + TypeSignature::Exact(vec) => vec.is_empty(), + TypeSignature::Uniform(0, _) | TypeSignature::Any(0) => true, + TypeSignature::OneOf(types) => types + .iter() + .any(|type_sig| type_sig.supports_zero_argument()), + _ => false, + } + } +} + +/// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. +/// +/// DataFusion will automatically coerce (cast) argument types to one of the supported +/// function signatures, if possible. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Signature { - /// type_signature - The types that the function accepts. See [TypeSignature] for more information. + /// The data types that the function accepts. See [TypeSignature] for more information. pub type_signature: TypeSignature, - /// volatility - The volatility of the function. See [Volatility] for more information. + /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, } impl Signature { - /// new - Creates a new Signature from any type signature and the volatility. + /// Creates a new Signature from a given type signature and volatility. pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { Signature { type_signature, volatility, } } - /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. + /// An arbitrary number of arguments with the same type, from those listed in `common_types`. pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Variadic(common_types), volatility, } } - /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. + /// An arbitrary number of arguments of the same type. pub fn variadic_equal(volatility: Volatility) -> Self { Self { type_signature: TypeSignature::VariadicEqual, volatility, } } - /// variadic_any - Creates a variadic signature that represents an arbitrary number of arguments of any type. + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { type_signature: TypeSignature::VariadicAny, volatility, } } - /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. + /// A fixed number of arguments of the same type, from those listed in `valid_types`. pub fn uniform( arg_count: usize, valid_types: Vec, @@ -110,21 +221,21 @@ impl Signature { volatility, } } - /// exact - Creates a signature which must match the types in exact_types in order. + /// Exactly matches the types in `exact_types`, in order. pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::Exact(exact_types), volatility, } } - /// any - Creates a signature which can a be made of any type but of a specified number + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::Any(arg_count), volatility, } } - /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. + /// Any one of a list of [TypeSignature]s. pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::OneOf(type_signatures), @@ -132,3 +243,59 @@ impl Signature { } } } + +/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. +/// Each element of this vector corresponds to an argument and indicates whether +/// the function's behavior is monotonic, or non-monotonic/unknown for that argument, namely: +/// - `None` signifies unknown monotonicity or non-monotonicity. +/// - `Some(true)` indicates that the function is monotonically increasing w.r.t. the argument in question. +/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question. +pub type FuncMonotonicity = Vec>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn supports_zero_argument_tests() { + // Testing `TypeSignature`s which supports 0 arg + let positive_cases = vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Uniform(0, vec![DataType::Float64]), + TypeSignature::Any(0), + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Any(0), + TypeSignature::Uniform(1, vec![DataType::Int8]), + ]), + ]; + + for case in positive_cases { + assert!( + case.supports_zero_argument(), + "Expected {:?} to support zero arguments", + case + ); + } + + // Testing `TypeSignature`s which doesn't support 0 arg + let negative_cases = vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Uniform(1, vec![DataType::Float64]), + TypeSignature::Any(1), + TypeSignature::VariadicAny, + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Uniform(1, vec![DataType::Int8]), + ]), + ]; + + for case in negative_cases { + assert!( + !case.supports_zero_argument(), + "Expected {:?} not to support zero arguments", + case + ); + } + } +} diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index 2f5a8923e8bb6..565f48c1c5a9e 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -18,8 +18,10 @@ //! Table source use crate::{Expr, LogicalPlan}; + use arrow::datatypes::SchemaRef; -use datafusion_common::Result; +use datafusion_common::{Constraints, Result}; + use std::any::Any; /// Indicates whether and how a filter expression can be handled by a @@ -28,14 +30,14 @@ use std::any::Any; pub enum TableProviderFilterPushDown { /// The expression cannot be used by the provider. Unsupported, - /// The expression can be used to help minimise the data retrieved, - /// but the provider cannot guarantee that all returned tuples - /// satisfy the filter. The Filter plan node containing this expression - /// will be preserved. + /// The expression can be used to reduce the data retrieved, + /// but the provider cannot guarantee it will omit all tuples that + /// may be filtered. In this case, DataFusion will apply an additional + /// `Filter` operation after the scan to ensure all rows are filtered correctly. Inexact, - /// The provider guarantees that all returned data satisfies this - /// filter expression. The Filter plan node containing this expression - /// will be removed. + /// The provider **guarantees** that it will omit **all** tuples that are + /// filtered by the filter expression. This is the fastest option, if available + /// as DataFusion will not apply additional filtering. Exact, } @@ -64,6 +66,11 @@ pub trait TableSource: Sync + Send { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; + /// Get primary key indices, if one exists. + fn constraints(&self) -> Option<&Constraints> { + None + } + /// Get the type of this table for metadata/catalog purposes. fn table_type(&self) -> TableType { TableType::Base @@ -96,4 +103,9 @@ pub trait TableSource: Sync + Send { fn get_logical_plan(&self) -> Option<&LogicalPlan> { None } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 3ecf54c9ce26d..1098842716b9e 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,13 +18,14 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, Cast, GetIndexedField, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, - TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; -use crate::Expr; -use datafusion_common::tree_node::VisitRecursion; -use datafusion_common::{tree_node::TreeNode, Result}; +use crate::{Expr, GetFieldAccess}; + +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { fn apply_children(&self, op: &mut F) -> Result @@ -32,7 +33,7 @@ impl TreeNode for Expr { F: FnMut(&Self) -> Result, { let children = match self { - Expr::Alias(expr, _) + Expr::Alias(Alias{expr,..}) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -47,12 +48,23 @@ impl TreeNode for Expr { | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref().clone()], - Expr::GetIndexedField(GetIndexedField { expr, .. }) => { - vec![expr.as_ref().clone()] + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let expr = expr.as_ref().clone(); + match field { + GetFieldAccess::ListIndex {key} => { + vec![key.as_ref().clone(), expr] + }, + GetFieldAccess::ListRange {start, stop} => { + vec![start.as_ref().clone(), stop.as_ref().clone(), expr] + } + GetFieldAccess::NamedStructField {name: _name} => { + vec![expr] + } + } } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => { + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { args.clone() } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { @@ -65,14 +77,12 @@ impl TreeNode for Expr { | Expr::Literal(_) | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard {..} | Expr::Placeholder (_) => vec![], Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { vec![left.as_ref().clone(), right.as_ref().clone()] } Expr::Like(Like { expr, pattern, .. }) - | Expr::ILike(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { vec![expr.as_ref().clone(), pattern.as_ref().clone()] } @@ -98,7 +108,7 @@ impl TreeNode for Expr { expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { + => { let mut expr_vec = args.clone(); if let Some(f) = filter { @@ -147,9 +157,11 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Alias(expr, name) => { - Expr::Alias(transform_boxed(expr, &mut transform)?, name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, Expr::Exists { .. } => self, @@ -177,33 +189,26 @@ impl TreeNode for Expr { expr, pattern, escape_char, + case_insensitive, }) => Expr::Like(Like::new( negated, transform_boxed(expr, &mut transform)?, transform_boxed(pattern, &mut transform)?, escape_char, - )), - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => Expr::ILike(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, + case_insensitive, )), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, + case_insensitive, }) => Expr::SimilarTo(Like::new( negated, transform_boxed(expr, &mut transform)?, transform_boxed(pattern, &mut transform)?, escape_char, + case_insensitive, )), Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), Expr::IsNotNull(expr) => { @@ -271,12 +276,19 @@ impl TreeNode for Expr { asc, nulls_first, )), - Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - Expr::ScalarUDF(ScalarUDF { args, fun }) => { - Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( + ScalarFunction::new(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( + ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::WindowFunction(WindowFunction { args, fun, @@ -292,17 +304,40 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Expr::AggregateFunction(AggregateFunction::new( + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -319,24 +354,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, @@ -346,14 +364,11 @@ impl TreeNode for Expr { transform_vec(list, &mut transform)?, negated, )), - Expr::Wildcard => Expr::Wildcard, - Expr::QualifiedWildcard { qualifier } => { - Expr::QualifiedWildcard { qualifier } - } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::Wildcard { qualifier } => Expr::Wildcard { qualifier }, + Expr::GetIndexedField(GetIndexedField { expr, field }) => { Expr::GetIndexedField(GetIndexedField::new( transform_boxed(expr, &mut transform)?, - key, + field, )) } Expr::Placeholder(Placeholder { id, data_type }) => { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 4f02bf3dfd2a3..7128b575978a3 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,8 +17,10 @@ use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{DataFusionError, Result}; + +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use std::ops::Deref; use crate::{AggregateFunction, Signature, TypeSignature}; @@ -74,6 +76,8 @@ pub static TIMESTAMPS: &[DataType] = &[ pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; +pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; + pub static TIMES: &[DataType] = &[ DataType::Time32(TimeUnit::Second), DataType::Time32(TimeUnit::Millisecond), @@ -88,6 +92,7 @@ pub fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { + use DataType::*; // Validate input_types matches (at least one of) the func signature. check_arg_count(agg_fun, input_types, &signature.type_signature)?; @@ -104,24 +109,44 @@ pub fn coerce_types( AggregateFunction::Sum => { // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_sum_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) + let v = match &input_types[0] { + Decimal128(p, s) => Decimal128(*p, *s), + Decimal256(p, s) => Decimal256(*p, *s), + d if d.is_signed_integer() => Int64, + d if d.is_unsigned_integer() => UInt64, + d if d.is_floating() => Float64, + Dictionary(_, v) => { + return coerce_types(agg_fun, &[v.as_ref().clone()], signature) + } + _ => { + return plan_err!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, + input_types[0] + ) + } + }; + Ok(vec![v]) } AggregateFunction::Avg => { // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval - if !is_avg_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) + let v = match &input_types[0] { + Decimal128(p, s) => Decimal128(*p, *s), + Decimal256(p, s) => Decimal256(*p, *s), + d if d.is_numeric() => Float64, + Dictionary(_, v) => { + return coerce_types(agg_fun, &[v.as_ref().clone()], signature) + } + _ => { + return plan_err!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, + input_types[0] + ) + } + }; + Ok(vec![v]) } AggregateFunction::BitAnd | AggregateFunction::BitOr @@ -129,10 +154,11 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. if !is_bit_and_or_xor_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -140,127 +166,131 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. if !is_bool_and_or_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Variance => { - if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::VariancePop => { + AggregateFunction::Variance | AggregateFunction::VariancePop => { if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Covariance => { - if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } - Ok(input_types.to_vec()) + Ok(vec![Float64, Float64]) } - AggregateFunction::CovariancePop => { + AggregateFunction::Covariance | AggregateFunction::CovariancePop => { if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Stddev => { - if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } - Ok(input_types.to_vec()) + Ok(vec![Float64, Float64]) } - AggregateFunction::StddevPop => { + AggregateFunction::Stddev | AggregateFunction::StddevPop => { if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } - Ok(input_types.to_vec()) + Ok(vec![Float64]) } AggregateFunction::Correlation => { if !is_correlation_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } - Ok(input_types.to_vec()) + Ok(vec![Float64, Float64]) + } + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => { + let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat(); + let input_types_valid = // number of input already checked before + valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); + if !input_types_valid { + return plan_err!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, + input_types[0] + ); + } + Ok(vec![Float64, Float64]) } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The percentile sample points count for {:?} must be integer, not {:?}.", agg_fun, input_types[2] - ))); + ); } let mut result = input_types.to_vec(); if can_coerce_from(&DataType::Float64, &input_types[1]) { result[1] = DataType::Float64; } else { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", agg_fun, input_types[1] - ))); + ); } Ok(result) } AggregateFunction::ApproxPercentileContWithWeight => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, input_types[1] - ))); + agg_fun, + input_types[1] + ); } if !matches!(input_types[2], DataType::Float64) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, input_types[2] - ))); + agg_fun, + input_types[2] + ); } Ok(input_types.to_vec()) } AggregateFunction::ApproxMedian => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -268,6 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -284,22 +331,22 @@ fn check_arg_count( match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { if input_types.len() != *agg_count { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} expects {:?} arguments, but {:?} were provided", agg_fun, agg_count, input_types.len() - ))); + ); } } TypeSignature::Exact(types) => { if types.len() != input_types.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} expects {:?} arguments, but {:?} were provided", agg_fun, types.len(), input_types.len() - ))); + ); } } TypeSignature::OneOf(variants) => { @@ -307,24 +354,24 @@ fn check_arg_count( .iter() .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); if !ok { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not accept {:?} function arguments.", agg_fun, input_types.len() - ))); + ); } } TypeSignature::VariadicAny => { if input_types.is_empty() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {agg_fun:?} expects at least one argument" - ))); + ); } } _ => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Aggregate functions do not support this {signature:?}" - ))); + ); } } Ok(()) @@ -349,23 +396,22 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { - arg_type if SIGNED_INTEGERS.contains(arg_type) => Ok(DataType::Int64), - arg_type if UNSIGNED_INTEGERS.contains(arg_type) => Ok(DataType::UInt64), - // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, - // the result type of floating-point is FLOAT64 with the double precision. - DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } - DataType::Dictionary(_, dict_value_type) => { - sum_return_type(dict_value_type.as_ref()) + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) } - other => Err(DataFusionError::Plan(format!( - "SUM does not support type \"{other:?}\"" - ))), + other => plan_err!("SUM does not support type \"{other:?}\""), } } @@ -374,9 +420,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "VAR does not support {arg_type:?}" - ))) + plan_err!("VAR does not support {arg_type:?}") } } @@ -385,9 +429,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "COVAR does not support {arg_type:?}" - ))) + plan_err!("COVAR does not support {arg_type:?}") } } @@ -396,9 +438,7 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "CORR does not support {arg_type:?}" - ))) + plan_err!("CORR does not support {arg_type:?}") } } @@ -407,9 +447,7 @@ pub fn stddev_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "STDDEV does not support {arg_type:?}" - ))) + plan_err!("STDDEV does not support {arg_type:?}") } } @@ -423,13 +461,18 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal128(new_precision, new_scale)) } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(dict_value_type.as_ref()) } - other => Err(DataFusionError::Plan(format!( - "AVG does not support {other:?}" - ))), + other => plan_err!("AVG does not support {other:?}"), } } @@ -441,13 +484,16 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal128(new_precision, *scale)) } + DataType::Decimal256(precision, scale) => { + // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) } - other => Err(DataFusionError::Plan(format!( - "AVG does not support {other:?}" - ))), + other => plan_err!("AVG does not support {other:?}"), } } @@ -467,7 +513,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _)) + || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) ), } } @@ -480,7 +526,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _)) + || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) ), } } @@ -536,10 +582,18 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; - use crate::aggregate_function; use arrow::datatypes::DataType; #[test] @@ -547,25 +601,25 @@ mod tests { // test input args with error number input types let fun = AggregateFunction::Min; let input_types = vec![DataType::Int64, DataType::Int32]; - let signature = aggregate_function::signature(&fun); + let signature = fun.signature(); let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string()); + assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); // test input args is invalid data type for sum or avg let fun = AggregateFunction::Sum; let input_types = vec![DataType::Utf8]; - let signature = aggregate_function::signature(&fun); + let signature = fun.signature(); let result = coerce_types(&fun, &input_types, &signature); assert_eq!( "Error during planning: The function Sum does not support inputs of type Utf8.", - result.unwrap_err().to_string() + result.unwrap_err().strip_backtrace() ); let fun = AggregateFunction::Avg; - let signature = aggregate_function::signature(&fun); + let signature = fun.signature(); let result = coerce_types(&fun, &input_types, &signature); assert_eq!( "Error during planning: The function Avg does not support inputs of type Utf8.", - result.unwrap_err().to_string() + result.unwrap_err().strip_backtrace() ); // test count, array_agg, approx_distinct, min, max. @@ -580,29 +634,39 @@ mod tests { let input_types = vec![ vec![DataType::Int32], vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal256(1, 1)], vec![DataType::Utf8], ]; for fun in funs { for input_type in &input_types { - let signature = aggregate_function::signature(&fun); - let result = coerce_types(&fun, input_type, &signature); - assert_eq!(*input_type, result.unwrap()); - } - } - // test sum, avg - let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; - let input_types = vec![ - vec![DataType::Int32], - vec![DataType::Float32], - vec![DataType::Decimal128(20, 3)], - ]; - for fun in funs { - for input_type in &input_types { - let signature = aggregate_function::signature(&fun); + let signature = fun.signature(); let result = coerce_types(&fun, input_type, &signature); assert_eq!(*input_type, result.unwrap()); } } + // test sum + let fun = AggregateFunction::Sum; + let signature = fun.signature(); + let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap(); + assert_eq!(r[0], DataType::Int64); + let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap(); + assert_eq!(r[0], DataType::Float64); + let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap(); + assert_eq!(r[0], DataType::Decimal128(20, 3)); + let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); + assert_eq!(r[0], DataType::Decimal256(20, 3)); + + // test avg + let fun = AggregateFunction::Avg; + let signature = fun.signature(); + let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap(); + assert_eq!(r[0], DataType::Float64); + let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap(); + assert_eq!(r[0], DataType::Float64); + let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap(); + assert_eq!(r[0], DataType::Decimal128(20, 3)); + let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); + assert_eq!(r[0], DataType::Decimal256(20, 3)); // ApproxPercentileCont input types let input_types = vec![ @@ -618,8 +682,7 @@ mod tests { vec![DataType::Float64, DataType::Float64], ]; for input_type in &input_types { - let signature = - aggregate_function::signature(&AggregateFunction::ApproxPercentileCont); + let signature = AggregateFunction::ApproxPercentileCont.signature(); let result = coerce_types( &AggregateFunction::ApproxPercentileCont, input_type, diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index f8a04de45bb45..dd9449198796a 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -17,243 +17,238 @@ //! Coercion rules for matching argument types for binary operators +use std::sync::Arc; + +use crate::Operator; + +use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::DataFusionError; -use datafusion_common::Result; +use datafusion_common::{ + exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result, +}; -use crate::type_coercion::{is_datetime, is_decimal, is_interval, is_numeric}; -use crate::Operator; +/// The type signature of an instantiation of binary operator expression such as +/// `lhs + rhs` +/// +/// Note this is different than [`crate::signature::Signature`] which +/// describes the type signature of a function. +struct Signature { + /// The type to coerce the left argument to + lhs: DataType, + /// The type to coerce the right argument to + rhs: DataType, + /// The return type of the expression + ret: DataType, +} -/// Returns the result type of applying mathematics operations such as -/// `+` to arguments of `lhs_type` and `rhs_type`. -fn mathematics_temporal_result_type( - lhs_type: &DataType, - rhs_type: &DataType, -) -> Option { - use arrow::datatypes::DataType::*; - use arrow::datatypes::IntervalUnit::*; - use arrow::datatypes::TimeUnit::*; +impl Signature { + /// A signature where the inputs are the same type as the output + fn uniform(t: DataType) -> Self { + Self { + lhs: t.clone(), + rhs: t.clone(), + ret: t, + } + } - if !is_interval(lhs_type) - && !is_interval(rhs_type) - && !is_datetime(lhs_type) - && !is_datetime(rhs_type) - { - return None; - }; + /// A signature where the inputs are the same type with a boolean output + fn comparison(t: DataType) -> Self { + Self { + lhs: t.clone(), + rhs: t, + ret: DataType::Boolean, + } + } +} - match (lhs_type, rhs_type) { - // datetime +/- interval - (Interval(_), Timestamp(_, _)) => Some(rhs_type.clone()), - (Timestamp(_, _), Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date32) => Some(rhs_type.clone()), - (Date32, Interval(_)) => Some(lhs_type.clone()), - (Interval(_), Date64) => Some(rhs_type.clone()), - (Date64, Interval(_)) => Some(lhs_type.clone()), - // interval +/- - (Interval(l), Interval(h)) if l == h => Some(lhs_type.clone()), - (Interval(_), Interval(_)) => Some(Interval(MonthDayNano)), - // timestamp - timestamp - (Timestamp(Second, _), Timestamp(Second, _)) - | (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => { - Some(Interval(DayTime)) +/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` +fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; + use Operator::*; + match op { + Eq | + NotEq | + Lt | + LtEq | + Gt | + GtEq | + IsDistinctFrom | + IsNotDistinctFrom => { + comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common argument type for comparison operation {lhs} {op} {rhs}" + ) + }) } - (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) - | (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => { - Some(Interval(MonthDayNano)) + And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { + // Logical binary boolean operators can only be evaluated for + // boolean or null arguments. + Ok(Signature::uniform(DataType::Boolean)) + } else { + plan_err!( + "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" + ) } - (Timestamp(_, _), Timestamp(_, _)) => None, - // date - date - (Date32, Date32) => Some(Interval(DayTime)), - (Date64, Date64) => Some(Interval(MonthDayNano)), - (Date32, Date64) | (Date64, Date32) => Some(Interval(MonthDayNano)), - // date - timestamp, timestamp - date - (Date32, Timestamp(_, _)) - | (Timestamp(_, _), Date32) - | (Date64, Timestamp(_, _)) - | (Timestamp(_, _), Date64) => { - // TODO: make get_result_type must after coerce type. - // if type isn't coerced, we need get common type, and then get result type. - let common_type = temporal_coercion(lhs_type, rhs_type); - common_type.and_then(|t| mathematics_temporal_result_type(&t, &t)) + RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common argument type for regex operation {lhs} {op} {rhs}" + ) + }) + } + BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { + bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common type for bitwise operation {lhs} {op} {rhs}" + ) + }) + } + StringConcat => { + string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common string type for string concat operation {lhs} {op} {rhs}" + ) + }) + } + AtArrow | ArrowAt => { + // ArrowAt and AtArrow check for whether one array is contained in another. + // The result type is boolean. Signature::comparison defines this signature. + // Operation has nothing to do with comparison + array_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common array type for arrow operation {lhs} {op} {rhs}" + ) + }) + } + Plus | Minus | Multiply | Divide | Modulo => { + let get_result = |lhs, rhs| { + use arrow::compute::kernels::numeric::*; + let l = new_empty_array(lhs); + let r = new_empty_array(rhs); + + let result = match op { + Plus => add_wrapping(&l, &r), + Minus => sub_wrapping(&l, &r), + Multiply => mul_wrapping(&l, &r), + Divide => div(&l, &r), + Modulo => rem(&l, &r), + _ => unreachable!(), + }; + result.map(|x| x.data_type().clone()) + }; + + if let Ok(ret) = get_result(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + Interval + Ok(Signature{ + lhs: lhs.clone(), + rhs: rhs.clone(), + ret, + }) + } else if let Some(coerced) = temporal_coercion(lhs, rhs) { + // Temporal arithmetic by first coercing to a common time representation + // e.g. Date32 - Timestamp + let ret = get_result(&coerced, &coerced).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for temporal operation {coerced} {op} {coerced}: {e}" + ) + })?; + Ok(Signature{ + lhs: coerced.clone(), + rhs: coerced, + ret, + }) + } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { + // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) + let ret = get_result(&lhs, &rhs).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for decimal operation {lhs} {op} {rhs}: {e}" + ) + })?; + Ok(Signature{ + lhs, + rhs, + ret, + }) + } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { + // Numeric arithmetic, e.g. Int32 + Int32 + Ok(Signature::uniform(numeric)) + } else { + plan_err!( + "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types" + ) + } } - _ => None, } } /// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( - lhs_type: &DataType, + lhs: &DataType, op: &Operator, - rhs_type: &DataType, + rhs: &DataType, ) -> Result { - if op.is_numerical_operators() && any_decimal(lhs_type, rhs_type) { - let (coerced_lhs_type, coerced_rhs_type) = - math_decimal_coercion(lhs_type, rhs_type); - - let lhs_type = coerced_lhs_type.unwrap_or(lhs_type.clone()); - let rhs_type = coerced_rhs_type.unwrap_or(rhs_type.clone()); - - if op.is_numerical_operators() { - if let Some(result_type) = - decimal_op_mathematics_type(op, &lhs_type, &rhs_type) - { - return Ok(result_type); - } - } - } - let result = match op { - Operator::And - | Operator::Or - | Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::Gt - | Operator::GtEq - | Operator::LtEq - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom => Some(DataType::Boolean), - Operator::Plus | Operator::Minus - if is_datetime(lhs_type) && is_datetime(rhs_type) - || (is_interval(lhs_type) && is_interval(rhs_type)) - || (is_datetime(lhs_type) && is_interval(rhs_type)) - || (is_interval(lhs_type) && is_datetime(rhs_type)) => - { - mathematics_temporal_result_type(lhs_type, rhs_type) - } - // following same with `coerce_types` - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type), - Operator::Plus - | Operator::Minus - | Operator::Modulo - | Operator::Divide - | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type), - Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type), - }; - - result.ok_or(DataFusionError::Plan(format!( - "Unsupported argument types. Can not evaluate {lhs_type:?} {op} {rhs_type:?}" - ))) + signature(lhs, op, rhs).map(|sig| sig.ret) } -/// Coercion rules for all binary operators. Returns the 'coerce_types' -/// is returns the type the arguments should be coerced to -/// -/// Returns None if no suitable type can be found. -pub fn coerce_types( - lhs_type: &DataType, +/// Returns the coerced input types for a binary expression evaluating the `op` with the left and right hand types +pub fn get_input_types( + lhs: &DataType, op: &Operator, - rhs_type: &DataType, -) -> Result { - // This result MUST be compatible with `binary_coerce` - let result = match op { - Operator::BitwiseAnd - | Operator::BitwiseOr - | Operator::BitwiseXor - | Operator::BitwiseShiftRight - | Operator::BitwiseShiftLeft => bitwise_coercion(lhs_type, rhs_type), - Operator::And | Operator::Or => match (lhs_type, rhs_type) { - // logical binary boolean operators can only be evaluated in bools or nulls - (DataType::Boolean, DataType::Boolean) - | (DataType::Null, DataType::Null) - | (DataType::Boolean, DataType::Null) - | (DataType::Null, DataType::Boolean) => Some(DataType::Boolean), - _ => None, - }, - // logical comparison operators have their own rules, and always return a boolean - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::Gt - | Operator::GtEq - | Operator::LtEq - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom => comparison_coercion(lhs_type, rhs_type), - Operator::Plus | Operator::Minus - if is_interval(lhs_type) && is_interval(rhs_type) => - { - temporal_coercion(lhs_type, rhs_type) - } - Operator::Minus if is_datetime(lhs_type) && is_datetime(rhs_type) => { - temporal_coercion(lhs_type, rhs_type) - } - // for math expressions, the final value of the coercion is also the return type - // because coercion favours higher information types - Operator::Plus - | Operator::Minus - | Operator::Modulo - | Operator::Divide - | Operator::Multiply => mathematics_numerical_coercion(lhs_type, rhs_type), - Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch => regex_coercion(lhs_type, rhs_type), - // "||" operator has its own rules, and always return a string type - Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type), - }; - - // re-write the error message of failed coercions to include the operator's information - result.ok_or(DataFusionError::Plan(format!("{lhs_type:?} {op} {rhs_type:?} can't be evaluated because there isn't a common type to coerce the types to"))) + rhs: &DataType, +) -> Result<(DataType, DataType)> { + signature(lhs, op, rhs).map(|sig| (sig.lhs, sig.rhs)) } /// Coercion rules for mathematics operators between decimal and non-decimal types. -pub fn math_decimal_coercion( +fn math_decimal_coercion( lhs_type: &DataType, rhs_type: &DataType, -) -> (Option, Option) { +) -> Option<(DataType, DataType)> { use arrow::datatypes::DataType::*; - if both_decimal(lhs_type, rhs_type) { - return (None, None); - } - match (lhs_type, rhs_type) { - (Null, dec_type @ Decimal128(_, _)) => (Some(dec_type.clone()), None), - (dec_type @ Decimal128(_, _), Null) => (None, Some(dec_type.clone())), - (Dictionary(key_type, value_type), _) => { - let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type); - let lhs_type = value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))); - (lhs_type, rhs_type) + (Dictionary(_, value_type), _) => { + let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?; + Some((value_type, rhs_type)) } - (_, Dictionary(key_type, value_type)) => { - let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type); - let rhs_type = value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))); - (lhs_type, rhs_type) + (_, Dictionary(_, value_type)) => { + let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; + Some((lhs_type, value_type)) + } + (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { + Some((dec_type.clone(), dec_type.clone())) } - (Decimal128(_, _), Float32 | Float64) => (Some(Float64), Some(Float64)), - (Float32 | Float64, Decimal128(_, _)) => (Some(Float64), Some(Float64)), - (Decimal128(_, _), _) => { - let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type); - (None, converted_decimal_type) + (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { + Some((lhs_type.clone(), rhs_type.clone())) } - (_, Decimal128(_, _)) => { - let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type); - (converted_decimal_type, None) + // Unlike with comparison we don't coerce to a decimal in the case of floating point + // numbers, instead falling back to floating point arithmetic instead + (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => { + Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?)) } - _ => (None, None), + (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => { + Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) + } + (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal256(rhs_type)?, + )), + (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some(( + coerce_numeric_type_to_decimal256(lhs_type)?, + rhs_type.clone(), + )), + _ => None, } } /// Returns the output type of applying bitwise operations such as /// `&`, `|`, or `xor`to arguments of `lhs_type` and `rhs_type`. -pub(crate) fn bitwise_coercion( - left_type: &DataType, - right_type: &DataType, -) -> Option { +fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option { use arrow::datatypes::DataType::*; if !both_numeric_or_null_and_numeric(left_type, right_type) { @@ -289,9 +284,7 @@ pub(crate) fn bitwise_coercion( } } -/// Returns the output type of applying comparison operations such as -/// `eq`, `not eq`, `lt`, `lteq`, `gt`, and `gteq` to arguments -/// of `lhs_type` and `rhs_type`. +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible @@ -303,30 +296,72 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8, _) if is_numeric(rhs_type) => Some(Utf8), - (LargeUtf8, _) if is_numeric(rhs_type) => Some(LargeUtf8), - (_, Utf8) if is_numeric(lhs_type) => Some(Utf8), - (_, LargeUtf8) if is_numeric(lhs_type) => Some(LargeUtf8), + (Utf8, _) if rhs_type.is_numeric() => Some(Utf8), + (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8), + (_, Utf8) if lhs_type.is_numeric() => Some(Utf8), + (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8), _ => None, } } -/// Returns the output type of applying numeric operations such as `=` -/// to arguments `lhs_type` and `rhs_type` if both are numeric +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation +/// where one is temporal and one is `Utf8`/`LargeUtf8`. +/// +/// Note this cannot be performed in case of arithmetic as there is insufficient information +/// to correctly determine the type of argument. Consider +/// +/// ```sql +/// timestamp > now() - '1 month' +/// interval > now() - '1970-01-2021' +/// ``` +/// +/// In the absence of a full type inference system, we can't determine the correct type +/// to parse the string argument +fn string_temporal_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + fn match_rule(l: &DataType, r: &DataType) -> Option { + match (l, r) { + // Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp + (Utf8, temporal) | (LargeUtf8, temporal) => match temporal { + Date32 | Date64 => Some(temporal.clone()), + Time32(_) | Time64(_) => { + if is_time_with_valid_unit(temporal.to_owned()) { + Some(temporal.to_owned()) + } else { + None + } + } + Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + _ => None, + }, + _ => None, + } + } + + match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation +/// where one both are numeric fn comparison_binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + if !lhs_type.is_numeric() || !rhs_type.is_numeric() { return None; }; @@ -338,12 +373,17 @@ fn comparison_binary_numeric_coercion( // these are ordered from most informative to least informative so // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { - // support decimal data type for comparison operation + // Prefer decimal data type over floating point for comparison operation (Decimal128(_, _), Decimal128(_, _)) => { get_wider_decimal_type(lhs_type, rhs_type) } (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), + (Decimal256(_, _), Decimal256(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } + (Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), + (_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), // The following match arms encode the following logic: Given the two @@ -381,28 +421,22 @@ fn comparison_binary_numeric_coercion( } } -/// Returns the output type of applying numeric operations such as `=` -/// to a decimal type `decimal_type` and `other_type` +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of +/// a comparison operation where one is a decimal fn get_comparison_common_decimal_type( decimal_type: &DataType, other_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; - let other_decimal_type = &match other_type { - // This conversion rule is from spark - // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 - Int8 => Decimal128(3, 0), - Int16 => Decimal128(5, 0), - Int32 => Decimal128(10, 0), - Int64 => Decimal128(20, 0), - Float32 => Decimal128(14, 7), - Float64 => Decimal128(30, 15), - _ => { - return None; + match decimal_type { + Decimal128(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal256(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) } - }; - match (decimal_type, &other_decimal_type) { - (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2), _ => None, } } @@ -422,14 +456,70 @@ fn get_wider_decimal_type( let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); Some(create_decimal_type((range + s) as u8, s)) } + (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = *s1.max(s2); + let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); + Some(create_decimal256_type((range + s) as u8, s)) + } (_, _) => None, } } +/// Returns the wider type among arguments `lhs` and `rhs`. +/// The wider type is the type that can safely represent values from both types +/// without information loss. Returns an Error if types are incompatible. +pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { + use arrow::datatypes::DataType::*; + Ok(match (lhs, rhs) { + (lhs, rhs) if lhs == rhs => lhs.clone(), + // Right UInt is larger than left UInt. + (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) | + // Right Int is larger than left Int. + (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) | + // Right Float is larger than left Float. + (Float16, Float32 | Float64) | (Float32, Float64) | + // Right String is larger than left String. + (Utf8, LargeUtf8) | + // Any right type is wider than a left hand side Null. + (Null, _) => rhs.clone(), + // Left UInt is larger than right UInt. + (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) | + // Left Int is larger than right Int. + (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | + // Left Float is larger than right Float. + (Float32 | Float64, Float16) | (Float64, Float32) | + // Left String is larget than right String. + (LargeUtf8, Utf8) | + // Any left type is wider than a right hand side Null. + (_, Null) => lhs.clone(), + (List(lhs_field), List(rhs_field)) => { + let field_type = + get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; + if lhs_field.name() != rhs_field.name() { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both {lhs} and {rhs}." + )); + } + assert_eq!(lhs_field.name(), rhs_field.name()); + let field_name = lhs_field.name(); + let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); + List(Arc::new(Field::new(field_name, field_type, nullable))) + } + (_, _) => { + return Err(exec_datafusion_err!( + "There is no wider type that can represent both {lhs} and {rhs}." + )); + } + }) +} + /// Convert the numeric data type to the decimal data type. /// Now, we just support the signed integer type and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 match numeric_type { Int8 => Some(Decimal128(3, 0)), Int16 => Some(Decimal128(5, 0)), @@ -442,6 +532,24 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { } } +/// Convert the numeric data type to the decimal data type. +/// Now, we just support the signed integer type and floating-point type. +fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + match numeric_type { + Int8 => Some(Decimal256(3, 0)), + Int16 => Some(Decimal256(5, 0)), + Int32 => Some(Decimal256(10, 0)), + Int64 => Some(Decimal256(20, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float32 => Some(Decimal256(14, 7)), + Float64 => Some(Decimal256(30, 15)), + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -461,10 +569,8 @@ fn mathematics_numerical_coercion( (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { mathematics_numerical_coercion(lhs_value_type, rhs_value_type) } - (Dictionary(key_type, value_type), _) => { - let value_type = mathematics_numerical_coercion(value_type, rhs_type); - value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))) + (Dictionary(_, value_type), _) => { + mathematics_numerical_coercion(value_type, rhs_type) } (_, Dictionary(_, value_type)) => { mathematics_numerical_coercion(lhs_type, value_type) @@ -490,141 +596,29 @@ fn create_decimal_type(precision: u8, scale: i8) -> DataType { ) } -/// Returns the coerced type of applying mathematics operations on decimal types. -/// Two sides of the mathematics operation will be coerced to the same type. Note -/// that we don't coerce the decimal operands in analysis phase, but do it in the -/// execution phase because this is not idempotent. -pub fn coercion_decimal_mathematics_type( - mathematics_op: &Operator, - left_decimal_type: &DataType, - right_decimal_type: &DataType, -) -> Option { - use arrow::datatypes::DataType::*; - match (left_decimal_type, right_decimal_type) { - // The promotion rule from spark - // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 - (Decimal128(_, _), Decimal128(_, _)) => match mathematics_op { - Operator::Plus | Operator::Minus => decimal_op_mathematics_type( - mathematics_op, - left_decimal_type, - right_decimal_type, - ), - Operator::Divide | Operator::Modulo => { - get_wider_decimal_type(left_decimal_type, right_decimal_type) - } - _ => None, - }, - _ => None, - } -} - -/// Returns the output type of applying mathematics operations on decimal types. -/// The rule is from spark. Note that this is different to the coerced type applied -/// to two sides of the arithmetic operation. -pub fn decimal_op_mathematics_type( - mathematics_op: &Operator, - left_decimal_type: &DataType, - right_decimal_type: &DataType, -) -> Option { - use arrow::datatypes::DataType::*; - match (left_decimal_type, right_decimal_type) { - // The coercion rule from spark - // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 - (Decimal128(p1, s1), Decimal128(p2, s2)) => { - match mathematics_op { - Operator::Plus | Operator::Minus => { - // max(s1, s2) - let result_scale = *s1.max(s2); - // max(s1, s2) + max(p1-s1, p2-s2) + 1 - let result_precision = - result_scale + (*p1 as i8 - *s1).max(*p2 as i8 - *s2) + 1; - Some(create_decimal_type(result_precision as u8, result_scale)) - } - Operator::Multiply => { - // s1 + s2 - let result_scale = *s1 + *s2; - // p1 + p2 + 1 - let result_precision = *p1 + *p2 + 1; - Some(create_decimal_type(result_precision, result_scale)) - } - Operator::Divide => { - // max(6, s1 + p2 + 1) - let result_scale = 6.max(*s1 + *p2 as i8 + 1); - // p1 - s1 + s2 + max(6, s1 + p2 + 1) - let result_precision = result_scale + *p1 as i8 - *s1 + *s2; - Some(create_decimal_type(result_precision as u8, result_scale)) - } - Operator::Modulo => { - // max(s1, s2) - let result_scale = *s1.max(s2); - // min(p1-s1, p2-s2) + max(s1, s2) - let result_precision = - result_scale + (*p1 as i8 - *s1).min(*p2 as i8 - *s2); - Some(create_decimal_type(result_precision as u8, result_scale)) - } - _ => None, - } - } - (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { - decimal_op_mathematics_type( - mathematics_op, - lhs_value_type.as_ref(), - rhs_value_type.as_ref(), - ) - } - (Dictionary(key_type, value_type), _) => { - let value_type = decimal_op_mathematics_type( - mathematics_op, - value_type.as_ref(), - right_decimal_type, - ); - value_type - .map(|value_type| Dictionary(key_type.clone(), Box::new(value_type))) - } - (_, Dictionary(_, value_type)) => decimal_op_mathematics_type( - mathematics_op, - left_decimal_type, - value_type.as_ref(), - ), - _ => None, - } +fn create_decimal256_type(precision: u8, scale: i8) -> DataType { + DataType::Decimal256( + DECIMAL256_MAX_PRECISION.min(precision), + DECIMAL256_MAX_SCALE.min(scale), + ) } /// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (_, Null) => is_numeric(lhs_type), - (Null, _) => is_numeric(rhs_type), + (_, Null) => lhs_type.is_numeric(), + (Null, _) => rhs_type.is_numeric(), (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { - is_numeric(lhs_value_type) && is_numeric(rhs_value_type) + lhs_value_type.is_numeric() && rhs_value_type.is_numeric() } - (Dictionary(_, value_type), _) => is_numeric(value_type) && is_numeric(rhs_type), - (_, Dictionary(_, value_type)) => is_numeric(lhs_type) && is_numeric(value_type), - _ => is_numeric(lhs_type) && is_numeric(rhs_type), - } -} - -/// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal -fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (_, Null) => is_decimal(lhs_type), - (Null, _) => is_decimal(rhs_type), - (Decimal128(_, _), Decimal128(_, _)) => true, - (Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type), - (_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type), - _ => false, - } -} - -/// Determine if at least of one of lhs and rhs is decimal -pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type), - (_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type), - (_, _) => is_decimal(lhs_type) || is_decimal(rhs_type), + (Dictionary(_, value_type), _) => { + value_type.is_numeric() && rhs_type.is_numeric() + } + (_, Dictionary(_, value_type)) => { + lhs_type.is_numeric() && value_type.is_numeric() + } + _ => lhs_type.is_numeric() && rhs_type.is_numeric(), } } @@ -673,10 +667,21 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_concat_internal_coercion(from_type, &LargeUtf8) } + // TODO: cast between array elements (#6558) + (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()), _ => None, }) } +fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + // TODO: cast between array elements (#6558) + if lhs_type.equals_datatype(rhs_type) { + Some(lhs_type.to_owned()) + } else { + None + } +} + fn string_concat_internal_coercion( from_type: &DataType, to_type: &DataType, @@ -688,8 +693,9 @@ fn string_concat_internal_coercion( } } -/// Coercion rules for Strings: the type that both lhs and rhs can be -/// casted to for the purpose of a string computation +/// Coercion rules for string types (Utf8/LargeUtf8): If at least one argument is +/// a string type and both arguments can be coerced into a string type, coerce +/// to string type. fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -697,6 +703,44 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option (LargeUtf8, Utf8) => Some(LargeUtf8), (Utf8, LargeUtf8) => Some(LargeUtf8), (LargeUtf8, LargeUtf8) => Some(LargeUtf8), + // TODO: cast between array elements (#6558) + (List(_), List(_)) => Some(lhs_type.clone()), + (List(_), _) => Some(lhs_type.clone()), + (_, List(_)) => Some(rhs_type.clone()), + _ => None, + } +} + +/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8): +/// If one argument is binary and the other is a string then coerce to string +/// (e.g. for `like`) +fn binary_to_string_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Binary, Utf8) => Some(Utf8), + (Binary, LargeUtf8) => Some(LargeUtf8), + (LargeBinary, Utf8) => Some(LargeUtf8), + (LargeBinary, LargeUtf8) => Some(LargeUtf8), + (Utf8, Binary) => Some(Utf8), + (Utf8, LargeBinary) => Some(LargeUtf8), + (LargeUtf8, Binary) => Some(LargeUtf8), + (LargeUtf8, LargeBinary) => Some(LargeUtf8), + _ => None, + } +} + +/// Coercion rules for binary types (Binary/LargeBinary): If at least one argument is +/// a binary type and both arguments can be coerced into a binary type, coerce +/// to binary type. +fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Binary | Utf8, Binary) | (Binary, Utf8) => Some(Binary), + (LargeBinary | Binary | Utf8 | LargeUtf8, LargeBinary) + | (LargeBinary, Binary | Utf8 | LargeUtf8) => Some(LargeBinary), _ => None, } } @@ -705,6 +749,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) + .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -737,30 +782,9 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Interval(MonthDayNano)), (Date64, Date32) | (Date32, Date64) => Some(Date64), - (Utf8, Date32) | (Date32, Utf8) => Some(Date32), - (Utf8, Date64) | (Date64, Utf8) => Some(Date64), - (Utf8, Time32(unit)) | (Time32(unit), Utf8) => { - match is_time_with_valid_unit(Time32(unit.clone())) { - false => None, - true => Some(Time32(unit.clone())), - } - } - (Utf8, Time64(unit)) | (Time64(unit), Utf8) => { - match is_time_with_valid_unit(Time64(unit.clone())) { - false => None, - true => Some(Time64(unit.clone())), - } - } - (Timestamp(_, tz), Utf8) | (Utf8, Timestamp(_, tz)) => { - Some(Timestamp(Nanosecond, tz.clone())) - } (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } @@ -807,7 +831,7 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { match (lhs_type, rhs_type) { @@ -824,29 +848,20 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { #[cfg(test)] mod tests { - use arrow::datatypes::DataType; - - use datafusion_common::assert_contains; - use datafusion_common::DataFusionError; - use datafusion_common::Result; - + use super::*; use crate::Operator; - use super::*; + use arrow::datatypes::DataType; + use datafusion_common::{assert_contains, Result}; #[test] fn test_coercion_error() -> Result<()> { let result_type = - coerce_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); + get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8); - if let Err(DataFusionError::Plan(e)) = result_type { - assert_eq!(e, "Float32 + Utf8 can't be evaluated because there isn't a common type to coerce the types to"); - Ok(()) - } else { - Err(DataFusionError::Internal( - "Coercion should have returned an DataFusionError::Internal".to_string(), - )) - } + let e = result_type.unwrap_err(); + assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types"); + Ok(()) } #[test] @@ -885,12 +900,14 @@ mod tests { for (i, input_type) in input_types.iter().enumerate() { let expect_type = &result_types[i]; for op in comparison_op_types { - let result_type = coerce_types(&input_decimal, &op, input_type)?; - assert_eq!(expect_type, &result_type); + let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?; + assert_eq!(expect_type, &lhs); + assert_eq!(expect_type, &rhs); } } // negative test - let result_type = coerce_types(&input_decimal, &Operator::Eq, &DataType::Boolean); + let result_type = + get_input_types(&input_decimal, &Operator::Eq, &DataType::Boolean); assert!(result_type.is_err()); Ok(()) } @@ -921,53 +938,6 @@ mod tests { coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), DataType::Decimal128(30, 15) ); - - let op = Operator::Plus; - let left_decimal_type = DataType::Decimal128(10, 3); - let right_decimal_type = DataType::Decimal128(20, 4); - let result = coercion_decimal_mathematics_type( - &op, - &left_decimal_type, - &right_decimal_type, - ); - assert_eq!(DataType::Decimal128(21, 4), result.unwrap()); - let op = Operator::Minus; - let result = coercion_decimal_mathematics_type( - &op, - &left_decimal_type, - &right_decimal_type, - ); - assert_eq!(DataType::Decimal128(21, 4), result.unwrap()); - let op = Operator::Multiply; - let result = coercion_decimal_mathematics_type( - &op, - &left_decimal_type, - &right_decimal_type, - ); - assert_eq!(None, result); - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(31, 7), result.unwrap()); - let op = Operator::Divide; - let result = coercion_decimal_mathematics_type( - &op, - &left_decimal_type, - &right_decimal_type, - ); - assert_eq!(DataType::Decimal128(20, 4), result.unwrap()); - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(35, 24), result.unwrap()); - let op = Operator::Modulo; - let result = coercion_decimal_mathematics_type( - &op, - &left_decimal_type, - &right_decimal_type, - ); - assert_eq!(DataType::Decimal128(20, 4), result.unwrap()); - let result = - decimal_op_mathematics_type(&op, &left_decimal_type, &right_decimal_type); - assert_eq!(DataType::Decimal128(11, 4), result.unwrap()); } #[test] @@ -987,10 +957,13 @@ mod tests { let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Utf8)); - // Can not coerce values of Binary to int, cannot support this + // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None); + assert_eq!( + dictionary_coercion(&lhs_type, &rhs_type, true), + Some(Binary) + ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; @@ -1009,36 +982,99 @@ mod tests { ); } + /// Test coercion rules for binary operators + /// + /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that the + /// the result type is `$RESULT_TYPE` macro_rules! test_coercion_binary_rule { - ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ - let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; - assert_eq!(result, $C_TYPE); + ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ + let (lhs, rhs) = get_input_types(&$LHS_TYPE, &$OP, &$RHS_TYPE)?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + }}; + } + + /// Test coercion rules for like + /// + /// Applies coercion rules for both + /// * `$LHS_TYPE LIKE $RHS_TYPE` + /// * `$RHS_TYPE LIKE $LHS_TYPE` + /// + /// And asserts the result type is `$RESULT_TYPE` + macro_rules! test_like_rule { + ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{ + println!("Coercing {} LIKE {}", $LHS_TYPE, $RHS_TYPE); + let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE); + assert_eq!(result, $RESULT_TYPE); + // reverse the order + let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE); + assert_eq!(result, $RESULT_TYPE); }}; } #[test] fn test_date_timestamp_arithmetic_error() -> Result<()> { - let common_type = coerce_types( + let (lhs, rhs) = get_input_types( &DataType::Timestamp(TimeUnit::Nanosecond, None), &Operator::Minus, &DataType::Timestamp(TimeUnit::Millisecond, None), )?; - assert_eq!(common_type.to_string(), "Timestamp(Millisecond, None)"); + assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)"); + assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)"); - let err = coerce_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) + let err = get_input_types(&DataType::Date32, &Operator::Plus, &DataType::Date64) .unwrap_err() .to_string(); - assert_contains!(&err, "Date32 + Date64 can't be evaluated because there isn't a common type to coerce the types to"); + + assert_contains!( + &err, + "Cannot get result type for temporal operation Date64 + Date64" + ); Ok(()) } #[test] - fn test_type_coercion() -> Result<()> { - // test like coercion rule - let result = like_coercion(&DataType::Utf8, &DataType::Utf8); - assert_eq!(result, Some(DataType::Utf8)); + fn test_like_coercion() { + // string coerce to strings + test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeUtf8, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Utf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeUtf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + + // Also coerce binary to strings + test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeBinary, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Binary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + } + #[test] + fn test_type_coercion() -> Result<()> { test_coercion_binary_rule!( DataType::Utf8, DataType::Date32, @@ -1227,29 +1263,13 @@ mod tests { fn test_math_decimal_coercion_rule( lhs_type: DataType, rhs_type: DataType, - mathematics_op: Operator, - expected_lhs_type: Option, - expected_rhs_type: Option, - expected_coerced_type: Option, - expected_output_type: DataType, + expected_lhs_type: DataType, + expected_rhs_type: DataType, ) { // The coerced types for lhs and rhs, if any of them is not decimal - let (l, r) = math_decimal_coercion(&lhs_type, &rhs_type); - assert_eq!(l, expected_lhs_type); - assert_eq!(r, expected_rhs_type); - - let lhs_type = l.unwrap_or(lhs_type); - let rhs_type = r.unwrap_or(rhs_type); - - // The coerced type of decimal math expression, applied during expression evaluation - let coerced_type = - coercion_decimal_mathematics_type(&mathematics_op, &lhs_type, &rhs_type); - assert_eq!(coerced_type, expected_coerced_type); - - // The output type of decimal math expression - let output_type = - decimal_op_mathematics_type(&mathematics_op, &lhs_type, &rhs_type).unwrap(); - assert_eq!(output_type, expected_output_type); + let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!(lhs_type, expected_lhs_type); + assert_eq!(rhs_type, expected_rhs_type); } #[test] @@ -1257,60 +1277,42 @@ mod tests { test_math_decimal_coercion_rule( DataType::Decimal128(10, 2), DataType::Decimal128(10, 2), - Operator::Plus, - None, - None, - Some(DataType::Decimal128(11, 2)), - DataType::Decimal128(11, 2), + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Plus, - Some(DataType::Decimal128(10, 0)), - None, - Some(DataType::Decimal128(13, 2)), - DataType::Decimal128(13, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Minus, - Some(DataType::Decimal128(10, 0)), - None, - Some(DataType::Decimal128(13, 2)), - DataType::Decimal128(13, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Multiply, - Some(DataType::Decimal128(10, 0)), - None, - None, - DataType::Decimal128(21, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Divide, - Some(DataType::Decimal128(10, 0)), - None, - Some(DataType::Decimal128(12, 2)), - DataType::Decimal128(23, 11), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), ); test_math_decimal_coercion_rule( DataType::Int32, DataType::Decimal128(10, 2), - Operator::Modulo, - Some(DataType::Decimal128(10, 0)), - None, - Some(DataType::Decimal128(12, 2)), + DataType::Decimal128(10, 0), DataType::Decimal128(10, 2), ); @@ -1391,6 +1393,70 @@ mod tests { DataType::Decimal128(15, 3) ); + // Binary + test_coercion_binary_rule!( + DataType::Binary, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::Utf8, + Operator::Eq, + DataType::Binary + ); + + // LargeBinary + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Binary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Utf8, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeUtf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Operator::Eq, + DataType::LargeBinary + ); + // TODO add other data type Ok(()) } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d86914325fc98..79b5742384953 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::signature::TIMEZONE_WILDCARD; use crate::{Signature, TypeSignature}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; /// Performs type coercion for function arguments. /// @@ -34,8 +35,17 @@ pub fn data_types( signature: &Signature, ) -> Result> { if current_types.is_empty() { - return Ok(vec![]); + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -45,6 +55,8 @@ pub fn data_types( return Ok(current_types.to_vec()); } + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. for valid_types in valid_types { if let Some(types) = maybe_data_types(&valid_types, current_types) { return Ok(types); @@ -52,12 +64,14 @@ pub fn data_types( } // none possible -> Error - Err(DataFusionError::Plan(format!( + plan_err!( "Coercion from {:?} to the signature {:?} failed.", - current_types, &signature.type_signature - ))) + current_types, + &signature.type_signature + ) } +/// Returns a Vec of all possible valid argument types for the given signature. fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], @@ -84,11 +98,11 @@ fn get_valid_types( TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::Any(number) => { if current_types.len() != *number { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function expected {} arguments but received {}", number, current_types.len() - ))); + ); } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } @@ -102,7 +116,12 @@ fn get_valid_types( Ok(valid_types) } -/// Try to coerce current_types into valid_types. +/// Try to coerce the current argument types to match the given `valid_types`. +/// +/// For example, if a function `func` accepts arguments of `(int64, int64)`, +/// but was called with `(int32, int64)`, this function could match the +/// valid_types by coercing the first argument to `int64`, and would return +/// `Some([int64, int64])`. fn maybe_data_types( valid_types: &[DataType], current_types: &[DataType], @@ -119,8 +138,8 @@ fn maybe_data_types( new_type.push(current_type.clone()) } else { // attempt to coerce - if can_coerce_from(valid_type, current_type) { - new_type.push(valid_type.clone()) + if let Some(valid_type) = coerced_from(valid_type, current_type) { + new_type.push(valid_type) } else { // not possible return None; @@ -135,69 +154,123 @@ fn maybe_data_types( /// /// See the module level documentation for more detail on coercion. pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { - use self::DataType::*; - if type_into == type_from { return true; } - // Null can convert to most of types + if let Some(coerced) = coerced_from(type_into, type_from) { + return coerced == *type_into; + } + false +} + +fn coerced_from<'a>( + type_into: &'a DataType, + type_from: &'a DataType, +) -> Option { + use self::DataType::*; + match type_into { - Int8 => matches!(type_from, Null | Int8), - Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8), - Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16), - Int64 => matches!( - type_from, - Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 - ), - UInt8 => matches!(type_from, Null | UInt8), - UInt16 => matches!(type_from, Null | UInt8 | UInt16), - UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32), - UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64), - Float32 => matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - ), - Float64 => matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - ), - Timestamp(TimeUnit::Nanosecond, _) => { - matches!( + // coerced into type_into + Int8 if matches!(type_from, Null | Int8) => Some(type_into.clone()), + Int16 if matches!(type_from, Null | Int8 | Int16 | UInt8) => { + Some(type_into.clone()) + } + Int32 if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => { + Some(type_into.clone()) + } + Int64 + if matches!( type_from, - Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 - ) + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 + ) => + { + Some(type_into.clone()) + } + UInt8 if matches!(type_from, Null | UInt8) => Some(type_into.clone()), + UInt16 if matches!(type_from, Null | UInt8 | UInt16) => Some(type_into.clone()), + UInt32 if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => { + Some(type_into.clone()) } - Interval(_) => { - matches!(type_from, Utf8 | LargeUtf8) + UInt64 if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => { + Some(type_into.clone()) } - Utf8 | LargeUtf8 => true, - Null => can_cast_types(type_from, type_into), - _ => false, + Float32 + if matches!( + type_from, + Null | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + ) => + { + Some(type_into.clone()) + } + Float64 + if matches!( + type_from, + Null | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _) + ) => + { + Some(type_into.clone()) + } + Timestamp(TimeUnit::Nanosecond, None) + if matches!( + type_from, + Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8 + ) => + { + Some(type_into.clone()) + } + Interval(_) if matches!(type_from, Utf8 | LargeUtf8) => Some(type_into.clone()), + // Any type can be coerced into strings + Utf8 | LargeUtf8 => Some(type_into.clone()), + Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { + match type_from { + Timestamp(_, Some(from_tz)) => { + Some(Timestamp(unit.clone(), Some(from_tz.clone()))) + } + Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { + // In the absence of any other information assume the time zone is "+00" (UTC). + Some(Timestamp(unit.clone(), Some("+00".into()))) + } + _ => None, + } + } + Timestamp(_, Some(_)) + if matches!( + type_from, + Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 + ) => + { + Some(type_into.clone()) + } + + // cannot coerce + _ => None, } } #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, TimeUnit}; #[test] fn test_maybe_data_types() { @@ -229,6 +302,20 @@ mod tests { vec![DataType::Boolean, DataType::UInt16], Some(vec![DataType::Boolean, DataType::UInt32]), ), + // UTF8 -> Timestamp + ( + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), + ], + vec![DataType::Utf8, DataType::Utf8, DataType::Utf8], + Some(vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())), + ]), + ), ]; for case in cases { diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 0881bce98d6ae..86005da3dafa7 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -49,6 +49,7 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _), ) } @@ -57,15 +58,6 @@ pub fn is_null(dt: &DataType) -> bool { *dt == DataType::Null } -/// Determine whether the given data type `dt` represents numeric values. -pub fn is_numeric(dt: &DataType) -> bool { - is_signed_numeric(dt) - || matches!( - dt, - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 - ) -} - /// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) @@ -91,5 +83,5 @@ pub fn is_utf8_or_large_utf8(dt: &DataType) -> bool { /// Determine whether the given data type `dt` is a `Decimal`. pub fn is_decimal(dt: &DataType) -> bool { - matches!(dt, DataType::Decimal128(_, _)) + matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) } diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index c53054e82112f..634558094ae79 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -28,9 +28,8 @@ pub fn get_coerce_type_for_list( ) -> Option { list_types .iter() - .fold(Some(expr_type.clone()), |left, right_type| match left { - None => None, - Some(left_type) => comparison_coercion(&left_type, right_type), + .try_fold(expr_type.clone(), |left_type, right_type| { + comparison_coercion(&left_type, right_type) }) } @@ -47,11 +46,9 @@ pub fn get_coerce_type_for_case_expression( }; when_or_then_types .iter() - .fold(Some(case_or_else_type), |left, right_type| match left { - // failed to find a valid coercion in a previous iteration - None => None, + .try_fold(case_or_else_type, |left_type, right_type| { // TODO: now just use the `equal` coercion rule for case when. If find the issue, and // refactor again. - Some(left_type) => comparison_coercion(&left_type, right_type), + comparison_coercion(&left_type, right_type) }) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 6c3690e283d2b..cfbca4ab1337a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -15,29 +15,48 @@ // specific language governing permissions and limitations // under the License. -//! Udaf module contains functions and structs supporting user-defined aggregate functions. +//! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::Expr; +use crate::{Accumulator, Expr}; use crate::{ - AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction, + AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. +/// Logical representation of a user-defined [aggregate function] (UDAF). +/// +/// An aggregate function combines the values from multiple input rows +/// into a single output "aggregate" (summary) row. It is different +/// from a scalar function because it is stateful across batches. User +/// defined aggregate functions can be used as normal SQL aggregate +/// functions (`GROUP BY` clause) as well as window functions (`OVER` +/// clause). +/// +/// `AggregateUDF` provides DataFusion the information needed to plan +/// and call aggregate functions, including name, type information, +/// and a factory function to create [`Accumulator`], which peform the +/// actual aggregation. +/// +/// For more information, please see [the examples]. +/// +/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process +/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function +/// [`Accumulator`]: crate::Accumulator #[derive(Clone)] pub struct AggregateUDF { /// name - pub name: String, - /// signature - pub signature: Signature, + name: String, + /// Signature (input arguments) + signature: Signature, /// Return type - pub return_type: ReturnTypeFunction, + return_type: ReturnTypeFunction, /// actual implementation - pub accumulator: AccumulatorFunctionImplementation, + accumulator: AccumulatorFactoryFunction, /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, + state_type: StateTypeFunction, } impl Debug for AggregateUDF { @@ -71,7 +90,7 @@ impl AggregateUDF { name: &str, signature: &Signature, return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFunctionImplementation, + accumulator: &AccumulatorFactoryFunction, state_type: &StateTypeFunction, ) -> Self { Self { @@ -83,14 +102,48 @@ impl AggregateUDF { } } - /// creates a logical expression with a call of the UDAF - /// This utility allows using the UDAF without requiring access to the registry. + /// creates an [`Expr`] that calls the aggregate function. + /// + /// This utility allows using the UDAF without requiring access to + /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return an accumualator the given aggregate, given + /// its return datatype. + pub fn accumulator(&self, return_type: &DataType) -> Result> { + (self.accumulator)(return_type) + } + + /// Return the type of the intermediate state used by this aggregator, given + /// its return datatype. Supports multi-phase aggregations + pub fn state_type(&self, return_type: &DataType) -> Result> { + // old API returns an Arc for some reason, try and unwrap it here + let res = (self.state_type)(return_type)?; + Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985d..3a18ca2d25e82 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -15,23 +15,31 @@ // specific language governing permissions and limitations // under the License. -//! Udf module contains foundational types that are used to represent UDFs in DataFusion. +//! [`ScalarUDF`]: Scalar User Defined Functions use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -/// Logical representation of a UDF. +/// Logical representation of a Scalar User Defined Function. +/// +/// A scalar function produces a single row output for each row of input. +/// +/// This struct contains the information DataFusion needs to plan and invoke +/// functions such name, type signature, return type, and actual implementation. +/// #[derive(Clone)] pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, /// actual implementation /// /// The fn param is the wrapped function but be aware that the function will @@ -40,7 +48,9 @@ pub struct ScalarUDF { /// will be passed. In that case the single element is a null array to indicate /// the batch's row count (so that the generative zero-argument function can know /// the result array size). - pub fun: ScalarFunctionImplementation, + fun: ScalarFunctionImplementation, + /// Optional aliases for the function. This list should NOT include the value of `name` as well + aliases: Vec, } impl Debug for ScalarUDF { @@ -81,12 +91,55 @@ impl ScalarUDF { signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), + aliases: vec![], } } + /// Adds additional names that can be used to invoke this function, in addition to `name` + pub fn with_aliases( + mut self, + aliases: impl IntoIterator, + ) -> Self { + self.aliases + .extend(aliases.into_iter().map(|s| s.to_string())); + self + } + /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) + Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Arc::new(self.clone()), + args, + )) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + pub fn aliases(&self) -> &[String] { + &self.aliases } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return the actual implementation + pub fn fun(&self) -> ScalarFunctionImplementation { + self.fun.clone() + } + + // TODO maybe add an invoke() method that runs the actual function? } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs new file mode 100644 index 0000000000000..c233ee84b32da --- /dev/null +++ b/datafusion/expr/src/udwf.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`WindowUDF`]: User Defined Window Functions + +use crate::{ + Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, + WindowFrame, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::{ + fmt::{self, Debug, Display, Formatter}, + sync::Arc, +}; + +/// Logical representation of a user-defined window function (UDWF) +/// A UDWF is different from a UDF in that it is stateful across batches. +/// +/// See the documetnation on [`PartitionEvaluator`] for more details +/// +/// [`PartitionEvaluator`]: crate::PartitionEvaluator +#[derive(Clone)] +pub struct WindowUDF { + /// name + name: String, + /// signature + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// Return the partition evaluator + partition_evaluator_factory: PartitionEvaluatorFactory, +} + +impl Debug for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("WindowUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &"") + .field("partition_evaluator_factory", &"") + .finish_non_exhaustive() + } +} + +/// Defines how the WindowUDF is shown to users +impl Display for WindowUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PartialEq for WindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl Eq for WindowUDF {} + +impl std::hash::Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl WindowUDF { + /// Create a new WindowUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + partition_evaluator_factory: &PartitionEvaluatorFactory, + ) -> Self { + Self { + name: name.to_string(), + signature: signature.clone(), + return_type: return_type.clone(), + partition_evaluator_factory: partition_evaluator_factory.clone(), + } + } + + /// creates a [`Expr`] that calls the window function given + /// the `partition_by`, `order_by`, and `window_frame` definition + /// + /// This utility allows using the UDWF without requiring access to + /// the registry, such as with the DataFrame API. + pub fn call( + &self, + args: Vec, + partition_by: Vec, + order_by: Vec, + window_frame: WindowFrame, + ) -> Expr { + let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + + Expr::WindowFunction(crate::expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + }) + } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + + /// Returns this function's signature (what input types are accepted) + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Return the type of the function given its input types + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + + /// Return a `PartitionEvaluator` for evaluating this window function + pub fn partition_evaluator_factory(&self) -> Result> { + (self.partition_evaluator_factory)() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c2eabea85727a..abdd7f5f57f61 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,30 +17,27 @@ //! Expression utilities -use crate::expr::{Sort, WindowFunction}; -use crate::logical_plan::builder::build_join_schema; -use crate::logical_plan::{ - Aggregate, Analyze, Distinct, Extension, Filter, Join, Limit, Partitioning, Prepare, - Projection, Repartition, Sort as SortPlan, Subquery, SubqueryAlias, Union, Unnest, - Values, Window, -}; +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + +use crate::expr::{Alias, Sort, WindowFunction}; +use crate::expr_rewriter::strip_outer_reference; +use crate::logical_plan::Aggregate; +use crate::signature::{Signature, TypeSignature}; use crate::{ - BinaryExpr, Cast, CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr, - ExprSchemable, GroupingSet, LogicalPlan, LogicalPlanBuilder, Operator, TableScan, - TryCast, + and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, + Operator, TryCast, }; + use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion, -}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, TableReference, }; + use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; -use std::cmp::Ordering; -use std::collections::HashSet; -use std::sync::Arc; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions @@ -60,10 +57,9 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { if group_expr.len() > 1 { - return Err(DataFusionError::Plan( + return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); + ); } Ok(grouping_set.distinct_expr().len()) } else { @@ -114,7 +110,7 @@ fn powerset(slice: &[T]) -> Result>, String> { fn check_grouping_set_size_limit(size: usize) -> Result<()> { let max_grouping_set_size = 65535; if size > max_grouping_set_size { - return Err(DataFusionError::Plan(format!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"))); + return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"); } Ok(()) @@ -124,7 +120,7 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> { fn check_grouping_sets_size_limit(size: usize) -> Result<()> { let max_grouping_sets_size = 4096; if size > max_grouping_sets_size { - return Err(DataFusionError::Plan(format!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"))); + return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"); } Ok(()) @@ -210,8 +206,8 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { grouping_sets.iter().map(|e| e.iter().collect()).collect() } Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { - let grouping_sets = - powerset(group_exprs).map_err(DataFusionError::Plan)?; + let grouping_sets = powerset(group_exprs) + .map_err(|e| plan_datafusion_err!("{}", e))?; check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets } @@ -252,10 +248,9 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { if group_expr.len() > 1 { - return Err(DataFusionError::Plan( + return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); + ); } Ok(grouping_set.distinct_expr()) } else { @@ -275,11 +270,10 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { // implementation, so that in the future if someone adds // new Expr types, they will check here as well Expr::ScalarVariable(_, _) - | Expr::Alias(_, _) + | Expr::Alias(_) | Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Like { .. } - | Expr::ILike { .. } | Expr::SimilarTo { .. } | Expr::Not(_) | Expr::IsNotNull(_) @@ -297,17 +291,14 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} @@ -319,15 +310,15 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { /// Find excluded columns in the schema, if any /// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]` fn get_excluded_columns( - opt_exclude: Option, - opt_except: Option, + opt_exclude: Option<&ExcludeSelectItem>, + opt_except: Option<&ExceptSelectItem>, schema: &DFSchema, qualifier: &Option, ) -> Result> { let mut idents = vec![]; if let Some(excepts) = opt_except { - idents.push(excepts.first_element); - idents.extend(excepts.additional_elements); + idents.push(&excepts.first_element); + idents.extend(&excepts.additional_elements); } if let Some(exclude) = opt_exclude { match exclude { @@ -341,9 +332,7 @@ fn get_excluded_columns( // if HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { - return Err(DataFusionError::Plan( - "EXCLUDE or EXCEPT contains duplicate column names".to_string(), - )); + return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); } let mut result = vec![]; @@ -390,7 +379,7 @@ fn get_exprs_except_skipped( pub fn expand_wildcard( schema: &DFSchema, plan: &LogicalPlan, - wildcard_options: Option, + wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let using_columns = plan.using_columns()?; let mut columns_to_skip = using_columns @@ -420,7 +409,7 @@ pub fn expand_wildcard( .. }) = wildcard_options { - get_excluded_columns(opt_exclude, opt_except, schema, &None)? + get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, &None)? } else { vec![] }; @@ -433,7 +422,7 @@ pub fn expand_wildcard( pub fn expand_qualified_wildcard( qualifier: &str, schema: &DFSchema, - wildcard_options: Option, + wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); let qualified_fields: Vec = schema @@ -442,19 +431,24 @@ pub fn expand_qualified_wildcard( .cloned() .collect(); if qualified_fields.is_empty() { - return Err(DataFusionError::Plan(format!( - "Invalid qualifier {qualifier}" - ))); + return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = - DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?; + DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? + // We can use the functional dependencies as is, since it only stores indices: + .with_functional_dependencies(schema.functional_dependencies().clone())?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, .. }) = wildcard_options { - get_excluded_columns(opt_exclude, opt_except, schema, &Some(qualifier))? + get_excluded_columns( + opt_exclude.as_ref(), + opt_except.as_ref(), + schema, + &Some(qualifier), + )? } else { vec![] }; @@ -479,9 +473,7 @@ pub fn generate_sort_key( Expr::Sort(Sort { expr, .. }) => { Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) } - _ => Err(DataFusionError::Plan( - "Order by only accepts sort expressions".to_string(), - )), + _ => plan_err!("Order by only accepts sort expressions"), }) .collect::>>()?; @@ -512,7 +504,6 @@ pub fn generate_sort_key( let res = final_sort_keys .into_iter() .zip(is_partition_flag) - .map(|(lhs, rhs)| (lhs, rhs)) .collect::>(); Ok(res) } @@ -598,22 +589,19 @@ pub fn group_window_expr_by_sort_keys( } Ok(()) } - other => Err(DataFusionError::Internal(format!( - "Impossibly got non-window expr {other:?}", - ))), + other => internal_err!( + "Impossibly got non-window expr {other:?}" + ), })?; Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } @@ -724,308 +712,16 @@ where /// // create new plan using rewritten_exprs in same position /// let new_plan = from_plan(&plan, rewritten_exprs, new_inputs); /// ``` +/// +/// Notice: sometimes [from_plan] will use schema of original plan, it don't change schema! +/// Such as `Projection/Aggregate/Window` +#[deprecated(since = "31.0.0", note = "use LogicalPlan::with_new_exprs instead")] pub fn from_plan( plan: &LogicalPlan, expr: &[Expr], inputs: &[LogicalPlan], ) -> Result { - match plan { - LogicalPlan::Projection(Projection { schema, .. }) => { - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - expr.to_vec(), - Arc::new(inputs[0].clone()), - schema.clone(), - )?)) - } - LogicalPlan::Dml(DmlStatement { - table_name, - table_schema, - op, - .. - }) => Ok(LogicalPlan::Dml(DmlStatement { - table_name: table_name.clone(), - table_schema: table_schema.clone(), - op: op.clone(), - input: Arc::new(inputs[0].clone()), - })), - LogicalPlan::Values(Values { schema, .. }) => Ok(LogicalPlan::Values(Values { - schema: schema.clone(), - values: expr - .chunks_exact(schema.fields().len()) - .map(|s| s.to_vec()) - .collect::>(), - })), - LogicalPlan::Filter { .. } => { - assert_eq!(1, expr.len()); - let predicate = expr[0].clone(); - - // filter predicates should not contain aliased expressions so we remove any aliases - // before this logic was added we would have aliases within filters such as for - // benchmark q6: - // - // lineitem.l_shipdate >= Date32(\"8766\") - // AND lineitem.l_shipdate < Date32(\"9131\") - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - // Decimal128(Some(49999999999999),30,15) - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - // Decimal128(Some(69999999999999),30,15) - // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - - struct RemoveAliases {} - - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) - } - Expr::Alias(_, _) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), - } - } - - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) - } - } - - let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; - - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(inputs[0].clone()), - )?)) - } - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { - Partitioning::RoundRobinBatch(n) => { - Ok(LogicalPlan::Repartition(Repartition { - partitioning_scheme: Partitioning::RoundRobinBatch(*n), - input: Arc::new(inputs[0].clone()), - })) - } - Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition { - partitioning_scheme: Partitioning::Hash(expr.to_owned(), *n), - input: Arc::new(inputs[0].clone()), - })), - Partitioning::DistributeBy(_) => Ok(LogicalPlan::Repartition(Repartition { - partitioning_scheme: Partitioning::DistributeBy(expr.to_owned()), - input: Arc::new(inputs[0].clone()), - })), - }, - LogicalPlan::Window(Window { - window_expr, - schema, - .. - }) => Ok(LogicalPlan::Window(Window { - input: Arc::new(inputs[0].clone()), - window_expr: expr[0..window_expr.len()].to_vec(), - schema: schema.clone(), - })), - LogicalPlan::Aggregate(Aggregate { - group_expr, schema, .. - }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(inputs[0].clone()), - expr[0..group_expr.len()].to_vec(), - expr[group_expr.len()..].to_vec(), - schema.clone(), - )?)), - LogicalPlan::Sort(SortPlan { fetch, .. }) => Ok(LogicalPlan::Sort(SortPlan { - expr: expr.to_vec(), - input: Arc::new(inputs[0].clone()), - fetch: *fetch, - })), - LogicalPlan::Join(Join { - join_type, - join_constraint, - on, - null_equals_null, - .. - }) => { - let schema = - build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; - - let equi_expr_count = on.len(); - assert!(expr.len() >= equi_expr_count); - - // The preceding part of expr is equi-exprs, - // and the struct of each equi-expr is like `left-expr = right-expr`. - let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| { - // SimplifyExpression rule may add alias to the equi_expr. - let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op:Operator::Eq, right }) = unalias_expr { - Ok((*left, *right)) - } else { - Err(DataFusionError::Internal(format!( - "The front part expressions should be an binary equiality expression, actual:{equi_expr}" - ))) - } - }).collect::>>()?; - - // Assume that the last expr, if any, - // is the filter_expr (non equality predicate from ON clause) - let filter_expr = - (expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone()); - - Ok(LogicalPlan::Join(Join { - left: Arc::new(inputs[0].clone()), - right: Arc::new(inputs[1].clone()), - join_type: *join_type, - join_constraint: *join_constraint, - on: new_on, - filter: filter_expr, - schema: DFSchemaRef::new(schema), - null_equals_null: *null_equals_null, - })) - } - LogicalPlan::CrossJoin(_) => { - let left = inputs[0].clone(); - let right = inputs[1].clone(); - LogicalPlanBuilder::from(left).cross_join(right)?.build() - } - LogicalPlan::Subquery(Subquery { - outer_ref_columns, .. - }) => { - let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?; - Ok(LogicalPlan::Subquery(Subquery { - subquery: Arc::new(subquery), - outer_ref_columns: outer_ref_columns.clone(), - })) - } - LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - inputs[0].clone(), - alias.clone(), - )?)) - } - LogicalPlan::Limit(Limit { skip, fetch, .. }) => Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, - input: Arc::new(inputs[0].clone()), - })), - LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - name, - if_not_exists, - or_replace, - .. - })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( - CreateMemoryTable { - input: Arc::new(inputs[0].clone()), - primary_key: vec![], - name: name.clone(), - if_not_exists: *if_not_exists, - or_replace: *or_replace, - }, - ))), - LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - name, - or_replace, - definition, - .. - })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - input: Arc::new(inputs[0].clone()), - name: name.clone(), - or_replace: *or_replace, - definition: definition.clone(), - }))), - LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { - node: e.node.from_template(expr, inputs), - })), - LogicalPlan::Union(Union { schema, .. }) => Ok(LogicalPlan::Union(Union { - inputs: inputs.iter().cloned().map(Arc::new).collect(), - schema: schema.clone(), - })), - LogicalPlan::Distinct(Distinct { .. }) => Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(inputs[0].clone()), - })), - LogicalPlan::Analyze(a) => { - assert!(expr.is_empty()); - assert_eq!(inputs.len(), 1); - Ok(LogicalPlan::Analyze(Analyze { - verbose: a.verbose, - schema: a.schema.clone(), - input: Arc::new(inputs[0].clone()), - })) - } - LogicalPlan::Explain(_) => { - // Explain should be handled specially in the optimizers; - // If this check cannot pass it means some optimizer pass is - // trying to optimize Explain directly - if expr.is_empty() { - return Err(DataFusionError::Plan( - "Invalid EXPLAIN command. Expression is empty".to_string(), - )); - } - - if inputs.is_empty() { - return Err(DataFusionError::Plan( - "Invalid EXPLAIN command. Inputs are empty".to_string(), - )); - } - - Ok(plan.clone()) - } - LogicalPlan::Prepare(Prepare { - name, data_types, .. - }) => Ok(LogicalPlan::Prepare(Prepare { - name: name.clone(), - data_types: data_types.clone(), - input: Arc::new(inputs[0].clone()), - })), - LogicalPlan::TableScan(ts) => { - assert!(inputs.is_empty(), "{plan:?} should have no inputs"); - Ok(LogicalPlan::TableScan(TableScan { - filters: expr.to_vec(), - ..ts.clone() - })) - } - LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Statement(_) => { - // All of these plan types have no inputs / exprs so should not be called - assert!(expr.is_empty(), "{plan:?} should have no exprs"); - assert!(inputs.is_empty(), "{plan:?} should have no inputs"); - Ok(plan.clone()) - } - LogicalPlan::DescribeTable(_) => Ok(plan.clone()), - LogicalPlan::Unnest(Unnest { column, schema, .. }) => { - // Update schema with unnested column type. - let input = Arc::new(inputs[0].clone()); - let nested_field = input.schema().field_from_column(column)?; - let unnested_field = schema.field_from_column(column)?; - let fields = input - .schema() - .fields() - .iter() - .map(|f| { - if f == nested_field { - unnested_field.clone() - } else { - f.clone() - } - }) - .collect::>(); - - let schema = Arc::new(DFSchema::new_with_metadata( - fields, - input.schema().metadata().clone(), - )?); - - Ok(LogicalPlan::Unnest(Unnest { - input, - column: column.clone(), - schema, - })) - } - } + plan.with_new_exprs(expr.to_vec(), inputs) } /// Find all columns referenced from an aggregate query @@ -1037,11 +733,7 @@ fn agg_cols(agg: &Aggregate) -> Vec { .collect() } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - plan: &LogicalPlan, - agg: &Aggregate, -) -> Result> { +fn exprlist_to_fields_aggregate(exprs: &[Expr], agg: &Aggregate) -> Result> { let agg_cols = agg_cols(agg); let mut fields = vec![]; for expr in exprs { @@ -1050,7 +742,7 @@ fn exprlist_to_fields_aggregate( // resolve against schema of input to aggregate fields.push(expr.to_field(agg.input.schema())?); } - _ => fields.push(expr.to_field(plan.schema())?), + _ => fields.push(expr.to_field(&agg.schema)?), } } Ok(fields) @@ -1067,15 +759,7 @@ pub fn exprlist_to_fields<'a>( // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to // look at the input to the aggregate instead. let fields = match plan { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - LogicalPlan::Window(window) => match window.input.as_ref() { - LogicalPlan::Aggregate(agg) => { - Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) - } - _ => None, - }, + LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(&exprs, agg)), _ => None, }; if let Some(fields) = fields { @@ -1106,9 +790,11 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(inner_expr, name) => { - columnize_expr(*inner_expr, input_schema).alias(name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { expr: Box::new(columnize_expr(*expr, input_schema)), data_type, @@ -1205,7 +891,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt64 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, None) => match time_unit { + DataType::Timestamp(time_unit, _) => match time_unit { TimeUnit::Second => true, TimeUnit::Millisecond => true, TimeUnit::Microsecond => true, @@ -1222,6 +908,8 @@ pub fn can_hash(data_type: &DataType) -> bool { { DataType::is_dictionary_key_type(key_type) } + DataType::List(_) => true, + DataType::LargeList(_) => true, _ => false, } } @@ -1264,39 +952,288 @@ pub fn find_valid_equijoin_key_pair( return Ok(None); } - let l_is_left = - check_all_columns_from_schema(&left_using_columns, left_schema.clone())?; - let r_is_right = - check_all_columns_from_schema(&right_using_columns, right_schema.clone())?; + if check_all_columns_from_schema(&left_using_columns, left_schema.clone())? + && check_all_columns_from_schema(&right_using_columns, right_schema.clone())? + { + return Ok(Some((left_key.clone(), right_key.clone()))); + } else if check_all_columns_from_schema(&right_using_columns, left_schema)? + && check_all_columns_from_schema(&left_using_columns, right_schema)? + { + return Ok(Some((right_key.clone(), left_key.clone()))); + } - let r_is_left_and_l_is_right = || { - let result = - check_all_columns_from_schema(&right_using_columns, left_schema.clone())? - && check_all_columns_from_schema( - &left_using_columns, - right_schema.clone(), - )?; + Ok(None) +} - Result::<_>::Ok(result) - }; +/// Creates a detailed error message for a function with wrong signature. +/// +/// For example, a query like `select round(3.14, 1.1);` would yield: +/// ```text +/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. +/// Candidate functions: +/// round(Float64, Int64) +/// round(Float32, Int64) +/// round(Float64) +/// round(Float32) +/// ``` +pub fn generate_signature_error_msg( + func_name: &str, + func_signature: Signature, + input_expr_types: &[DataType], +) -> String { + let candidate_signatures = func_signature + .type_signature + .to_string_repr() + .iter() + .map(|args_str| format!("\t{func_name}({args_str})")) + .collect::>() + .join("\n"); + + format!( + "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", + func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures + ) +} + +/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { + split_conjunction_impl(expr, vec![]) +} - let join_key_pair = match (l_is_left, r_is_right) { - (true, true) => Some((left_key.clone(), right_key.clone())), - (_, _) if r_is_left_and_l_is_right()? => { - Some((right_key.clone(), left_key.clone())) +fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + let exprs = split_conjunction_impl(left, exprs); + split_conjunction_impl(right, exprs) } - _ => None, - }; + Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// This is often used to "split" filter expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::split_conjunction_owned; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_conjunction_owned to split them +/// assert_eq!(split_conjunction_owned(expr), split); +/// ``` +pub fn split_conjunction_owned(expr: Expr) -> Vec { + split_binary_owned(expr, Operator::And) +} + +/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// This is often used to "split" expressions such as `col1 = 5 +/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit, Operator}; +/// # use datafusion_expr::utils::split_binary_owned; +/// # use std::ops::Add; +/// // a=1 + b=2 +/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use split_binary_owned to split them +/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// ``` +pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { + split_binary_owned_impl(expr, op, vec![]) +} + +fn split_binary_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_binary_owned_impl(*left, operator, exprs); + split_binary_owned_impl(*right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => { + split_binary_owned_impl(*expr, operator, exprs) + } + other => { + exprs.push(other); + exprs + } + } +} + +/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` +/// +/// See [`split_binary_owned`] for more details and an example. +pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { + split_binary_impl(expr, op, vec![]) +} + +fn split_binary_impl<'a>( + expr: &'a Expr, + operator: Operator, + mut exprs: Vec<&'a Expr>, +) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + let exprs = split_binary_impl(left, operator, exprs); + split_binary_impl(right, operator, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs), + other => { + exprs.push(other); + exprs + } + } +} - Ok(join_key_pair) +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical AND. +/// +/// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::conjunction; +/// // a=1 AND b=2 +/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use conjunction to join them together with `AND` +/// assert_eq!(conjunction(split), Some(expr)); +/// ``` +pub fn conjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.and(expr)) +} + +/// Combines an array of filter expressions into a single filter +/// expression consisting of the input filter expressions joined with +/// logical OR. +/// +/// Returns None if the filters array is empty. +pub fn disjunction(filters: impl IntoIterator) -> Option { + filters.into_iter().reduce(|accum, expr| accum.or(expr)) +} + +/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate be all `predicates` ANDed. +pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { + // reduce filters to a single filter with an AND + let predicate = predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + and(acc, (*predicate).to_owned()) + }); + + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(plan), + )?)) +} + +/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and +/// one not in the subquery (closed upon from outer scope) +/// +/// # Arguments +/// +/// * `exprs` - List of expressions that may or may not be joins +/// +/// # Return value +/// +/// Tuple of (expressions containing joins, remaining non-join expressions) +pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { + let mut joins = vec![]; + let mut others = vec![]; + for filter in exprs.into_iter() { + // If the expression contains correlated predicates, add it to join filters + if filter.contains_outer() { + if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) + { + joins.push(strip_outer_reference((*filter).clone())); + } + } else { + others.push((*filter).clone()); + } + } + + Ok((joins, others)) +} + +/// Returns the first (and only) element in a slice, or an error +/// +/// # Arguments +/// +/// * `slice` - The slice to extract from +/// +/// # Return value +/// +/// The first element, or an error +pub fn only_or_err(slice: &[T]) -> Result<&T> { + match slice { + [it] => Ok(it), + [] => plan_err!("No items found!"), + _ => plan_err!("More than one item found!"), + } +} + +/// merge inputs schema into a single schema. +pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { + if inputs.len() == 1 { + inputs[0].schema().clone().as_ref().clone() + } else { + inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ) + } } #[cfg(test)] mod tests { use super::*; use crate::{ - col, cube, expr, grouping_set, rollup, AggregateFunction, WindowFrame, - WindowFunction, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, + WindowFrame, WindowFunction, }; #[test] @@ -1499,22 +1436,22 @@ mod tests { // 1. col let sets = enumerate_grouping_sets(vec![simple_col.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!("[simple_col]", &result); // 2. cube let sets = enumerate_grouping_sets(vec![cube.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!("[CUBE (col1, col2, col3)]", &result); // 3. rollup let sets = enumerate_grouping_sets(vec![rollup.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!("[ROLLUP (col1, col2, col3)]", &result); // 4. col + cube let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!( "[GROUPING SETS (\ (simple_col), \ @@ -1530,7 +1467,7 @@ mod tests { // 5. col + rollup let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!( "[GROUPING SETS (\ (simple_col), \ @@ -1543,7 +1480,7 @@ mod tests { // 6. col + grouping_set let sets = enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!( "[GROUPING SETS (\ (simple_col, col1, col2, col3))]", @@ -1556,7 +1493,7 @@ mod tests { grouping_set, rollup.clone(), ])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!( "[GROUPING SETS (\ (simple_col, col1, col2, col3), \ @@ -1568,7 +1505,7 @@ mod tests { // 8. col + cube + rollup let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?; - let result = format!("{sets:?}"); + let result = format!("[{}]", expr_vec_fmt!(sets)); assert_eq!( "[GROUPING SETS (\ (simple_col), \ @@ -1608,4 +1545,143 @@ mod tests { Ok(()) } + #[test] + fn test_split_conjunction() { + let expr = col("a"); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_conjunction_two() { + let expr = col("a").eq(lit(5)).and(col("b")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_alias() { + let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); + let expr1 = col("a").eq(lit(5)); + let expr2 = col("b"); // has no alias + + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr1, &expr2]); + } + + #[test] + fn test_split_conjunction_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + let result = split_conjunction(&expr); + assert_eq!(result, vec![&expr]); + } + + #[test] + fn test_split_binary_owned() { + let expr = col("a"); + assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + } + + #[test] + fn test_split_binary_owned_two() { + assert_eq!( + split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_binary_owned_different_op() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!( + // expr is connected by OR, but pass in AND + split_binary_owned(expr.clone(), Operator::And), + vec![expr] + ); + } + + #[test] + fn test_split_conjunction_owned() { + let expr = col("a"); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_split_conjunction_owned_two() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + vec![col("a").eq(lit(5)), col("b")] + ); + } + + #[test] + fn test_split_conjunction_owned_alias() { + assert_eq!( + split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + vec![ + col("a").eq(lit(5)), + // no alias on b + col("b"), + ] + ); + } + + #[test] + fn test_conjunction_empty() { + assert_eq!(conjunction(vec![]), None); + } + + #[test] + fn test_conjunction() { + // `[A, B, C]` + let expr = conjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A AND B) AND C` + assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); + + // which is different than `A AND (B AND C)` + assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); + } + + #[test] + fn test_disjunction_empty() { + assert_eq!(disjunction(vec![]), None); + } + + #[test] + fn test_disjunction() { + // `[A, B, C]` + let expr = disjunction(vec![col("a"), col("b"), col("c")]); + + // --> `(A OR B) OR C` + assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); + + // which is different than `A OR (B OR C)` + assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); + } + + #[test] + fn test_split_conjunction_owned_or() { + let expr = col("a").eq(lit(5)).or(col("b")); + assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + } + + #[test] + fn test_collect_expr() -> Result<()> { + let mut accum: HashSet = HashSet::new(); + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + expr_to_columns( + &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &mut accum, + )?; + assert_eq!(1, accum.len()); + assert!(accum.contains(&Column::from_name("a"))); + Ok(()) + } } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2d1882788be4..2701ca1ecf3b1 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,7 +23,9 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use crate::expr::Sort; +use crate::Expr; +use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; use std::convert::{From, TryFrom}; @@ -68,14 +70,14 @@ impl TryFrom for WindowFrame { if let WindowFrameBound::Following(val) = &start_bound { if val.is_null() { - plan_error( - "Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING", + plan_err!( + "Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING" )? } } else if let WindowFrameBound::Preceding(val) = &end_bound { if val.is_null() { - plan_error( - "Invalid window frame: end bound cannot be UNBOUNDED PRECEDING", + plan_err!( + "Invalid window frame: end bound cannot be UNBOUNDED PRECEDING" )? } }; @@ -142,31 +144,57 @@ impl WindowFrame { } } -/// Construct equivalent explicit window frames for implicit corner cases. -/// With this processing, we may assume in downstream code that RANGE/GROUPS -/// frames contain an appropriate ORDER BY clause. -pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { +/// Regularizes ORDER BY clause for window definition for implicit corner cases. +pub fn regularize_window_order_by( + frame: &WindowFrame, + order_by: &mut Vec, +) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent in two edge cases. + // column. However, an ORDER BY clause may be absent or present but with + // more than one column in two edge cases: + // 1. start bound is UNBOUNDED or CURRENT ROW + // 2. end bound is CURRENT ROW or UNBOUNDED. + // In these cases, we regularize the ORDER BY clause if the ORDER BY clause + // is absent. If an ORDER BY clause is present but has more than one column, + // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. if (frame.start_bound.is_unbounded() || frame.start_bound == WindowFrameBound::CurrentRow) && (frame.end_bound == WindowFrameBound::CurrentRow || frame.end_bound.is_unbounded()) { - if order_bys == 0 { - frame.units = WindowFrameUnits::Rows; - frame.start_bound = - WindowFrameBound::Preceding(ScalarValue::UInt64(None)); - frame.end_bound = WindowFrameBound::Following(ScalarValue::UInt64(None)); + // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause + // with constant value as sort key. + // If an ORDER BY clause is present but has more than one column, it is + // unchanged. + if order_by.is_empty() { + order_by.push(Expr::Sort(Sort::new( + Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), + true, + false, + ))); } - } else { - plan_error("RANGE requires exactly one ORDER BY column")? + } + } + Ok(()) +} + +/// Checks if given window frame is valid. In particular, if the frame is RANGE +/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. +pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { + if frame.units == WindowFrameUnits::Range && order_bys != 1 { + // See `regularize_window_order_by`. + if !(frame.start_bound.is_unbounded() + || frame.start_bound == WindowFrameBound::CurrentRow) + || !(frame.end_bound == WindowFrameBound::CurrentRow + || frame.end_bound.is_unbounded()) + { + plan_err!("RANGE requires exactly one ORDER BY column")? } } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { - plan_error("GROUPS requires an ORDER BY clause")? + plan_err!("GROUPS requires an ORDER BY clause")? }; - Ok(frame) + Ok(()) } /// There are five ways to describe starting and ending frame boundaries: @@ -241,9 +269,9 @@ pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result let result = match *value { ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, e => { - return Err(DataFusionError::SQL(ParserError(format!( + return sql_err!(ParserError(format!( "INTERVAL expression cannot be {e:?}" - )))); + ))); } }; if let Some(leading_field) = leading_field { @@ -252,16 +280,12 @@ pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result result } } - _ => plan_error( - "Invalid window frame: frame offsets must be non negative integers", + _ => plan_err!( + "Invalid window frame: frame offsets must be non negative integers" )?, }))) } -fn plan_error(err_message: &str) -> Result { - Err(DataFusionError::Plan(err_message.to_string())) -} - impl fmt::Display for WindowFrameBound { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -336,7 +360,7 @@ mod tests { }; let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - err.to_string(), + err.strip_backtrace(), "Error during planning: Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING".to_owned() ); @@ -347,7 +371,7 @@ mod tests { }; let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - err.to_string(), + err.strip_backtrace(), "Error during planning: Invalid window frame: end bound cannot be UNBOUNDED PRECEDING".to_owned() ); diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a5b58d173c1ad..610f1ecaeae91 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -19,13 +19,13 @@ //! sets of rows that are related to the current query row. //! //! see also -//! use crate::aggregate_function::AggregateFunction; use crate::type_coercion::functions::data_types; -use crate::{aggregate_function, AggregateUDF, Signature, TypeSignature, Volatility}; +use crate::utils; +use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; use arrow::datatypes::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; use strum_macros::EnumIter; @@ -33,11 +33,14 @@ use strum_macros::EnumIter; /// WindowFunction #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WindowFunction { - /// window function that leverages an aggregate function + /// A built in aggregate function that leverages an aggregate function AggregateFunction(AggregateFunction), - /// window function that leverages a built-in window function + /// A a built-in window function BuiltInWindowFunction(BuiltInWindowFunction), + /// A user defined aggregate function AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), } /// Find DataFusion's built-in window function by name. @@ -69,6 +72,7 @@ impl fmt::Display for WindowFunction { WindowFunction::AggregateFunction(fun) => fun.fmt(f), WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunction::WindowUDF(fun) => fun.fmt(f), } } } @@ -142,105 +146,159 @@ impl FromStr for BuiltInWindowFunction { "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in window function named {name}" - ))) - } + _ => return plan_err!("There is no built-in window function named {name}"), }) } } /// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] pub fn return_type( fun: &WindowFunction, input_expr_types: &[DataType], ) -> Result { - match fun { - WindowFunction::AggregateFunction(fun) => { - aggregate_function::return_type(fun, input_expr_types) - } - WindowFunction::BuiltInWindowFunction(fun) => { - return_type_for_built_in(fun, input_expr_types) - } - WindowFunction::AggregateUDF(fun) => { - Ok((*(fun.return_type)(input_expr_types)?).clone()) + fun.return_type(input_expr_types) +} + +impl WindowFunction { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), + WindowFunction::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), + WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), } } } /// Returns the datatype of the built-in window function -fn return_type_for_built_in( - fun: &BuiltInWindowFunction, - input_expr_types: &[DataType], -) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature_for_built_in(fun))?; - - match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } } /// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] pub fn signature(fun: &WindowFunction) -> Signature { - match fun { - WindowFunction::AggregateFunction(fun) => aggregate_function::signature(fun), - WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), - WindowFunction::AggregateUDF(fun) => fun.signature.clone(), + fun.signature() +} + +impl WindowFunction { + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunction::AggregateFunction(fun) => fun.signature(), + WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunction::AggregateUDF(fun) => fun.signature().clone(), + WindowFunction::WindowUDF(fun) => fun.signature().clone(), + } } } /// the signatures supported by the built-in window function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `BuiltInWindowFunction::signature` instead" +)] pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) + fun.signature() +} + +impl BuiltInWindowFunction { + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } - BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } #[cfg(test)] mod tests { use super::*; + use strum::IntoEnumIterator; #[test] fn test_count_return_type() -> Result<()> { let fun = find_df_window_func("count").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Int64, observed); - let observed = return_type(&fun, &[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64])?; assert_eq!(DataType::Int64, observed); Ok(()) @@ -249,10 +307,10 @@ mod tests { #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&fun, &[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64])?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -261,10 +319,10 @@ mod tests { #[test] fn test_last_value_return_type() -> Result<()> { let fun = find_df_window_func("last_value").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&fun, &[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -273,10 +331,10 @@ mod tests { #[test] fn test_lead_return_type() -> Result<()> { let fun = find_df_window_func("lead").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&fun, &[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -285,10 +343,10 @@ mod tests { #[test] fn test_lag_return_type() -> Result<()> { let fun = find_df_window_func("lag").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&fun, &[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -297,10 +355,10 @@ mod tests { #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); - let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?; + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&fun, &[DataType::Float64, DataType::UInt64])?; + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -309,7 +367,7 @@ mod tests { #[test] fn test_percent_rank_return_type() -> Result<()> { let fun = find_df_window_func("percent_rank").unwrap(); - let observed = return_type(&fun, &[])?; + let observed = fun.return_type(&[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -318,12 +376,21 @@ mod tests { #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = find_df_window_func("cume_dist").unwrap(); - let observed = return_type(&fun, &[])?; + let observed = fun.return_type(&[])?; assert_eq!(DataType::Float64, observed); Ok(()) } + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + #[test] fn test_window_function_case_insensitive() -> Result<()> { let names = vec![ @@ -399,4 +466,18 @@ mod tests { ); assert_eq!(find_df_window_func("not_exist"), None) } + + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } } diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_state.rs similarity index 84% rename from datafusion/physical-expr/src/window/window_frame_state.rs rename to datafusion/expr/src/window_state.rs index e23a58a09b668..de88396d9b0e7 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -15,19 +15,101 @@ // specific language governing permissions and limitations // under the License. -//! This module provides utilities for window frame index calculations -//! depending on the window frame mode: RANGE, ROWS, GROUPS. - -use arrow::array::ArrayRef; -use arrow::compute::kernels::sort::SortOptions; -use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -use std::cmp::min; -use std::collections::VecDeque; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; +//! Structures used to hold window function state (for implementing WindowUDFs) + +use std::{collections::VecDeque, ops::Range, sync::Arc}; + +use arrow::{ + array::ArrayRef, + compute::{concat, SortOptions}, + datatypes::DataType, + record_batch::RecordBatch, +}; +use datafusion_common::{ + internal_err, + utils::{compare_rows, get_row_at_idx, search_in_slice}, + DataFusionError, Result, ScalarValue, +}; + +use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + +/// Holds the state of evaluating a window function +#[derive(Debug)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub window_frame_range: Range, + pub window_frame_ctx: Option, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// Stores the results calculated by window frame + pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } + + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } + + pub fn new(out_type: &DataType) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?; + Ok(Self { + window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] @@ -120,24 +202,24 @@ impl WindowFrameContext { WindowFrameBound::CurrentRow => idx, // UNBOUNDED FOLLOWING WindowFrameBound::Following(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'" - ))) + ) } WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - min(idx + n as usize, length) + std::cmp::min(idx + n as usize, length) } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal("Rows should be Uint".to_string())) + return internal_err!("Rows should be Uint") } }; let end = match window_frame.end_bound { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'" - ))) + ) } WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= n as usize { @@ -150,17 +232,28 @@ impl WindowFrameContext { // UNBOUNDED FOLLOWING WindowFrameBound::Following(ScalarValue::UInt64(None)) => length, WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { - min(idx + n as usize + 1, length) + std::cmp::min(idx + n as usize + 1, length) } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return Err(DataFusionError::Internal("Rows should be Uint".to_string())) + return internal_err!("Rows should be Uint") } }; Ok(Range { start, end }) } } +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// Flag indicating whether we have received all data for this partition + pub is_end: bool, + /// Number of rows emitted for each partition + pub n_out_row: usize, +} + /// This structure encapsulates all the state information we require as we scan /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER @@ -429,10 +522,9 @@ impl WindowFrameStateGroups { if let ScalarValue::UInt64(Some(value)) = delta { *value as usize } else { - return Err(DataFusionError::Internal( + return internal_err!( "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame" - .to_string(), - )); + ); } } else { 0 @@ -510,7 +602,7 @@ impl WindowFrameStateGroups { Ok(match (SIDE, SEARCH_SIDE) { // Window frame start: (true, _) => { - let group_idx = min(group_idx, self.group_end_indices.len()); + let group_idx = std::cmp::min(group_idx, self.group_end_indices.len()); if group_idx > 0 { // Normally, start at the boundary of the previous group. self.group_end_indices[group_idx - 1].1 @@ -531,7 +623,7 @@ impl WindowFrameStateGroups { } // Window frame end, FOLLOWING n (false, false) => { - let group_idx = min( + let group_idx = std::cmp::min( self.current_group_idx + delta, self.group_end_indices.len() - 1, ); @@ -547,11 +639,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result @@ -61,35 +58,30 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { let window_expr = window .window_expr .iter() - .map(|expr| expr.clone().rewrite(&mut rewriter)) + .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Window(Window { - input: window.input.clone(), - window_expr, - schema: rewrite_schema(&window.schema), - }))) + Ok(Transformed::Yes( + LogicalPlanBuilder::from((*window.input).clone()) + .window(window_expr)? + .build()?, + )) } LogicalPlan::Aggregate(agg) => { let aggr_expr = agg .aggr_expr .iter() - .map(|expr| expr.clone().rewrite(&mut rewriter)) + .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; Ok(Transformed::Yes(LogicalPlan::Aggregate( - Aggregate::try_new_with_schema( - agg.input.clone(), - agg.group_expr.clone(), - aggr_expr, - rewrite_schema(&agg.schema), - )?, + Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } LogicalPlan::Sort(Sort { expr, input, fetch }) => { let sort_expr = expr .iter() - .map(|expr| expr.clone().rewrite(&mut rewriter)) + .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; Ok(Transformed::Yes(LogicalPlan::Sort(Sort { expr: sort_expr, @@ -101,21 +93,16 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { let projection_expr = projection .expr .iter() - .map(|expr| expr.clone().rewrite(&mut rewriter)) + .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; Ok(Transformed::Yes(LogicalPlan::Projection( - Projection::try_new_with_schema( - projection_expr, - projection.input, - // rewrite_schema(projection.schema.clone()), - rewrite_schema(&projection.schema), - )?, + Projection::try_new(projection_expr, projection.input)?, ))) } LogicalPlan::Filter(Filter { predicate, input, .. }) => { - let predicate = predicate.rewrite(&mut rewriter)?; + let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( predicate, input, )?))) @@ -132,15 +119,6 @@ impl TreeNodeRewriter for CountWildcardRewriter { fn mutate(&mut self, old_expr: Expr) -> Result { let new_expr = match old_expr.clone() { - Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => { - Expr::Column(Column { - name: name.replace( - COUNT_STAR, - count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), - ), - relation: relation.clone(), - }) - } Expr::WindowFunction(expr::WindowFunction { fun: window_function::WindowFunction::AggregateFunction( @@ -151,32 +129,39 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, window_frame, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ), - args: vec![lit(COUNT_STAR_EXPANSION)], - partition_by, - order_by, - window_frame, - }), + Expr::Wildcard { qualifier: None } => { + Expr::WindowFunction(expr::WindowFunction { + fun: window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { - Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], - distinct, - filter, - order_by, - }), + Expr::Wildcard { qualifier: None } => { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], + distinct, + filter, + order_by, + )) + } _ => old_expr, }, @@ -233,30 +218,6 @@ impl TreeNodeRewriter for CountWildcardRewriter { Ok(new_expr) } } -fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef { - let new_fields = schema - .fields() - .iter() - .map(|field| { - let mut name = field.field().name().clone(); - if name.contains(COUNT_STAR) { - name = name.replace( - COUNT_STAR, - count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), - ); - } - DFField::new( - field.qualifier().cloned(), - &name, - field.data_type().clone(), - field.is_nullable(), - ) - }) - .collect::>(); - DFSchemaRef::new( - DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), - ) -} #[cfg(test)] mod tests { @@ -267,8 +228,8 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, - max, out_ref_col, scalar_subquery, AggregateFunction, Expr, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -283,14 +244,14 @@ mod tests { fn test_count_wildcard_on_sort() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? - .sort(vec![count(Expr::Wildcard).sort(true, false)])? + .aggregate(vec![col("b")], vec![count(wildcard())])? + .project(vec![count(wildcard())])? + .sort(vec![count(wildcard()).sort(true, false)])? .build()?; - let expected = "Sort: COUNT(UInt8(1)) ASC NULLS LAST [COUNT(UInt8(1)):Int64;N]\ - \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(UInt8(1))]] [b:UInt32, COUNT(UInt8(1)):Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ + \n Projection: COUNT(*) [COUNT(*):Int64;N]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] [b:UInt32, COUNT(*):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -304,19 +265,19 @@ mod tests { col("a"), Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, ), ))? .build()?; let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(UInt8(1)):Int64;N]\ - \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; + \n Subquery: [COUNT(*):Int64;N]\ + \n Projection: COUNT(*) [COUNT(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -328,18 +289,18 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan_t1) .filter(exists(Arc::new( LogicalPlanBuilder::from(table_scan_t2) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?, )))? .build()?; let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(UInt8(1)):Int64;N]\ - \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; + \n Subquery: [COUNT(*):Int64;N]\ + \n Projection: COUNT(*) [COUNT(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -382,7 +343,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(AggregateFunction::Count), - vec![Expr::Wildcard], + vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], WindowFrame { @@ -393,12 +354,12 @@ mod tests { end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), }, ))])? - .project(vec![count(Expr::Wildcard)])? + .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ + \n WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -406,13 +367,13 @@ mod tests { fn test_count_wildcard_on_aggregate() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![count(wildcard())])? + .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -420,13 +381,13 @@ mod tests { fn test_count_wildcard_on_nesting() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(count(Expr::Wildcard))])? - .project(vec![count(Expr::Wildcard)])? + .aggregate(Vec::::new(), vec![max(count(wildcard()))])? + .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(UInt8(1)))]] [MAX(COUNT(UInt8(1))):Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: COUNT(UInt8(1)) AS COUNT(*) [COUNT(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(UInt8(1))) AS MAX(COUNT(*))]] [MAX(COUNT(*)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 3d0dabdd377ce..90af7aec82935 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -126,7 +126,7 @@ fn generate_projection_expr( )); } } else { - exprs.push(Expr::Wildcard); + exprs.push(Expr::Wildcard { qualifier: None }); } Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 436bb3a060447..14d5ddf473786 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -40,7 +40,7 @@ use std::time::Instant; /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// -/// For example, it may resolve [`Expr]s into more specific forms such +/// For example, it may resolve [`Expr`]s into more specific forms such /// as a subquery reference, to do type coercion to ensure the types /// of operands are correct. /// diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7cdedc06b4530..7c5b70b19af0a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -16,10 +16,11 @@ // under the License. use crate::analyzer::check_plan; -use crate::utils::{collect_subquery_cols, split_conjunction}; +use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, @@ -42,11 +43,11 @@ pub fn check_subquery_expr( if let Expr::ScalarSubquery(subquery) = expr { // Scalar subquery should only return one column if subquery.subquery.schema().fields().len() > 1 { - return Err(datafusion_common::DataFusionError::Plan(format!( + return plan_err!( "Scalar subquery should only return one column, but found {}: {}", subquery.subquery.schema().fields().len(), - subquery.subquery.schema().field_names().join(", "), - ))); + subquery.subquery.schema().field_names().join(", ") + ); } // Correlated scalar subquery must be aggregated to return at most one row if !subquery.outer_ref_columns.is_empty() { @@ -71,10 +72,9 @@ pub fn check_subquery_expr( { Ok(()) } else { - Err(DataFusionError::Plan( + plan_err!( "Correlated scalar subquery must be aggregated to return at most one row" - .to_string(), - )) + ) } } }?; @@ -84,33 +84,40 @@ pub fn check_subquery_expr( LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic - Err(DataFusionError::Plan( + plan_err!( "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" - .to_string(), - )) + ) } else { Ok(()) } }, - _ => Err(DataFusionError::Plan( + _ => plan_err!( "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - .to_string(), - )) + ) }?; } check_correlations_in_subquery(inner_plan, true) } else { + if let Expr::InSubquery(subquery) = expr { + // InSubquery should only return one column + if subquery.subquery.subquery.schema().fields().len() > 1 { + return plan_err!( + "InSubquery should only return one column, but found {}: {}", + subquery.subquery.subquery.schema().fields().len(), + subquery.subquery.subquery.schema().field_names().join(", ") + ); + } + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), - _ => Err(DataFusionError::Plan( + _ => plan_err!( "In/Exist subquery can only be used in \ Projection, Filter, Window functions, Aggregate and Join plan nodes" - .to_string(), - )), + ), }?; check_correlations_in_subquery(inner_plan, false) } @@ -132,9 +139,7 @@ fn check_inner_plan( can_contain_outer_ref: bool, ) -> Result<()> { if !can_contain_outer_ref && contains_outer_reference(inner_plan) { - return Err(DataFusionError::Plan( - "Accessing outer reference columns is not allowed in the plan".to_string(), - )); + return plan_err!("Accessing outer reference columns is not allowed in the plan"); } // We want to support as many operators as possible inside the correlated subquery match inner_plan { @@ -156,9 +161,9 @@ fn check_inner_plan( .filter(|expr| !can_pullup_over_aggregation(expr)) .collect::>(); if is_aggregate && is_scalar && !maybe_unsupport.is_empty() { - return Err(DataFusionError::Plan(format!( - "Correlated column is not allowed in predicate: {predicate:?}" - ))); + return plan_err!( + "Correlated column is not allowed in predicate: {predicate}" + ); } check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) } @@ -221,9 +226,8 @@ fn check_inner_plan( Ok(()) } }, - _ => Err(DataFusionError::Plan( - "Unsupported operator in the subquery plan.".to_string(), - )), + LogicalPlan::Extension(_) => Ok(()), + _ => plan_err!("Unsupported operator in the subquery plan."), } } @@ -239,10 +243,9 @@ fn check_aggregation_in_scalar_subquery( agg: &Aggregate, ) -> Result<()> { if agg.aggr_expr.is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Correlated scalar subquery must be aggregated to return at most one row" - .to_string(), - )); + ); } if !agg.group_expr.is_empty() { let correlated_exprs = get_correlated_expressions(inner_plan)?; @@ -258,10 +261,9 @@ fn check_aggregation_in_scalar_subquery( if !group_columns.all(|group| inner_subquery_cols.contains(&group)) { // Group BY columns must be a subset of columns in the correlated expressions - return Err(DataFusionError::Plan( + return plan_err!( "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns" - .to_string(), - )); + ); } } Ok(()) @@ -331,11 +333,64 @@ fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty() }); if mixed { - Err(DataFusionError::Plan( + plan_err!( "Window expressions should not contain a mixed of outer references and inner columns" - .to_string(), - )) + ) } else { Ok(()) } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datafusion_common::{DFSchema, DFSchemaRef}; + use datafusion_expr::{Extension, UserDefinedLogicalNodeCore}; + + use super::*; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockUserDefinedLogicalPlan { + empty_schema: DFSchemaRef, + } + + impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan { + fn name(&self) -> &str { + "MockUserDefinedLogicalPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &datafusion_common::DFSchemaRef { + &self.empty_schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MockUserDefinedLogicalPlan") + } + + fn from_template(&self, _exprs: &[Expr], _inputs: &[LogicalPlan]) -> Self { + Self { + empty_schema: self.empty_schema.clone(), + } + } + } + + #[test] + fn wont_fail_extension_plan() { + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(MockUserDefinedLogicalPlan { + empty_schema: DFSchemaRef::new(DFSchema::empty()), + }), + }); + + check_inner_plan(&plan, false, false, true).unwrap(); + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0d0061a5e435d..91611251d9dd9 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -23,31 +23,34 @@ use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - ScalarUDF, WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ - any_decimal, coerce_types, comparison_coercion, like_coercion, math_decimal_coercion, + comparison_coercion, get_input_types, like_coercion, }; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; -use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_utf8}; -use datafusion_expr::utils::from_plan; +use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, - is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, - Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, + type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, + ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, + Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ExprSchemable, Signature}; use crate::analyzer::AnalyzerRule; -use crate::utils::{merge_schema, rewrite_preserving_name}; #[derive(Default)] pub struct TypeCoercion {} @@ -108,13 +111,13 @@ fn analyze_internal( }) .collect::>>()?; - // TODO: use from_plan after fix https://github.com/apache/arrow-datafusion/issues/6613 + // TODO: with_new_exprs can't change the schema, so we need to do this here match &plan { LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( new_expr, Arc::new(new_inputs[0].clone()), )?)), - _ => from_plan(plan, &new_expr, &new_inputs), + _ => plan.with_new_exprs(new_expr, &new_inputs), } } @@ -159,11 +162,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; let expr_type = expr.get_type(&self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(DataFusionError::Plan( - format!( + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" ), - ))?; + )?; let new_subquery = Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, @@ -205,103 +207,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter { expr, pattern, escape_char, + case_insensitive, }) => { let left_type = expr.get_type(&self.schema)?; let right_type = pattern.get_type(&self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { - DataFusionError::Plan(format!( - "There isn't a common type to coerce {left_type} and {right_type} in LIKE expression" - )) + let op_name = if case_insensitive { + "ILIKE" + } else { + "LIKE" + }; + plan_datafusion_err!( + "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" + ) })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new(negated, expr, pattern, escape_char)); - Ok(expr) - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; - let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { - DataFusionError::Plan(format!( - "There isn't a common type to coerce {left_type} and {right_type} in ILIKE expression" - )) - })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::ILike(Like::new(negated, expr, pattern, escape_char)); + let expr = Expr::Like(Like::new( + negated, + expr, + pattern, + escape_char, + case_insensitive, + )); Ok(expr) } - Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) => { - // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419 - let left_type = left.get_type(&self.schema)?; - let right_type = right.get_type(&self.schema)?; - match (&left_type, &right_type) { - // Handle some case about Interval. - ( - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - &DataType::Interval(_), - ) if matches!(op, Operator::Plus | Operator::Minus) => Ok(expr), - ( - &DataType::Interval(_), - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - ) if matches!(op, Operator::Plus) => Ok(expr), - (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) - if op.is_numerical_operators() => - { - if matches!(op, Operator::Minus) { - Ok(expr) - } else { - Err(DataFusionError::Internal(format!( - "Unsupported operation {op:?} between {left_type:?} and {right_type:?}" - ))) - } - } - // For numerical operations between decimals, we don't coerce the types. - // But if only one of the operands is decimal, we cast the other operand to decimal - // if the other operand is integer. If the other operand is float, we cast the - // decimal operand to float. - (lhs_type, rhs_type) - if op.is_numerical_operators() - && any_decimal(lhs_type, rhs_type) => - { - let (coerced_lhs_type, coerced_rhs_type) = - math_decimal_coercion(lhs_type, rhs_type); - let new_left = if let Some(lhs_type) = coerced_lhs_type { - left.clone().cast_to(&lhs_type, &self.schema)? - } else { - left.as_ref().clone() - }; - let new_right = if let Some(rhs_type) = coerced_rhs_type { - right.clone().cast_to(&rhs_type, &self.schema)? - } else { - right.as_ref().clone() - }; - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(new_left), - op, - Box::new(new_right), - )); - Ok(expr) - } - _ => { - let common_type = coerce_types(&left_type, &op, &right_type)?; - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.clone().cast_to(&common_type, &self.schema)?), - op, - Box::new(right.clone().cast_to(&common_type, &self.schema)?), - )); - Ok(expr) - } - } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + let (left_type, right_type) = get_input_types( + &left.get_type(&self.schema)?, + &op, + &right.get_type(&self.schema)?, + )?; + + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left.cast_to(&left_type, &self.schema)?), + op, + Box::new(right.cast_to(&right_type, &self.schema)?), + ))) } Expr::Between(Between { expr, @@ -352,9 +294,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); match result_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ))), + ), Some(coerced_type) => { // find the coerced type let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; @@ -377,58 +319,66 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let case = coerce_case_expression(case, &self.schema)?; Ok(Expr::Case(case)) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - let expr = Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)); - Ok(expr) - } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let nex_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &function::signature(&fun), - )?; - let expr = Expr::ScalarFunction(ScalarFunction::new(fun, nex_expr)); - Ok(expr) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let new_args = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature(), + )?; + let new_args = coerce_arguments_for_fun( + new_args.as_slice(), + &self.schema, + &fun, + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &aggregate_function::signature(&fun), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature, - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -438,6 +388,19 @@ impl TreeNodeRewriter for TypeCoercionRewriter { }) => { let window_frame = coerce_window_frame(window_frame, &self.schema, &order_by)?; + + let args = match &fun { + window_function::WindowFunction::AggregateFunction(fun) => { + coerce_agg_exprs_for_signature( + fun, + &args, + &self.schema, + &fun.signature(), + )? + } + _ => args, + }; + let expr = Expr::WindowFunction(WindowFunction::new( fun, args, @@ -488,11 +451,7 @@ fn coerce_scalar_range_aware( // If type coercion fails, check if the largest type in family works: if let Some(largest_type) = get_widest_type_in_family(target_type) { coerce_scalar(largest_type, value).map_or_else( - |_| { - Err(DataFusionError::Execution(format!( - "Cannot cast {value:?} to {target_type:?}" - ))) - }, + |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), |_| ScalarValue::try_from(target_type), ) } else { @@ -544,19 +503,20 @@ fn coerce_window_frame( let target_type = match window_frame.units { WindowFrameUnits::Range => { if let Some(col_type) = current_types.first() { - if is_numeric(col_type) || is_utf8_or_large_utf8(col_type) { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { col_type } else if is_datetime(col_type) { &DataType::Interval(IntervalUnit::MonthDayNano) } else { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Cannot run range queries on datatype: {col_type:?}" - ))); + ); } } else { - return Err(DataFusionError::Internal( - "ORDER BY column cannot be empty".to_string(), - )); + return internal_err!("ORDER BY column cannot be empty"); } } WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, @@ -571,8 +531,8 @@ fn coerce_window_frame( // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { let left_type = expr.get_type(schema)?; - coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; - expr.clone().cast_to(&DataType::Boolean, schema) + get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; + cast_expr(expr, &DataType::Boolean, schema) } /// Returns `expressions` coerced to types compatible with @@ -602,11 +562,76 @@ fn coerce_arguments_for_signature( .collect::>>() } +fn coerce_arguments_for_fun( + expressions: &[Expr], + schema: &DFSchema, + fun: &BuiltinScalarFunction, +) -> Result> { + if expressions.is_empty() { + return Ok(vec![]); + } + + let mut expressions: Vec = expressions.to_vec(); + + // Cast Fixedsizelist to List for array functions + if *fun == BuiltinScalarFunction::MakeArray { + expressions = expressions + .into_iter() + .map(|expr| { + let data_type = expr.get_type(schema).unwrap(); + if let DataType::FixedSizeList(field, _) = data_type { + let field = field.as_ref().clone(); + let to_type = DataType::List(Arc::new(field)); + expr.cast_to(&to_type, schema) + } else { + Ok(expr) + } + }) + .collect::>>()?; + } + + if *fun == BuiltinScalarFunction::MakeArray { + // Find the final data type for the function arguments + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new_type = current_types + .iter() + .skip(1) + .fold(current_types.first().unwrap().clone(), |acc, x| { + comparison_coercion(&acc, x).unwrap_or(acc) + }); + + return expressions + .iter() + .zip(current_types) + .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) + .collect(); + } + Ok(expressions) +} + /// Cast `expr` to the specified type, if possible fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result { expr.clone().cast_to(to_type, schema) } +/// Cast array `expr` to the specified type, if possible +fn cast_array_expr( + expr: &Expr, + from_type: &DataType, + to_type: &DataType, + schema: &DFSchema, +) -> Result { + if from_type.equals_datatype(&DataType::Null) { + Ok(expr.clone()) + } else { + cast_expr(expr, to_type, schema) + } +} + /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. @@ -630,7 +655,7 @@ fn coerce_agg_exprs_for_signature( input_exprs .iter() .enumerate() - .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema)) + .map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema)) .collect::>>() } @@ -694,20 +719,20 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let coerced_type = get_coerce_type_for_case_expression(&when_types, Some(case_type)); coerced_type.ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ to common types in CASE WHEN expression" - )) + ) }) }) .transpose()?; let then_else_coerce_type = get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( || { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ to common types in CASE WHEN expression" - )) + ) }, )?; @@ -749,16 +774,17 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { mod test { use std::sync::Arc; + use arrow::array::{FixedSizeListArray, Int32Array}; use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::Field; use datafusion_common::tree_node::TreeNode; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::{ - cast, col, concat, concat_ws, create_udaf, is_true, - AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, BinaryExpr, - BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Filter, Operator, - StateTypeFunction, Subquery, + cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, + AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, + ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, }; use datafusion_expr::{ lit, @@ -769,7 +795,7 @@ mod test { use datafusion_physical_expr::expressions::AvgAccumulator; use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; use crate::test::assert_analyzed_plan_eq; @@ -823,7 +849,7 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), @@ -844,7 +870,7 @@ mod test { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), @@ -858,24 +884,25 @@ mod test { .err() .unwrap(); assert_eq!( - r#"Context("type_coercion", Plan("Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed."))"#, - &format!("{err:?}") - ); + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + err.strip_backtrace() + ); Ok(()) } #[test] fn scalar_function() -> Result<()> { + // test that automatic argument type coercion for scalar functions work let empty = empty(); let lit_expr = lit(10i64); - let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs; + let fun: BuiltinScalarFunction = BuiltinScalarFunction::Acos; let scalar_function_expr = Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr])); let plan = LogicalPlan::Projection(Projection::try_new( vec![scalar_function_expr], empty, )?); - let expected = "Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation"; + let expected = "Projection: acos(CAST(Int64(10) AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } @@ -884,20 +911,16 @@ mod test { let empty = empty(); let my_avg = create_udaf( "MY_AVG", - DataType::Float64, + vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -913,12 +936,8 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Float64))); let state_type: StateTypeFunction = Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); - let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::new( "MY_AVG", &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), @@ -926,9 +945,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); @@ -937,8 +957,8 @@ mod test { .err() .unwrap(); assert_eq!( - r#"Context("type_coercion", Plan("Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed."))"#, - &format!("{err:?}") + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.", + err.strip_backtrace() ); Ok(()) } @@ -955,7 +975,7 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(Int64(12))\n EmptyRelation"; + let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; let empty = empty_with_type(DataType::Int32); @@ -968,7 +988,7 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(a)\n EmptyRelation"; + let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; Ok(()) } @@ -984,10 +1004,13 @@ mod test { None, None, )); - let err = Projection::try_new(vec![agg_expr], empty).err().unwrap(); + let err = Projection::try_new(vec![agg_expr], empty) + .err() + .unwrap() + .strip_backtrace(); assert_eq!( - "Plan(\"The function Avg does not support inputs of type Utf8.\")", - &format!("{err:?}") + "Error during planning: No function matches the given name and argument types 'AVG(Utf8)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tAVG(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64)", + err ); Ok(()) } @@ -1012,7 +1035,7 @@ mod test { let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = - "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ + "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Literal(Int32(1)), Literal(Int8(4)), Literal(Int64(8))]) })\ \n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; @@ -1031,7 +1054,7 @@ mod test { })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = - "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\ + "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Literal(Int32(1)), Literal(Int8(4)), Literal(Int64(8))]) })\ \n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } @@ -1081,9 +1104,9 @@ mod test { let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains("Int64 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to")); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); + let err = ret.unwrap_err().to_string(); + assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); // is not true let expr = col("a").is_not_true(); @@ -1114,7 +1137,7 @@ mod test { // like : utf8 like "abc" let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None)); + let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; @@ -1122,7 +1145,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None)); + let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \ @@ -1131,7 +1154,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let like_expr = Expr::Like(Like::new(false, expr, pattern, None)); + let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); @@ -1143,7 +1166,7 @@ mod test { // ilike let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::ILike(Like::new(false, expr, pattern, None)); + let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; @@ -1151,7 +1174,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); - let ilike_expr = Expr::ILike(Like::new(false, expr, pattern, None)); + let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8) AS a ILIKE NULL \ @@ -1160,7 +1183,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); - let ilike_expr = Expr::ILike(Like::new(false, expr, pattern, None)); + let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); @@ -1183,9 +1206,9 @@ mod test { let empty = empty_with_type(DataType::Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains("Utf8 IS DISTINCT FROM Boolean can't be evaluated because there isn't a common type to coerce the types to")); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); + let err = ret.unwrap_err().to_string(); + assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); // is not unknown let expr = col("a").is_not_unknown(); @@ -1226,6 +1249,57 @@ mod test { Ok(()) } + #[test] + fn test_casting_for_fixed_size_list() -> Result<()> { + let val = lit(ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 3, + Arc::new(Int32Array::from(vec![1, 2, 3])), + None, + ), + ))); + let expr = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![val.clone()], + )); + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new_unqualified( + "item", + DataType::FixedSizeList( + Arc::new(Field::new("a", DataType::Int32, true)), + 3, + ), + true, + )], + std::collections::HashMap::new(), + )?); + let mut rewriter = TypeCoercionRewriter { schema }; + let result = expr.rewrite(&mut rewriter)?; + + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new_unqualified( + "item", + DataType::List(Arc::new(Field::new("a", DataType::Int32, true))), + true, + )], + std::collections::HashMap::new(), + )?); + let expected_casted_expr = cast_expr( + &val, + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + &schema, + )?; + + let expected = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![expected_casted_expr], + )); + + assert_eq!(result, expected); + Ok(()) + } + #[test] fn test_type_coercion_rewrite() -> Result<()> { // gt @@ -1398,7 +1472,7 @@ mod test { }; let err = coerce_case_expression(case, &schema).unwrap_err(); assert_eq!( - err.to_string(), + err.strip_backtrace(), "Error during planning: \ Failed to coerce case (Interval(MonthDayNano)) and \ when ([Float32, Binary, Utf8]) to common types in \ @@ -1416,7 +1490,7 @@ mod test { }; let err = coerce_case_expression(case, &schema).unwrap_err(); assert_eq!( - err.to_string(), + err.strip_backtrace(), "Error during planning: \ Failed to coerce then ([Date32, Float32, Binary]) and \ else (Some(Timestamp(Nanosecond, None))) to common types \ diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0f63ecc2cc70c..1d21407a69850 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,20 +20,20 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +use crate::{utils, OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, }; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::{ - col, - logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, - Expr, ExprSchemable, +use datafusion_expr::expr::Alias; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; - -use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including /// - the expression itself (cloned) @@ -110,12 +110,7 @@ impl CommonSubexprEliminate { projection: &Projection, config: &dyn OptimizerConfig, ) -> Result { - let Projection { - expr, - input, - schema, - .. - } = projection; + let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; @@ -123,11 +118,9 @@ impl CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - pop_expr(&mut new_expr)?, - Arc::new(new_input), - schema.clone(), - )?)) + // Since projection expr changes, schema changes also. Use try_new method. + Projection::try_new(pop_expr(&mut new_expr)?, Arc::new(new_input)) + .map(LogicalPlan::Projection) } fn try_optimize_filter( @@ -161,9 +154,7 @@ impl CommonSubexprEliminate { Arc::new(new_input), )?)) } else { - Err(DataFusionError::Internal( - "Failed to pop predicate expr".to_string(), - )) + internal_err!("Failed to pop predicate expr") } } @@ -202,7 +193,6 @@ impl CommonSubexprEliminate { group_expr, aggr_expr, input, - schema, .. } = aggregate; let mut expr_set = ExprSet::new(); @@ -248,12 +238,17 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(new_input), - new_group_expr, - new_aggr_expr, - schema.clone(), - )?)) + // Alias aggregation expressions if they have changed + let new_aggr_expr = new_aggr_expr + .iter() + .zip(aggr_expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.clone().alias_if_changed(old_expr.display_name()?) + }) + .collect::>>()?; + // Since group_epxr changes, schema changes also. Use try_new method. + Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) + .map(LogicalPlan::Aggregate) } else { let mut agg_exprs = vec![]; @@ -264,22 +259,18 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.clone().alias(&id)); } _ => { - return Err(DataFusionError::Internal( - "expr_set invalid state".to_string(), - )); + return internal_err!("expr_set invalid state"); } } } let mut proj_exprs = vec![]; for expr in &new_group_expr { - let out_col: Column = - expr.to_field(&new_input_schema)?.qualified_column(); - proj_exprs.push(Expr::Column(out_col)); + extract_expressions(expr, &new_input_schema, &mut proj_exprs)? } for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { if expr_rewritten == expr_orig { - if let Expr::Alias(expr, name) = expr_rewritten { + if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { @@ -369,6 +360,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Distinct(_) | LogicalPlan::Extension(_) | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan @@ -383,7 +375,7 @@ impl OptimizerRule for CommonSubexprEliminate { Ok(Some(build_recover_project_plan( &original_schema, optimized_plan, - ))) + )?)) } plan => Ok(plan), } @@ -453,9 +445,7 @@ fn build_common_expr_project_plan( project_exprs.push(expr.clone().alias(&id)); } _ => { - return Err(DataFusionError::Internal( - "expr_set invalid state".to_string(), - )); + return internal_err!("expr_set invalid state"); } } } @@ -476,16 +466,35 @@ fn build_common_expr_project_plan( /// the "intermediate" projection plan built in [build_common_expr_project_plan]. /// /// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. -fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); - LogicalPlan::Projection( - Projection::try_new(col_exprs, Arc::new(input)) - .expect("Cannot build projection plan from an invalid schema"), - ) + Ok(LogicalPlan::Projection(Projection::try_new( + col_exprs, + Arc::new(input), + )?)) +} + +fn extract_expressions( + expr: &Expr, + schema: &DFSchema, + result: &mut Vec, +) -> Result<()> { + if let Expr::GroupingSet(groupings) = expr { + for e in groupings.distinct_expr() { + result.push(Expr::Column(e.to_field(schema)?.qualified_column())) + } + } else { + result.push(Expr::Column(expr.to_field(schema)?.qualified_column())); + } + + Ok(()) } /// Which type of [expressions](Expr) should be considered for rewriting? @@ -500,10 +509,9 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } @@ -516,13 +524,10 @@ impl ExprMask { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Sort { .. } - | Expr::Wildcard + | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); match self { Self::Normal => is_normal_minus_aggregates || is_aggr, @@ -705,9 +710,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { Ok(RewriteRecursion::Skip) } } - _ => Err(DataFusionError::Internal( - "expr_set invalid state".to_string(), - )), + _ => internal_err!("expr_set invalid state"), } } @@ -773,8 +776,8 @@ mod test { avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, }; use datafusion_expr::{ - AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, + Signature, StateTypeFunction, Volatility, }; use crate::optimizer::OptimizerContext; @@ -898,11 +901,10 @@ mod test { assert_eq!(inputs, &[DataType::UInt32]); Ok(Arc::new(DataType::UInt32)) }); - let accumulator: AccumulatorFunctionImplementation = - Arc::new(|_| unimplemented!()); + let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -911,6 +913,7 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) @@ -1218,7 +1221,7 @@ mod test { .map(|field| (field.name(), field.data_type())) .collect(); let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}"); - let expected = r###"[ + let expected = r#"[ ( "a", UInt64, @@ -1231,7 +1234,7 @@ mod test { "c", UInt64, ), -]"###; +]"#; assert_eq!(expected, formatted_fields_with_datatype); } @@ -1252,4 +1255,52 @@ mod test { Ok(()) } + + #[test] + fn test_extract_expressions_from_grouping_set() -> Result<()> { + let mut result = Vec::with_capacity(3); + let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]); + let schema = DFSchema::new_with_metadata( + vec![ + DFField::new_unqualified("a", DataType::Int32, false), + DFField::new_unqualified("b", DataType::Int32, false), + DFField::new_unqualified("c", DataType::Int32, false), + ], + HashMap::default(), + )?; + extract_expressions(&grouping, &schema, &mut result)?; + + assert!(result.len() == 3); + Ok(()) + } + + #[test] + fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> { + let mut result = Vec::with_capacity(2); + let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]); + let schema = DFSchema::new_with_metadata( + vec![ + DFField::new_unqualified("a", DataType::Int32, false), + DFField::new_unqualified("b", DataType::Int32, false), + ], + HashMap::default(), + )?; + extract_expressions(&grouping, &schema, &mut result)?; + + assert!(result.len() == 2); + Ok(()) + } + + #[test] + fn test_extract_expressions_from_col() -> Result<()> { + let mut result = Vec::with_capacity(1); + let schema = DFSchema::new_with_metadata( + vec![DFField::new_unqualified("a", DataType::Int32, false)], + HashMap::default(), + )?; + extract_expressions(&col("a"), &schema, &mut result)?; + + assert!(result.len() == 1); + Ok(()) + } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs new file mode 100644 index 0000000000000..b1000f042c987 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate.rs @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use crate::utils::collect_subquery_cols; +use datafusion_common::tree_node::{ + RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, +}; +use datafusion_common::{plan_err, Result}; +use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; +use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_physical_expr::execution_props::ExecutionProps; +use std::collections::{BTreeSet, HashMap}; +use std::ops::Deref; + +/// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. +/// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. +pub struct PullUpCorrelatedExpr { + pub join_filters: Vec, + // mapping from the plan to its holding correlated columns + pub correlated_subquery_cols_map: HashMap>, + pub in_predicate_opt: Option, + // indicate whether it is Exists(Not Exists) SubQuery + pub exists_sub_query: bool, + // indicate whether the correlated expressions can pull up or not + pub can_pull_up: bool, + // indicate whether need to handle the Count bug during the pull up process + pub need_handle_count_bug: bool, + // mapping from the plan to its expressions' evaluation result on empty batch + pub collected_count_expr_map: HashMap, + // pull up having expr, which must be evaluated after the Join + pub pull_up_having_expr: Option, +} + +/// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join +/// This is used to handle the Count bug +pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; + +/// Mapping from expr display name to its evaluation result on empty record batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is 'ScalarValue(2)') +pub type ExprResultMap = HashMap; + +impl TreeNodeRewriter for PullUpCorrelatedExpr { + type N = LogicalPlan; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Filter(_) => Ok(RewriteRecursion::Continue), + LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { + let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); + if plan_hold_outer { + // the unsupported case + self.can_pull_up = false; + Ok(RewriteRecursion::Stop) + } else { + Ok(RewriteRecursion::Continue) + } + } + LogicalPlan::Limit(_) => { + let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); + match (self.exists_sub_query, plan_hold_outer) { + (false, true) => { + // the unsupported case + self.can_pull_up = false; + Ok(RewriteRecursion::Stop) + } + _ => Ok(RewriteRecursion::Continue), + } + } + _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { + // the unsupported cases, the plan expressions contain out reference columns(like window expressions) + self.can_pull_up = false; + Ok(RewriteRecursion::Stop) + } + _ => Ok(RewriteRecursion::Continue), + } + } + + fn mutate(&mut self, plan: LogicalPlan) -> Result { + let subquery_schema = plan.schema().clone(); + match &plan { + LogicalPlan::Filter(plan_filter) => { + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + let (mut join_filters, subquery_filters) = + find_join_exprs(subquery_filter_exprs)?; + if let Some(in_predicate) = &self.in_predicate_opt { + // in_predicate may be already included in the join filters, remove it from the join filters first. + join_filters = remove_duplicated_filter(join_filters, in_predicate); + } + let correlated_subquery_cols = + collect_subquery_cols(&join_filters, subquery_schema)?; + for expr in join_filters { + if !self.join_filters.contains(&expr) { + self.join_filters.push(expr) + } + } + + let mut expr_result_map_for_count_bug = HashMap::new(); + let pull_up_expr_opt = if let Some(expr_result_map) = + self.collected_count_expr_map.get(plan_filter.input.deref()) + { + if let Some(expr) = conjunction(subquery_filters.clone()) { + filter_exprs_evaluation_result_on_empty_batch( + &expr, + plan_filter.input.schema().clone(), + expr_result_map, + &mut expr_result_map_for_count_bug, + )? + } else { + None + } + } else { + None + }; + + match (&pull_up_expr_opt, &self.pull_up_having_expr) { + (Some(_), Some(_)) => { + // Error path + plan_err!("Unsupported Subquery plan") + } + (Some(_), None) => { + self.pull_up_having_expr = pull_up_expr_opt; + let new_plan = + LogicalPlanBuilder::from((*plan_filter.input).clone()) + .build()?; + self.correlated_subquery_cols_map + .insert(new_plan.clone(), correlated_subquery_cols); + Ok(new_plan) + } + (None, _) => { + // if the subquery still has filter expressions, restore them. + let mut plan = + LogicalPlanBuilder::from((*plan_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + plan = plan.filter(expr)? + } + let new_plan = plan.build()?; + self.correlated_subquery_cols_map + .insert(new_plan.clone(), correlated_subquery_cols); + Ok(new_plan) + } + } + } + LogicalPlan::Projection(projection) + if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => + { + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + // add missing columns to Projection + let mut missing_exprs = + self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?; + + let mut expr_result_map_for_count_bug = HashMap::new(); + if let Some(expr_result_map) = + self.collected_count_expr_map.get(projection.input.deref()) + { + proj_exprs_evaluation_result_on_empty_batch( + &projection.expr, + projection.input.schema().clone(), + expr_result_map, + &mut expr_result_map_for_count_bug, + )?; + if !expr_result_map_for_count_bug.is_empty() { + // has count bug + let un_matched_row = Expr::Column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR.to_string(), + )); + // add the unmatched rows indicator to the Projection expressions + missing_exprs.push(un_matched_row); + } + } + + let new_plan = LogicalPlanBuilder::from((*projection.input).clone()) + .project(missing_exprs)? + .build()?; + if !expr_result_map_for_count_bug.is_empty() { + self.collected_count_expr_map + .insert(new_plan.clone(), expr_result_map_for_count_bug); + } + Ok(new_plan) + } + LogicalPlan::Aggregate(aggregate) + if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => + { + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + // add missing columns to Aggregation's group expressions + let mut missing_exprs = self.collect_missing_exprs( + &aggregate.group_expr, + &local_correlated_cols, + )?; + + // if the original group expressions are empty, need to handle the Count bug + let mut expr_result_map_for_count_bug = HashMap::new(); + if self.need_handle_count_bug + && aggregate.group_expr.is_empty() + && !missing_exprs.is_empty() + { + agg_exprs_evaluation_result_on_empty_batch( + &aggregate.aggr_expr, + aggregate.input.schema().clone(), + &mut expr_result_map_for_count_bug, + )?; + if !expr_result_map_for_count_bug.is_empty() { + // has count bug + let un_matched_row = + Expr::Literal(ScalarValue::Boolean(Some(true))) + .alias(UN_MATCHED_ROW_INDICATOR); + // add the unmatched rows indicator to the Aggregation's group expressions + missing_exprs.push(un_matched_row); + } + } + let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) + .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? + .build()?; + if !expr_result_map_for_count_bug.is_empty() { + self.collected_count_expr_map + .insert(new_plan.clone(), expr_result_map_for_count_bug); + } + Ok(new_plan) + } + LogicalPlan::SubqueryAlias(alias) => { + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + let mut new_correlated_cols = BTreeSet::new(); + for col in local_correlated_cols.iter() { + new_correlated_cols + .insert(Column::new(Some(alias.alias.clone()), col.name.clone())); + } + self.correlated_subquery_cols_map + .insert(plan.clone(), new_correlated_cols); + if let Some(input_map) = + self.collected_count_expr_map.get(alias.input.deref()) + { + self.collected_count_expr_map + .insert(plan.clone(), input_map.clone()); + } + Ok(plan) + } + LogicalPlan::Limit(limit) => { + let input_expr_map = self + .collected_count_expr_map + .get(limit.input.deref()) + .cloned(); + // handling the limit clause in the subquery + let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) + { + // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) + (true, false) => { + if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: limit.input.schema().clone(), + }) + } else { + LogicalPlanBuilder::from((*limit.input).clone()).build()? + } + } + _ => plan, + }; + if let Some(input_map) = input_expr_map { + self.collected_count_expr_map + .insert(new_plan.clone(), input_map); + } + Ok(new_plan) + } + _ => Ok(plan), + } + } +} + +impl PullUpCorrelatedExpr { + fn collect_missing_exprs( + &self, + exprs: &[Expr], + correlated_subquery_cols: &BTreeSet, + ) -> Result> { + let mut missing_exprs = vec![]; + for expr in exprs { + if !missing_exprs.contains(expr) { + missing_exprs.push(expr.clone()) + } + } + for col in correlated_subquery_cols.iter() { + let col_expr = Expr::Column(col.clone()); + if !missing_exprs.contains(&col_expr) { + missing_exprs.push(col_expr) + } + } + if let Some(pull_up_having) = &self.pull_up_having_expr { + let filter_apply_columns = pull_up_having.to_columns()?; + for col in filter_apply_columns { + let col_expr = Expr::Column(col); + if !missing_exprs.contains(&col_expr) { + missing_exprs.push(col_expr) + } + } + } + Ok(missing_exprs) + } +} + +fn collect_local_correlated_cols( + plan: &LogicalPlan, + all_cols_map: &HashMap>, + local_cols: &mut BTreeSet, +) { + for child in plan.inputs() { + if let Some(cols) = all_cols_map.get(child) { + local_cols.extend(cols.clone()); + } + // SubqueryAlias is treated as the leaf node + if !matches!(child, LogicalPlan::SubqueryAlias(_)) { + collect_local_correlated_cols(child, all_cols_map, local_cols); + } + } +} + +fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { + filters + .into_iter() + .filter(|filter| { + if filter == in_predicate { + return false; + } + + // ignore the binary order + !match (filter, in_predicate) { + (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { + (a_expr.op == b_expr.op) + && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) + || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) + } + _ => false, + } + }) + .collect::>() +} + +fn agg_exprs_evaluation_result_on_empty_batch( + agg_expr: &[Expr], + schema: DFSchemaRef, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result<()> { + for e in agg_expr.iter() { + let result_expr = e.clone().transform_up(&|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + } + _ => Transformed::No(expr), + }; + Ok(new_expr) + })?; + + let result_expr = result_expr.unalias(); + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(schema.clone()); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { + expr_result_map_for_count_bug.insert(e.display_name()?, result_expr); + } + } + Ok(()) +} + +fn proj_exprs_evaluation_result_on_empty_batch( + proj_expr: &[Expr], + schema: DFSchemaRef, + input_expr_result_map_for_count_bug: &ExprResultMap, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result<()> { + for expr in proj_expr.iter() { + let result_expr = expr.clone().transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::Yes(result_expr.clone())) + } else { + Ok(Transformed::No(expr)) + } + } else { + Ok(Transformed::No(expr)) + } + })?; + if result_expr.ne(expr) { + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(schema.clone()); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + let expr_name = match expr { + Expr::Alias(Alias { name, .. }) => name.to_string(), + Expr::Column(Column { relation: _, name }) => name.to_string(), + _ => expr.display_name()?, + }; + expr_result_map_for_count_bug.insert(expr_name, result_expr); + } + } + Ok(()) +} + +fn filter_exprs_evaluation_result_on_empty_batch( + filter_expr: &Expr, + schema: DFSchemaRef, + input_expr_result_map_for_count_bug: &ExprResultMap, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result> { + let result_expr = filter_expr.clone().transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::Yes(result_expr.clone())) + } else { + Ok(Transformed::No(expr)) + } + } else { + Ok(Transformed::No(expr)) + } + })?; + let pull_up_expr = if result_expr.ne(filter_expr) { + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(schema); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + match &result_expr { + // evaluate to false or null on empty batch, no need to pull up + Expr::Literal(ScalarValue::Null) + | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + // evaluate to true on empty batch, need to pull up the expr + Expr::Literal(ScalarValue::Boolean(Some(true))) => { + for (name, exprs) in input_expr_result_map_for_count_bug { + expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); + } + Some(filter_expr.clone()) + } + // can not evaluate statically + _ => { + for input_expr in input_expr_result_map_for_count_bug.values() { + let new_expr = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(result_expr.clone()), + Box::new(input_expr.clone()), + )], + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + }); + expr_result_map_for_count_bug + .insert(new_expr.display_name()?, new_expr); + } + None + } + } + } else { + for (name, exprs) in input_expr_result_map_for_count_bug { + expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); + } + None + }; + Ok(pull_up_expr) +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 5ecafe6e37147..450336376a239 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,30 +15,29 @@ // specific language governing permissions and limitations // under the License. -use crate::alias::AliasGenerator; +use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; -use crate::utils::{ - collect_subquery_cols, conjunction, extract_join_filters, only_or_err, - replace_qualified_name, split_conjunction, -}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Column, DataFusionError, Result}; +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; -use datafusion_expr::expr_rewriter::unnormalize_col; -use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; +use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; +use datafusion_expr::logical_plan::{JoinType, Subquery}; +use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ - exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Distinct, Expr, Filter, + exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; use log::debug; +use std::collections::BTreeSet; use std::ops::Deref; use std::sync::Arc; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins #[derive(Default)] -pub struct DecorrelatePredicateSubquery { - alias: AliasGenerator, -} +pub struct DecorrelatePredicateSubquery {} impl DecorrelatePredicateSubquery { #[allow(missing_docs)] @@ -115,7 +114,9 @@ impl OptimizerRule for DecorrelatePredicateSubquery { // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { - if let Some(plan) = build_join(&subquery, &cur_input, &self.alias)? { + if let Some(plan) = + build_join(&subquery, &cur_input, config.alias_generator())? + { cur_input = plan; } else { // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter @@ -198,43 +199,91 @@ impl OptimizerRule for DecorrelatePredicateSubquery { fn build_join( query_info: &SubqueryInfo, left: &LogicalPlan, - alias: &AliasGenerator, + alias: Arc, ) -> Result> { - let in_predicate = query_info - .where_in_expr + let where_in_expr_opt = &query_info.where_in_expr; + let in_predicate_opt = where_in_expr_opt .clone() - .map(|in_expr| { - let projection = Projection::try_from_plan(&query_info.query.subquery) - .map_err(|e| context!("a projection is required", e))?; - // TODO add the validate logic to Analyzer - let subquery_expr = only_or_err(projection.expr.as_slice()) - .map_err(|e| context!("single expression projection required", e))?; - - // in_predicate may be also include in the join filters - Ok(Expr::eq(in_expr, subquery_expr.clone())) + .map(|where_in_expr| { + query_info + .query + .subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |expr| { + Ok(Expr::eq(where_in_expr, expr)) + }) }) - .map_or(Ok(None), |v: Result| v.map(Some))?; + .map_or(Ok(None), |v| v.map(Some))?; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); - if let Some((join_filter, subquery_plan)) = - pull_up_correlated_expr(subquery, in_predicate, &subquery_alias)? - { - let sub_query_alias = LogicalPlanBuilder::from(subquery_plan) - .alias(subquery_alias.clone())? - .build()?; + + let mut pull_up = PullUpCorrelatedExpr { + join_filters: vec![], + correlated_subquery_cols_map: Default::default(), + in_predicate_opt: in_predicate_opt.clone(), + exists_sub_query: in_predicate_opt.is_none(), + can_pull_up: true, + need_handle_count_bug: false, + collected_count_expr_map: Default::default(), + pull_up_having_expr: None, + }; + let new_plan = subquery.clone().rewrite(&mut pull_up)?; + if !pull_up.can_pull_up { + return Ok(None); + } + + let sub_query_alias = LogicalPlanBuilder::from(new_plan) + .alias(subquery_alias.to_string())? + .build()?; + let mut all_correlated_cols = BTreeSet::new(); + pull_up + .correlated_subquery_cols_map + .values() + .for_each(|cols| all_correlated_cols.extend(cols.clone())); + + // alias the join filter + let join_filter_opt = + conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) + .map(Option::Some) + })?; + + if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { + ( + Some(join_filter), + Some(Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + })), + ) => { + let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); + Some(in_predicate.and(join_filter)) + } + (Some(join_filter), _) => Some(join_filter), + ( + _, + Some(Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + })), + ) => { + let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); + Some(in_predicate) + } + _ => None, + } { // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; let new_plan = LogicalPlanBuilder::from(left.clone()) - .join( - sub_query_alias, - join_type, - (Vec::::new(), Vec::::new()), - Some(join_filter), - )? + .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; debug!( "predicate subquery optimized:\n{}", @@ -246,141 +295,6 @@ fn build_join( } } -/// This function pull up the correlated expressions(contains outer reference columns) from the inner subquery's [Filter]. -/// It adds the inner reference columns to the [Projection] of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. -/// -/// This function can't handle the non-correlated subquery, and will return None. -fn pull_up_correlated_expr( - subquery: &LogicalPlan, - in_predicate_opt: Option, - subquery_alias: &str, -) -> Result> { - match subquery { - LogicalPlan::Distinct(subqry_distinct) => { - let distinct_input = &subqry_distinct.input; - let optimized_plan = pull_up_correlated_expr( - distinct_input, - in_predicate_opt, - subquery_alias, - )? - .map(|(filters, right)| { - ( - filters, - LogicalPlan::Distinct(Distinct { - input: Arc::new(right), - }), - ) - }); - Ok(optimized_plan) - } - LogicalPlan::Projection(projection) => { - // extract join filters from the inner subquery's Filter - let (mut join_filters, subquery_input) = - extract_join_filters(&projection.input)?; - if in_predicate_opt.is_none() && join_filters.is_empty() { - // cannot rewrite non-correlated subquery - return Ok(None); - } - - if let Some(in_predicate) = &in_predicate_opt { - // in_predicate may be already included in the join filters, remove it from the join filters first. - join_filters = remove_duplicated_filter(join_filters, in_predicate); - } - let input_schema = subquery_input.schema(); - let correlated_subquery_cols = - collect_subquery_cols(&join_filters, input_schema.clone())?; - - // add missing columns to projection - let mut project_exprs: Vec = - if let Some(Expr::BinaryExpr(BinaryExpr { - left: _, - op: Operator::Eq, - right, - })) = &in_predicate_opt - { - if !matches!(right.deref(), Expr::Column(_)) { - vec![right.deref().clone().alias(format!( - "{:?}", - unnormalize_col(right.deref().clone()) - ))] - } else { - vec![right.deref().clone()] - } - } else { - vec![] - }; - // the inner reference cols need to added to the projection if they are missing. - for col in correlated_subquery_cols.iter() { - let col_expr = Expr::Column(col.clone()); - if !project_exprs.contains(&col_expr) { - project_exprs.push(col_expr) - } - } - - // alias the join filter - let join_filter_opt = - conjunction(join_filters).map_or(Ok(None), |filter| { - replace_qualified_name( - filter, - &correlated_subquery_cols, - subquery_alias, - ) - .map(Option::Some) - })?; - - let join_filter = if let Some(Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - })) = in_predicate_opt - { - let right_expr_name = - format!("{:?}", unnormalize_col(right.deref().clone())); - let right_col = - Column::new(Some(subquery_alias.to_string()), right_expr_name); - let in_predicate = - Expr::eq(left.deref().clone(), Expr::Column(right_col)); - join_filter_opt - .map(|filter| in_predicate.clone().and(filter)) - .unwrap_or_else(|| in_predicate) - } else { - join_filter_opt.ok_or_else(|| { - DataFusionError::Internal( - "join filters should not be empty".to_string(), - ) - })? - }; - - let right = LogicalPlanBuilder::from(subquery_input) - .project(project_exprs)? - .build()?; - Ok(Some((join_filter, right))) - } - _ => Ok(None), - } -} - -fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { - filters - .into_iter() - .filter(|filter| { - if filter == in_predicate { - return false; - } - - // ignore the binary order - !match (filter, in_predicate) { - (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { - (a_expr.op == b_expr.op) - && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) - || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) - } - _ => false, - } - }) - .collect::>() -} - struct SubqueryInfo { query: Subquery, where_in_expr: Option, @@ -908,11 +822,11 @@ mod tests { .build()?; // Maybe okay if the table only has a single column? - assert_optimizer_err( - Arc::new(DecorrelatePredicateSubquery::new()), - &plan, - "a projection is required", - ); + let expected = "check_analyzed_plan\ + \ncaused by\ + \nError during planning: InSubquery should only return one column, but found 4"; + assert_analyzer_check_err(vec![], &plan, expected); + Ok(()) } @@ -968,10 +882,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1) AS o_custkey + Int32(1), orders.o_custkey [o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_optimized_plan_eq_display_indent( @@ -1003,11 +917,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimizer_err( - Arc::new(DecorrelatePredicateSubquery::new()), - &plan, - "single expression projection required", - ); + let expected = "check_analyzed_plan\ + \ncaused by\ + \nError during planning: InSubquery should only return one column"; + assert_analyzer_check_err(vec![], &plan, expected); + Ok(()) } @@ -1179,10 +1093,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\ - \n Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * UInt32(2):UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]\ + \n Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq_display_indent( @@ -1213,10 +1127,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * UInt32(2):UInt32, a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]\ \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; @@ -1249,10 +1163,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; @@ -1292,14 +1206,14 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]\ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq_display_indent( @@ -1460,12 +1374,11 @@ mod tests { .build()?; // Other rule will pushdown `customer.c_custkey = 1`, - // TODO revisit the logic, is it a valid physical plan when no cols in projection? let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 []\ - \n Projection: []\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_optimized_plan_equal(&plan, expected) @@ -1579,7 +1492,10 @@ mod tests { fn exists_subquery_no_projection() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .filter( + out_ref_col(DataType::Int64, "customer.c_custkey") + .eq(col("orders.o_custkey")), + )? .build()?, ); @@ -1588,7 +1504,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_equal(&plan, expected) } /// Test for correlated exists expressions @@ -1612,8 +1534,8 @@ mod tests { let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_optimized_plan_equal(&plan, expected) @@ -1692,8 +1614,8 @@ mod tests { let expected = "Projection: test.c [c:UInt32]\ \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ + \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -1752,11 +1674,11 @@ mod tests { \n LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Projection: sq1.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ + \n Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]\ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ - \n Projection: sq2.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]\ + \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -1781,8 +1703,8 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]\ + \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -1834,9 +1756,9 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Distinct: [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ + \n Distinct: [c:UInt32, a:UInt32]\ + \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -1862,9 +1784,9 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Distinct: [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]\ + \n Distinct: [sq.b + sq.c:UInt32, a:UInt32]\ + \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -1890,9 +1812,9 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ - \n Distinct: [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]\ + \n Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]\ + \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 533566a0bf695..cf9a59d6b892f 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,13 +20,13 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator}; +use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; #[derive(Default)] pub struct EliminateCrossJoin; @@ -60,30 +60,23 @@ impl OptimizerRule for EliminateCrossJoin { let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; let mut all_inputs: Vec = vec![]; - match &input { - LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/arrow-datafusion/issues/4844 - if join.filter.is_some() { - return Ok(None); - } - - flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?; - } - LogicalPlan::CrossJoin(_) => { - flatten_join_inputs( - &input, - &mut possible_join_keys, - &mut all_inputs, - )?; - } + let did_flat_successfully = match &input { + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) + | LogicalPlan::CrossJoin(_) => try_flatten_join_inputs( + &input, + &mut possible_join_keys, + &mut all_inputs, + )?, _ => { return utils::optimize_children(self, plan, config); } + }; + + if !did_flat_successfully { + return Ok(None); } let predicate = &filter.predicate; @@ -137,13 +130,20 @@ impl OptimizerRule for EliminateCrossJoin { } } -fn flatten_join_inputs( +/// Recursively accumulate possible_join_keys and inputs from inner joins (including cross joins). +/// Returns a boolean indicating whether the flattening was successful. +fn try_flatten_join_inputs( plan: &LogicalPlan, possible_join_keys: &mut Vec<(Expr, Expr)>, all_inputs: &mut Vec, -) -> Result<()> { +) -> Result { let children = match plan { - LogicalPlan::Join(join) => { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + if join.filter.is_some() { + // The filter of inner join will lost, skip this rule. + // issue: https://github.com/apache/arrow-datafusion/issues/4844 + return Ok(false); + } possible_join_keys.extend(join.on.clone()); let left = &*(join.left); let right = &*(join.right); @@ -155,28 +155,25 @@ fn flatten_join_inputs( vec![left, right] } _ => { - return Err(DataFusionError::Plan( - "flatten_join_inputs just can call join/cross_join".to_string(), - )); + return plan_err!("flatten_join_inputs just can call join/cross_join"); } }; for child in children.iter() { match *child { - LogicalPlan::Join(left_join) => { - if left_join.join_type == JoinType::Inner { - flatten_join_inputs(child, possible_join_keys, all_inputs)?; - } else { - all_inputs.push((*child).clone()); + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) + | LogicalPlan::CrossJoin(_) => { + if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? { + return Ok(false); } } - LogicalPlan::CrossJoin(_) => { - flatten_join_inputs(child, possible_join_keys, all_inputs)?; - } _ => all_inputs.push((*child).clone()), } } - Ok(()) + Ok(true) } fn find_inner_join( @@ -197,13 +194,10 @@ fn find_inner_join( )?; // Save join keys - match key_pair { - Some((valid_l, valid_r)) => { - if can_hash(&valid_l.get_type(left_input.schema())?) { - join_keys.push((valid_l, valid_r)); - } + if let Some((valid_l, valid_r)) = key_pair { + if can_hash(&valid_l.get_type(left_input.schema())?) { + join_keys.push((valid_l, valid_r)); } - _ => continue, } } @@ -298,39 +292,33 @@ fn remove_join_expressions( join_keys: &HashSet<(Expr, Expr)>, ) -> Result> { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => { - if join_keys.contains(&(*left.clone(), *right.clone())) - || join_keys.contains(&(*right.clone(), *left.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } - } - Operator::And => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match op { + Operator::Eq => { + if join_keys.contains(&(*left.clone(), *right.clone())) + || join_keys.contains(&(*right.clone(), *left.clone())) + { + Ok(None) + } else { + Ok(Some(expr.clone())) + } } - } - // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Operator::Or => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(or(ll, rr))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), + // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. + Operator::And | Operator::Or => { + let l = remove_join_expressions(left, join_keys)?; + let r = remove_join_expressions(right, join_keys)?; + match (l, r) { + (Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr( + BinaryExpr::new(Box::new(ll), *op, Box::new(rr)), + ))), + (Some(ll), _) => Ok(Some(ll)), + (_, Some(rr)) => Ok(Some(rr)), + _ => Ok(None), + } } + _ => Ok(Some(expr.clone())), } - _ => Ok(Some(expr.clone())), - }, + } _ => Ok(Some(expr.clone())), } } @@ -365,6 +353,12 @@ mod tests { assert_eq!(plan.schema(), optimized_plan.schema()) } + fn assert_optimization_rule_fails(plan: &LogicalPlan) { + let rule = EliminateCrossJoin::new(); + let optimized_plan = rule.try_optimize(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.is_none()); + } + #[test] fn eliminate_cross_with_simple_and() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -533,6 +527,30 @@ mod tests { Ok(()) } + #[test] + /// See https://github.com/apache/arrow-datafusion/issues/7530 + fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + let t3 = test_table_scan_with_name("t3")?; + + // could not eliminate to inner join with filter + let plan = LogicalPlanBuilder::from(t1) + .join( + t3, + JoinType::Inner, + (vec!["t1.a"], vec!["t3.a"]), + Some(col("t1.a").gt(lit(20u32))), + )? + .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)? + .filter(col("t1.a").gt(lit(15u32)))? + .build()?; + + assert_optimization_rule_fails(&plan); + + Ok(()) + } + #[test] /// ```txt /// filter: a.id = b.id and a.id = c.id diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 00abcdcc68aa2..0dbebcc8a0519 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -77,7 +77,7 @@ impl OptimizerRule for EliminateJoin { mod tests { use crate::eliminate_join::EliminateJoin; use crate::test::*; - use datafusion_common::{Column, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; @@ -89,10 +89,9 @@ mod tests { #[test] fn join_on_false() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(false)))), )? .build()?; @@ -104,10 +103,9 @@ mod tests { #[test] fn join_on_true() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .join( + .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - (Vec::::new(), Vec::::new()), Some(Expr::Literal(ScalarValue::Boolean(Some(true)))), )? .build()?; diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 7844ca7909fce..4386253740aaa 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -97,7 +97,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs new file mode 100644 index 0000000000000..5771ea2e19a29 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to replace nested unions to single union. +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::{Distinct, LogicalPlan, Union}; +use std::sync::Arc; + +#[derive(Default)] +/// An optimization rule that replaces nested unions with a single union. +pub struct EliminateNestedUnion; + +impl EliminateNestedUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateNestedUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + }))) + } + LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .collect::>(); + + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { + inputs, + schema: schema.clone(), + }), + ))))) + } + _ => Ok(None), + }, + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_nested_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +fn extract_plans_from_union(plan: &Arc) -> Vec> { + match plan.as_ref() { + LogicalPlan::Union(Union { inputs, schema }) => inputs + .iter() + .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .collect::>(), + _ => vec![plan.clone()], + } +} + +fn extract_plan_from_distinct(plan: &Arc) -> &Arc { + match plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(plan)) => plan, + _ => plan, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, logical_plan::table_scan}; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_distinct_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "Union\ + \n Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .union(plan_builder.clone().build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct(plan_builder.clone().distinct()?.build()?)? + .union(plan_builder.clone().distinct()?.build()?)? + .union_distinct(plan_builder.clone().build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + // We don't need to use project_with_column_index in logical optimizer, + // after LogicalPlanBuilder::union, we already have all equal expression aliases + #[test] + fn eliminate_nested_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .build()?, + )? + .union_distinct( + plan_builder + .clone() + .project(vec![col("id").alias("_id"), col("key"), col("value")])? + .build()?, + )? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table\ + \n Projection: table.id AS id, table.key, table.value\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union(table_2.build()?)? + .union(table_3.build()?)? + .build()?; + + let expected = "Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> { + let table_1 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, false), + ]), + None, + )?; + + let table_2 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let table_3 = table_scan( + Some("table_1"), + &Schema::new(vec![ + Field::new("id", DataType::Int16, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float32, false), + ]), + None, + )?; + + let plan = table_1 + .union_distinct(table_2.build()?)? + .union_distinct(table_3.build()?)? + .build()?; + + let expected = "Distinct:\ + \n Union\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1\ + \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ + \n TableScan: table_1"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs new file mode 100644 index 0000000000000..70ee490346ffb --- /dev/null +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to eliminate one union. +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; + +use crate::optimizer::ApplyOrder; + +#[derive(Default)] +/// An optimization rule that eliminates union with one element. +pub struct EliminateOneUnion; + +impl EliminateOneUnion { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateOneUnion { + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => { + Ok(inputs.first().map(|input| input.as_ref().clone())) + } + _ => Ok(None), + } + } + + fn name(&self) -> &str { + "eliminate_one_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{ + expr_rewriter::coerce_plan_expr_for_schema, + logical_plan::{table_scan, Union}, + }; + use std::sync::Arc; + + fn schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ]) + } + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_with_rules( + vec![Arc::new(EliminateOneUnion::new())], + plan, + expected, + ) + } + + #[test] + fn eliminate_nothing() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .build()?; + + let expected = "\ + Union\ + \n TableScan: table\ + \n TableScan: table"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn eliminate_one_union() -> Result<()> { + let table_plan = coerce_plan_expr_for_schema( + &table_scan(Some("table"), &schema(), None)?.build()?, + &schema().to_dfschema()?, + )?; + let schema = table_plan.schema().clone(); + let single_union_plan = LogicalPlan::Union(Union { + inputs: vec![Arc::new(table_plan)], + schema, + }); + + let expected = "TableScan: table"; + assert_optimized_plan_equal(&single_union_plan, expected) + } +} diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 8dfdfae035a12..e4d57f0209a46 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -18,10 +18,7 @@ //! Optimizer rule to eliminate left/right/full join to inner join if possible. use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; -use datafusion_expr::{ - logical_plan::{Join, JoinType, LogicalPlan}, - utils::from_plan, -}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Operator}; use crate::optimizer::ApplyOrder; @@ -109,7 +106,7 @@ impl OptimizerRule for EliminateOuterJoin { schema: join.schema.clone(), null_equals_null: join.null_equals_null, }); - let new_plan = from_plan(plan, &plan.expressions(), &[new_join])?; + let new_plan = plan.with_new_inputs(&[new_join])?; Ok(Some(new_plan)) } _ => Ok(None), diff --git a/datafusion/optimizer/src/eliminate_project.rs b/datafusion/optimizer/src/eliminate_project.rs deleted file mode 100644 index 5c43b8d12cb6a..0000000000000 --- a/datafusion/optimizer/src/eliminate_project.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchemaRef, Result}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Expr, Projection}; - -/// Optimization rule that eliminate unnecessary [LogicalPlan::Projection]. -#[derive(Default)] -pub struct EliminateProjection; - -impl EliminateProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for EliminateProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(projection) => { - let child_plan = projection.input.as_ref(); - match child_plan { - LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) - | LogicalPlan::Union(_) - | LogicalPlan::Filter(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Sort(_) => { - if can_eliminate(projection, child_plan.schema()) { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - _ => { - if plan.schema() == child_plan.schema() { - Ok(Some(child_plan.clone())) - } else { - Ok(None) - } - } - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "eliminate_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub(crate) fn can_eliminate(projection: &Projection, schema: &DFSchemaRef) -> bool { - if projection.expr.len() != schema.fields().len() { - return false; - } - for (i, e) in projection.expr.iter().enumerate() { - match e { - Expr::Column(c) => { - let d = schema.fields().get(i).unwrap(); - if c != &d.qualified_column() && c != &d.unqualified_column() { - return false; - } - } - _ => return false, - } - } - true -} diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 20b9c629712c6..24664d57c38d8 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -15,20 +15,29 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to extract equijoin expr from filter +//! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates use crate::optimizer::ApplyOrder; -use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; +use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair, split_conjunction}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; // equijoin predicate type EquijoinPredicate = (Expr, Expr); -/// Optimization rule that extract equijoin expr from the filter +/// Optimizer that splits conjunctive join predicates into equijoin +/// predicates and (other) filter predicates. +/// +/// Join algorithms are often highly optimized for equality predicates such as `x = y`, +/// often called `equijoin` predicates, so it is important to locate such predicates +/// and treat them specially. +/// +/// For example, `SELECT ... FROM A JOIN B ON (A.x = B.y AND B.z > 50)` +/// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`). +/// See [find_valid_equijoin_key_pair] for more information on what predicates +/// are considered equijoins. #[derive(Default)] pub struct ExtractEquijoinPredicate; @@ -151,7 +160,6 @@ mod tests { use super::*; use crate::test::*; use arrow::datatypes::DataType; - use datafusion_common::Column; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; @@ -172,12 +180,7 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(col("t1.a").eq(col("t2.a"))), - )? + .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ @@ -192,10 +195,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; @@ -212,10 +214,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( (col("t1.a") + lit(10i64)) .gt_eq(col("t2.a") * lit(2u32)) @@ -263,10 +264,9 @@ mod tests { let t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( t2, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.c") .eq(col("t2.c")) @@ -291,10 +291,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -303,10 +302,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t1.a") .eq(col("t2.a")) @@ -330,10 +328,9 @@ mod tests { let t3 = test_table_scan_with_name("t3")?; let input = LogicalPlanBuilder::from(t2) - .join( + .join_on( t3, JoinType::Left, - (Vec::::new(), Vec::::new()), Some( col("t2.a") .eq(col("t3.a")) @@ -342,10 +339,9 @@ mod tests { )? .build()?; let plan = LogicalPlanBuilder::from(t1) - .join( + .join_on( input, JoinType::Left, - (Vec::::new(), Vec::::new()), Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; @@ -373,12 +369,7 @@ mod tests { ) .alias("t1.a + 1 = t2.a + 2"); let plan = LogicalPlanBuilder::from(t1) - .join( - t2, - JoinType::Left, - (Vec::::new(), Vec::::new()), - Some(filter), - )? + .join_on(t2, JoinType::Left, Some(filter))? .build()?; let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index d500debf59d4c..b54facc5d6825 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -15,20 +15,21 @@ // specific language governing permissions and limitations // under the License. -pub mod alias; pub mod analyzer; pub mod common_subexpr_eliminate; +pub mod decorrelate; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; pub mod eliminate_join; pub mod eliminate_limit; +pub mod eliminate_nested_union; +pub mod eliminate_one_union; pub mod eliminate_outer_join; -pub mod eliminate_project; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; -pub mod merge_projection; +pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs deleted file mode 100644 index d551283015a97..0000000000000 --- a/datafusion/optimizer/src/merge_projection.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::optimizer::ApplyOrder; -use datafusion_common::Result; -use datafusion_expr::{Expr, LogicalPlan, Projection}; -use std::collections::HashMap; - -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; - -/// Optimization rule that merge [LogicalPlan::Projection]. -#[derive(Default)] -pub struct MergeProjection; - -impl MergeProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for MergeProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Projection(parent_projection) => { - match parent_projection.input.as_ref() { - LogicalPlan::Projection(child_projection) => { - let replace_map = collect_projection_expr(child_projection); - let new_exprs = parent_projection - .expr - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .enumerate() - .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = parent_projection.schema.fields() - [i] - .qualified_name(); - if e.display_name()? == parent_expr { - Ok(e) - } else { - Ok(e.alias(parent_expr)) - } - } - Err(e) => Err(e), - }) - .collect::>>()?; - let new_plan = - LogicalPlan::Projection(Projection::try_new_with_schema( - new_exprs, - child_projection.input.clone(), - parent_projection.schema.clone(), - )?); - Ok(Some( - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan), - )) - } - _ => Ok(None), - } - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "merge_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias - let expr = projection.expr[i].clone().unalias(); - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -#[cfg(test)] -mod tests { - use crate::merge_projection::MergeProjection; - use datafusion_common::Result; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan, - Operator, - }; - use std::sync::Arc; - - use crate::test::*; - - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(MergeProjection::new()), plan, expected) - } - - #[test] - fn merge_two_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_three_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .project(vec![col("a")])? - .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? - .build()?; - - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } - - #[test] - fn merge_alias() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")])? - .project(vec![col("a").alias("alias")])? - .build()?; - - let expected = "Projection: test.a AS alias\ - \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) - } -} diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs new file mode 100644 index 0000000000000..7ae9f7edf5e51 --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -0,0 +1,1062 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to prune unnecessary columns from intermediate schemas +//! inside the [`LogicalPlan`]. This rule: +//! - Removes unnecessary columns that do not appear at the output and/or are +//! not used during any computation step. +//! - Adds projections to decrease table column size before operators that +//! benefit from a smaller memory footprint at its input. +//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. + +use std::collections::HashSet; +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, +}; +use datafusion_expr::expr::{Alias, ScalarFunction, ScalarFunctionDefinition}; +use datafusion_expr::{ + logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, + Expr, GroupingSet, Projection, TableScan, Window, +}; + +use hashbrown::HashMap; +use itertools::{izip, Itertools}; + +/// A rule for optimizing logical plans by removing unused columns/fields. +/// +/// `OptimizeProjections` is an optimizer rule that identifies and eliminates +/// columns from a logical plan that are not used by downstream operations. +/// This can improve query performance and reduce unnecessary data processing. +/// +/// The rule analyzes the input logical plan, determines the necessary column +/// indices, and then removes any unnecessary columns. It also removes any +/// unnecessary projections from the plan tree. +#[derive(Default)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = (0..plan.schema().fields().len()).collect::>(); + optimize_projections(plan, config, &indices) + } + + fn name(&self) -> &str { + "optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } +} + +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. +/// +/// # Parameters +/// +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream operations. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occured during the optimization process. +fn optimize_projections( + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + // `child_required_indices` stores + // - indices of the columns required for each child + // - a flag indicating whether putting a projection above children is beneficial for the parent. + // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. + let child_required_indices: Vec<(Vec, bool)> = match plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, true)) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + let exprs = plan.expressions(); + plan.inputs() + .into_iter() + .map(|input| { + get_all_required_indices(indices, input, exprs.iter()) + .map(|idxs| (idxs, false)) + }) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .iter() + .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) + .collect::>() + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::Extension(_) + | LogicalPlan::DescribeTable(_) => { + // These operators have no inputs, so stop the optimization process. + // TODO: Add support for `LogicalPlan::Extension`. + return Ok(None); + } + LogicalPlan::Projection(proj) => { + return if let Some(proj) = merge_consecutive_projections(proj)? { + Ok(Some( + rewrite_projection_given_requirements(&proj, config, indices)? + // Even if we cannot optimize the projection, merge if possible: + .unwrap_or_else(|| LogicalPlan::Projection(proj)), + )) + } else { + rewrite_projection_given_requirements(proj, config, indices) + }; + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_group_exprs); + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + for idx in aggregate_reqs.iter_mut() { + *idx -= n_group_exprs; + } + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + let required_indices = + merge_slices(&simplest_groupby_indices, &group_by_reqs); + get_at_indices(&aggregate.group_expr, &required_indices) + } else { + aggregate.group_expr.clone() + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; + + let aggregate_input = if let Some(input) = + optimize_projections(&aggregate.input, config, &necessary_indices)? + { + input + } else { + aggregate.input.as_ref().clone() + }; + + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + let necessary_exprs = get_required_exprs(schema, &necessary_indices); + let (aggregate_input, _) = + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; + } + + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + return Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); + } + LogicalPlan::Window(window) => { + // Split parent requirements to child and window expression sections: + let n_input_fields = window.input.schema().fields().len(); + let (child_reqs, mut window_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < n_input_fields); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + for idx in window_reqs.iter_mut() { + *idx -= n_input_fields; + } + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = get_all_required_indices( + &child_reqs, + &window.input, + new_window_expr.iter(), + )?; + let window_child = if let Some(new_window_child) = + optimize_projections(&window.input, config, &required_indices)? + { + new_window_child + } else { + window.input.as_ref().clone() + }; + + return if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Some(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `old_child`, because `required_indices` + // refers to `old_child`. + let required_exprs = + get_required_exprs(window.input.schema(), &required_indices); + let (window_child, _) = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(|window| Some(LogicalPlan::Window(window))) + }; + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let exprs = plan.expressions(); + let left_indices = + get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; + let right_indices = + get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_indices, true), (right_indices, true)] + } + LogicalPlan::CrossJoin(cross_join) => { + let left_len = cross_join.left.schema().fields().len(); + let (left_child_indices, right_child_indices) = + split_join_requirements(left_len, indices, &JoinType::Inner); + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![(left_child_indices, true), (right_child_indices, true)] + } + LogicalPlan::TableScan(table_scan) => { + let schema = table_scan.source.schema(); + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = with_indices(&table_scan.projection, schema, |map| { + indices.iter().map(|&idx| map[idx]).collect() + }); + + return TableScan::try_new( + table_scan.table_name.clone(), + table_scan.source.clone(), + Some(projection), + table_scan.filters.clone(), + table_scan.fetch, + ) + .map(|table| Some(LogicalPlan::TableScan(table))); + } + }; + + let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) + .map(|((required_indices, projection_beneficial), child)| { + let (input, is_changed) = if let Some(new_input) = + optimize_projections(child, config, &required_indices)? + { + (new_input, true) + } else { + (child.clone(), false) + }; + let project_exprs = get_required_exprs(child.schema(), &required_indices); + let (input, proj_added) = if projection_beneficial { + add_projection_on_top_if_helpful(input, project_exprs)? + } else { + (input, false) + }; + Ok((is_changed || proj_added).then_some(input)) + }) + .collect::>>()?; + if new_inputs.iter().all(|child| child.is_none()) { + // All children are the same in this case, no need to change the plan: + Ok(None) + } else { + // At least one of the children is changed: + let new_inputs = izip!(new_inputs, plan.inputs()) + // If new_input is `None`, this means child is not changed, so use + // `old_child` during construction: + .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) + .collect::>(); + plan.with_new_inputs(&new_inputs).map(Some) + } +} + +/// This function applies the given function `f` to the projection indices +/// `proj_indices` if they exist. Otherwise, applies `f` to a default set +/// of indices according to `schema`. +fn with_indices( + proj_indices: &Option>, + schema: SchemaRef, + mut f: F, +) -> Vec +where + F: FnMut(&[usize]) -> Vec, +{ + match proj_indices { + Some(indices) => f(indices.as_slice()), + None => { + let range: Vec = (0..schema.fields.len()).collect(); + f(range.as_slice()) + } + } +} + +/// Merges consecutive projections. +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. +/// +/// # Parameters +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. +fn merge_consecutive_projections(proj: &Projection) -> Result> { + let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { + return Ok(None); + }; + + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::::new(); + for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { + for col in columns.into_iter() { + *column_referral_map.entry(col.clone()).or_default() += 1; + } + } + + // If an expression is non-trivial and appears more than once, consecutive + // projections will benefit from a compute-once approach. For details, see: + // https://github.com/apache/arrow-datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 + && !is_expr_trivial( + &prev_projection.expr + [prev_projection.schema.index_of_column(&col).unwrap()], + ) + }) { + return Ok(None); + } + + // If all the expression of the top projection can be rewritten, do so and + // create a new projection: + let new_exprs = proj + .expr + .iter() + .map(|expr| rewrite_expr(expr, prev_projection)) + .collect::>>>()?; + if let Some(new_exprs) = new_exprs { + let new_exprs = new_exprs + .into_iter() + .zip(proj.expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.alias_if_changed(old_expr.name_for_alias()?) + }) + .collect::>>()?; + Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + } else { + Ok(None) + } +} + +/// Trim the given expression by removing any unnecessary layers of aliasing. +/// If the expression is an alias, the function returns the underlying expression. +/// Otherwise, it returns the given expression as is. +/// +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. +/// +/// Consider: +/// +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` +/// +/// After merge, we want to produce: +/// +/// ```text +/// Projection(a + b as sum1) +/// --Source(a, b) +/// ``` +/// +/// Without trimming, we would end up with: +/// +/// ```text +/// Projection((a as a1 + b as b1) as sum1) +/// --Source(a, b) +/// ``` +fn trim_expr(expr: Expr) -> Expr { + match expr { + Expr::Alias(alias) => trim_expr(*alias.expr), + _ => expr, + } +} + +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +// Exit early when there is no rewrite to do. +macro_rules! rewrite_expr_with_check { + ($expr:expr, $input:expr) => { + if let Some(value) = rewrite_expr($expr, $input)? { + value + } else { + return Ok(None); + } + }; +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { + let result = match expr { + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(col)?; + input.expr[idx].clone() + } + Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( + Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), + binary.op, + Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), + )), + Expr::Alias(alias) => Expr::Alias(Alias::new( + trim_expr(rewrite_expr_with_check!(&alias.expr, input)), + alias.relation.clone(), + alias.name.clone(), + )), + Expr::Literal(_) => expr.clone(), + Expr::Cast(cast) => { + let new_expr = rewrite_expr_with_check!(&cast.expr, input); + Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) + } + Expr::ScalarFunction(scalar_fn) => { + // TODO: Support UDFs. + let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else { + return Ok(None); + }; + return Ok(scalar_fn + .args + .iter() + .map(|expr| rewrite_expr(expr, input)) + .collect::>>()? + .map(|new_args| { + Expr::ScalarFunction(ScalarFunction::new(fun, new_args)) + })); + } + // Unsupported type for consecutive projection merge analysis. + _ => return Ok(None), + }; + Ok(Some(result)) +} + +/// Retrieves a set of outer-referenced columns by the given expression, `expr`. +/// Note that the `Expr::to_columns()` function doesn't return these columns. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// +/// # Returns +/// +/// If the function can safely infer all outer-referenced columns, returns a +/// `Some(HashSet)` containing these columns. Otherwise, returns `None`. +fn outer_columns(expr: &Expr) -> Option> { + let mut columns = HashSet::new(); + outer_columns_helper(expr, &mut columns).then_some(columns) +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expression, `expr`. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper(expr: &Expr, columns: &mut HashSet) -> bool { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col.clone()); + true + } + Expr::BinaryExpr(binary_expr) => { + outer_columns_helper(&binary_expr.left, columns) + && outer_columns_helper(&binary_expr.right, columns) + } + Expr::ScalarSubquery(subquery) => { + let exprs = subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Exists(exists) => { + let exprs = exists.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns), + Expr::InSubquery(insubquery) => { + let exprs = insubquery.subquery.outer_ref_columns.iter(); + outer_columns_helper_multi(exprs, columns) + } + Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns), + Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns), + Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns), + Expr::AggregateFunction(aggregate_fn) => { + outer_columns_helper_multi(aggregate_fn.args.iter(), columns) + && aggregate_fn + .order_by + .as_ref() + .map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns)) + && aggregate_fn + .filter + .as_ref() + .map_or(true, |filter| outer_columns_helper(filter, columns)) + } + Expr::WindowFunction(window_fn) => { + outer_columns_helper_multi(window_fn.args.iter(), columns) + && outer_columns_helper_multi(window_fn.order_by.iter(), columns) + && outer_columns_helper_multi(window_fn.partition_by.iter(), columns) + } + Expr::GroupingSet(groupingset) => match groupingset { + GroupingSet::GroupingSets(multi_exprs) => multi_exprs + .iter() + .all(|e| outer_columns_helper_multi(e.iter(), columns)), + GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => { + outer_columns_helper_multi(exprs.iter(), columns) + } + }, + Expr::ScalarFunction(scalar_fn) => { + outer_columns_helper_multi(scalar_fn.args.iter(), columns) + } + Expr::Like(like) => { + outer_columns_helper(&like.expr, columns) + && outer_columns_helper(&like.pattern, columns) + } + Expr::InList(in_list) => { + outer_columns_helper(&in_list.expr, columns) + && outer_columns_helper_multi(in_list.list.iter(), columns) + } + Expr::Case(case) => { + let when_then_exprs = case + .when_then_expr + .iter() + .flat_map(|(first, second)| [first.as_ref(), second.as_ref()]); + outer_columns_helper_multi(when_then_exprs, columns) + && case + .expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + && case + .else_expr + .as_ref() + .map_or(true, |expr| outer_columns_helper(expr, columns)) + } + Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true, + _ => false, + } +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +/// +/// Returns `true` if it can safely collect all outer-referenced columns. +/// Otherwise, returns `false`. +fn outer_columns_helper_multi<'a>( + mut exprs: impl Iterator, + columns: &mut HashSet, +) -> bool { + exprs.all(|e| outer_columns_helper(e, columns)) +} + +/// Generates the required expressions (columns) that reside at `indices` of +/// the given `input_schema`. +/// +/// # Arguments +/// +/// * `input_schema` - A reference to the input schema. +/// * `indices` - A slice of `usize` indices specifying required columns. +/// +/// # Returns +/// +/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. +fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { + let fields = input_schema.fields(); + indices + .iter() + .map(|&idx| Expr::Column(fields[idx].qualified_column())) + .collect() +} + +/// Get indices of the fields referred to by any expression in `exprs` within +/// the given schema (`input_schema`). +/// +/// # Arguments +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `exprs`: An iterator of expressions for which we want to find necessary +/// field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate all `exprs` successfully. +fn indices_referred_by_exprs<'a>( + input_schema: &DFSchemaRef, + exprs: impl Iterator, +) -> Result> { + let indices = exprs + .map(|expr| indices_referred_by_expr(input_schema, expr)) + .collect::>>()?; + Ok(indices + .into_iter() + .flatten() + // Make sure no duplicate entries exist and indices are ordered: + .sorted() + .dedup() + .collect()) +} + +/// Get indices of the fields referred to by the given expression `expr` within +/// the given schema (`input_schema`). +/// +/// # Parameters +/// +/// * `input_schema`: The input schema to analyze for index requirements. +/// * `expr`: An expression for which we want to find necessary field indices. +/// +/// # Returns +/// +/// A [`Result`] object containing the indices of all required fields in +/// `input_schema` to calculate `expr` successfully. +fn indices_referred_by_expr( + input_schema: &DFSchemaRef, + expr: &Expr, +) -> Result> { + let mut cols = expr.to_columns()?; + // Get outer-referenced columns: + if let Some(outer_cols) = outer_columns(expr) { + cols.extend(outer_cols); + } else { + // Expression is not known to contain outer columns or not. Hence, do + // not assume anything and require all the schema indices at the input: + return Ok((0..input_schema.fields().len()).collect()); + } + Ok(cols + .iter() + .flat_map(|col| input_schema.index_of_column(col)) + .collect()) +} + +/// Gets all required indices for the input; i.e. those required by the parent +/// and those referred to by `exprs`. +/// +/// # Parameters +/// +/// * `parent_required_indices` - A slice of indices required by the parent plan. +/// * `input` - The input logical plan to analyze for index requirements. +/// * `exprs` - An iterator of expressions used to determine required indices. +/// +/// # Returns +/// +/// A `Result` containing a vector of `usize` indices containing all the required +/// indices. +fn get_all_required_indices<'a>( + parent_required_indices: &[usize], + input: &LogicalPlan, + exprs: impl Iterator, +) -> Result> { + indices_referred_by_exprs(input.schema(), exprs) + .map(|indices| merge_slices(parent_required_indices, &indices)) +} + +/// Retrieves the expressions at specified indices within the given slice. Ignores +/// any invalid indices. +/// +/// # Parameters +/// +/// * `exprs` - A slice of expressions to index into. +/// * `indices` - A slice of indices specifying the positions of expressions sought. +/// +/// # Returns +/// +/// A vector of expressions corresponding to specified indices. +fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { + indices + .iter() + // Indices may point to further places than `exprs` len. + .filter_map(|&idx| exprs.get(idx).cloned()) + .collect() +} + +/// Merges two slices into a single vector with sorted (ascending) and +/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` +/// will produce `[1, 2, 3, 6]`. +fn merge_slices(left: &[T], right: &[T]) -> Vec { + // Make sure to sort before deduping, which removes the duplicates: + left.iter() + .cloned() + .chain(right.iter().cloned()) + .sorted() + .dedup() + .collect() +} + +/// Splits requirement indices for a join into left and right children based on +/// the join type. +/// +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. +/// +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. +/// +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. +/// +/// # Parameters +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: &[usize], + join_type: &JoinType, +) -> (Vec, Vec) { + match join_type { + // In these cases requirements are split between left/right children: + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + let (left_reqs, mut right_reqs): (Vec, Vec) = + indices.iter().partition(|&&idx| idx < left_len); + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + for idx in right_reqs.iter_mut() { + *idx -= left_len; + } + (left_reqs, right_reqs) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. + JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), + } +} + +/// Adds a projection on top of a logical plan if doing so reduces the number +/// of columns for the parent operator. +/// +/// This function takes a `LogicalPlan` and a list of projection expressions. +/// If the projection is beneficial (it reduces the number of columns in the +/// plan) a new `LogicalPlan` with the projection is created and returned, along +/// with a `true` flag. If the projection doesn't reduce the number of columns, +/// the original plan is returned with a `false` flag. +/// +/// # Parameters +/// +/// * `plan` - The input `LogicalPlan` to potentially add a projection to. +/// * `project_exprs` - A list of expressions for the projection. +/// +/// # Returns +/// +/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` +/// (with or without the added projection) and a `bool` flag indicating if a +/// projection was added (`true`) or not (`false`). +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + project_exprs: Vec, +) -> Result<(LogicalPlan, bool)> { + // Make sure projection decreases the number of columns, otherwise it is unnecessary. + if project_exprs.len() >= plan.schema().fields().len() { + Ok((plan, false)) + } else { + Projection::try_new(project_exprs, Arc::new(plan)) + .map(|proj| (LogicalPlan::Projection(proj), true)) + } +} + +/// Rewrite the given projection according to the fields required by its +/// ancestors. +/// +/// # Parameters +/// +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_projection_given_requirements( + proj: &Projection, + config: &dyn OptimizerConfig, + indices: &[usize], +) -> Result> { + let exprs_used = get_at_indices(&proj.expr, indices); + let required_indices = + indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; + return if let Some(input) = + optimize_projections(&proj.input, config, &required_indices)? + { + if &projection_schema(&input, &exprs_used)? == input.schema() { + Ok(Some(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else if exprs_used.len() < proj.expr.len() { + // Projection expression used is different than the existing projection. + // In this case, even if the child doesn't change, we should update the + // projection to use fewer columns: + if &projection_schema(&proj.input, &exprs_used)? == proj.input.schema() { + Ok(Some(proj.input.as_ref().clone())) + } else { + Projection::try_new(exprs_used, proj.input.clone()) + .map(|proj| Some(LogicalPlan::Projection(proj))) + } + } else { + // Projection doesn't change. + Ok(None) + }; +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::optimize_projections::OptimizeProjections; + use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, TableReference}; + use datafusion_expr::{ + binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, + table_scan, Expr, LogicalPlan, Operator, + }; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + } + + #[test] + fn merge_two_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_three_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .project(vec![col("a")])? + .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? + .build()?; + + let expected = "Projection: Int32(1) + test.a\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .project(vec![col("a").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn merge_nested_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("alias1").alias("alias2")])? + .project(vec![col("alias2").alias("alias")])? + .build()?; + + let expected = "Projection: test.a AS alias\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn test_nested_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("foo", DataType::Int32, false)]); + + let groups: Vec = vec![]; + + let plan = table_scan(TableReference::none(), &schema, None) + .unwrap() + .aggregate(groups.clone(), vec![count(lit(1))]) + .unwrap() + .aggregate(groups, vec![count(lit(1))]) + .unwrap() + .build() + .unwrap(); + + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Projection: \ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n TableScan: ?table? projection=[]"; + assert_optimized_plan_equal(&plan, expected) + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index f2e6c340d7380..0dc34cb809eb6 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,10 @@ //! Query optimizer traits +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Instant; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -24,16 +28,16 @@ use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; +use crate::eliminate_nested_union::EliminateNestedUnion; +use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; -use crate::eliminate_project::EliminateProjection; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; -use crate::merge_projection::MergeProjection; +use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; -use crate::push_down_projection::PushDownProjection; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -41,14 +45,14 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use chrono::{DateTime, Utc}; + +use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; + +use chrono::{DateTime, Utc}; use log::{debug, warn}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Instant; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient @@ -80,6 +84,9 @@ pub trait OptimizerConfig { /// time is used as the value for now() fn query_execution_start_time(&self) -> DateTime; + /// Return alias generator used to generate unique aliases for subqueries + fn alias_generator(&self) -> Arc; + fn options(&self) -> &ConfigOptions; } @@ -91,6 +98,9 @@ pub struct OptimizerContext { /// expressions such as `now()` to use a literal value instead query_execution_start_time: DateTime, + /// Alias generator used to generate unique aliases for subqueries + alias_generator: Arc, + options: ConfigOptions, } @@ -102,6 +112,7 @@ impl OptimizerContext { Self { query_execution_start_time: Utc::now(), + alias_generator: Arc::new(AliasGenerator::new()), options, } } @@ -148,6 +159,10 @@ impl OptimizerConfig for OptimizerContext { self.query_execution_start_time } + fn alias_generator(&self) -> Arc { + self.alias_generator.clone() + } + fn options(&self) -> &ConfigOptions { &self.options } @@ -208,6 +223,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(ReplaceDistinctWithAggregate::new()), @@ -219,7 +235,6 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), - Arc::new(MergeProjection::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -227,6 +242,8 @@ impl Optimizer { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(PropagateEmptyRelation::new()), + // Must be after PropagateEmptyRelation + Arc::new(EliminateOneUnion::new()), Arc::new(FilterNullJoinKeys::default()), Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit @@ -238,10 +255,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - // PushDownProjection can pushdown Projections through Limits, do PushDownLimit again. - Arc::new(PushDownLimit::new()), + Arc::new(OptimizeProjections::new()), ]; Self::with_rules(rules) @@ -276,11 +290,16 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { - let result = self.optimize_recursively(rule, &new_plan, config); - + let result = + self.optimize_recursively(rule, &new_plan, config) + .and_then(|plan| { + if let Some(plan) = &plan { + assert_schema_is_the_same(rule.name(), &new_plan, plan)?; + } + Ok(plan) + }); match result { Ok(Some(plan)) => { - assert_schema_is_the_same(rule.name(), &new_plan, &plan)?; new_plan = plan; observer(&new_plan, rule.as_ref()); log_plan(rule.name(), &new_plan); @@ -363,7 +382,7 @@ impl Optimizer { }) .collect::>(); - Ok(Some(plan.with_new_inputs(new_inputs.as_slice())?)) + Ok(Some(plan.with_new_inputs(&new_inputs)?)) } /// Use a rule to optimize the whole plan. @@ -405,7 +424,7 @@ impl Optimizer { /// Returns an error if plans have different schemas. /// /// It ignores metadata and nullability. -fn assert_schema_is_the_same( +pub(crate) fn assert_schema_is_the_same( rule_name: &str, prev_plan: &LogicalPlan, new_plan: &LogicalPlan, @@ -416,8 +435,7 @@ fn assert_schema_is_the_same( if !equivalent { let e = DataFusionError::Internal(format!( - "Optimizer rule '{}' failed, due to generate a different schema, original schema: {:?}, new schema: {:?}", - rule_name, + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", prev_plan.schema(), new_plan.schema() )); @@ -432,15 +450,18 @@ fn assert_schema_is_the_same( #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + + use super::ApplyOrder; use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; - use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; + + use datafusion_common::{ + plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; - use std::sync::{Arc, Mutex}; - - use super::ApplyOrder; #[test] fn skip_failing_rule() { @@ -465,7 +486,7 @@ mod tests { assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", - err.to_string() + err.strip_backtrace() ); } @@ -479,20 +500,31 @@ mod tests { }); let err = opt.optimize(&plan, &config, &observe).unwrap_err(); assert_eq!( - "get table_scan rule\ncaused by\n\ - Internal error: Optimizer rule 'get table_scan rule' failed, due to generate a different schema, \ - original schema: DFSchema { fields: [], metadata: {} }, \ + "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ + Internal error: Failed due to a difference in schemas, \ + original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }], \ - metadata: {} }. \ - This was likely caused by a bug in DataFusion's code \ + metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }.\ + \nThis was likely caused by a bug in DataFusion's code \ and we would welcome that you file an bug report in our issue tracker", - err.to_string() + err.strip_backtrace() ); } + #[test] + fn skip_generate_different_schema() { + let opt = Optimizer::with_rules(vec![Arc::new(GetTableScanRule {})]); + let config = OptimizerContext::new().with_skip_failing_rules(true); + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }); + opt.optimize(&plan, &config, &observe).unwrap(); + } + #[test] fn generate_same_schema_different_metadata() -> Result<()> { // if the plan creates more metadata than previously (because @@ -601,7 +633,7 @@ mod tests { _: &LogicalPlan, _: &dyn OptimizerConfig, ) -> Result> { - Err(DataFusionError::Plan("rule failed".to_string())) + plan_err!("rule failed") } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 01e16058ec32b..040b69fc8bf3f 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; use std::sync::Arc; @@ -156,9 +156,7 @@ fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { }; Ok((left_empty, right_empty)) } - _ => Err(DataFusionError::Plan( - "plan just can have two child".to_string(), - )), + _ => plan_err!("plan just can have two child"), } } @@ -177,21 +175,18 @@ fn empty_child(plan: &LogicalPlan) -> Result> { } _ => Ok(None), }, - _ => Err(DataFusionError::Plan( - "plan just can have one child".to_string(), - )), + _ => plan_err!("plan just can have one child"), } } #[cfg(test)] mod tests { use crate::eliminate_filter::EliminateFilter; - use crate::optimizer::Optimizer; + use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, test_table_scan, test_table_scan_fields, - test_table_scan_with_name, + assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, + test_table_scan_fields, test_table_scan_with_name, }; - use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFField, DFSchema, ScalarValue}; use datafusion_expr::logical_plan::table_scan; @@ -210,21 +205,15 @@ mod tests { plan: &LogicalPlan, expected: &str, ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![ - Arc::new(EliminateFilter::new()), - Arc::new(PropagateEmptyRelation::new()), - ]); - let config = &mut OptimizerContext::new() - .with_max_passes(1) - .with_skip_failing_rules(false); - let optimized_plan = optimizer - .optimize(plan, config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); - Ok(()) + assert_optimized_plan_eq_with_rules( + vec![ + Arc::new(EliminateFilter::new()), + Arc::new(EliminateNestedUnion::new()), + Arc::new(PropagateEmptyRelation::new()), + ], + plan, + expected, + ) } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 175c6a118c7c0..c090fb849a823 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -12,50 +12,116 @@ // specific language governing permissions and limitations // under the License. -//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan +//! [`PushDownFilter`] Moves filters so they are applied as early as possible in +//! the plan. use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, split_conjunction}; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFSchema, DataFusionError, Result}; +use datafusion_common::{ + internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, +}; +use datafusion_expr::expr::Alias; +use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; +use datafusion_expr::Volatility; use datafusion_expr::{ and, expr_rewriter::replace_col, logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, - utils::from_plan, - BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, + or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition, + TableProviderFilterPushDown, }; use itertools::Itertools; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -/// Push Down Filter optimizer rule pushes filter clauses down the plan +/// Optimizer rule for pushing (moving) filter expressions down in a plan so +/// they are applied as early as possible. +/// /// # Introduction -/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)). -/// An example of a filter-commutative operation is a projection; a counter-example is `limit`. /// -/// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B) -/// can commute with a filter that depends on A only, but does not commute with a filter that depends -/// on SUM(B). +/// The goal of this rule is to improve query performance by eliminating +/// redundant work. +/// +/// For example, given a plan that sorts all values where `a > 10`: +/// +/// ```text +/// Filter (a > 10) +/// Sort (a, b) +/// ``` +/// +/// A better plan is to filter the data *before* the Sort, which sorts fewer +/// rows and therefore does less work overall: +/// +/// ```text +/// Sort (a, b) +/// Filter (a > 10) <-- Filter is moved before the sort +/// ``` +/// +/// However it is not always possible to push filters down. For example, given a +/// plan that finds the top 3 values and then keeps only those that are greater +/// than 10, if the filter is pushed below the limit it would produce a +/// different result. +/// +/// ```text +/// Filter (a > 10) <-- can not move this Filter before the limit +/// Limit (fetch=3) +/// Sort (a, b) +/// ``` +/// +/// +/// More formally, a filter-commutative operation is an operation `op` that +/// satisfies `filter(op(data)) = op(filter(data))`. +/// +/// The filter-commutative property is plan and column-specific. A filter on `a` +/// can be pushed through a `Aggregate(group_by = [a], agg=[SUM(b))`. However, a +/// filter on `SUM(b)` can not be pushed through the same aggregate. +/// +/// # Handling Conjunctions +/// +/// It is possible to only push down **part** of a filter expression if is +/// connected with `AND`s (more formally if it is a "conjunction"). +/// +/// For example, given the following plan: +/// +/// ```text +/// Filter(a > 10 AND SUM(b) < 5) +/// Aggregate(group_by = [a], agg = [SUM(b)) +/// ``` +/// +/// The `a > 10` is commutative with the `Aggregate` but `SUM(b) < 5` is not. +/// Therefore it is possible to only push part of the expression, resulting in: +/// +/// ```text +/// Filter(SUM(b) < 5) +/// Aggregate(group_by = [a], agg = [SUM(b)) +/// Filter(a > 10) +/// ``` /// -/// This optimizer commutes filters with filter-commutative operations to push the filters -/// the closest possible to the scans, re-writing the filter expressions by every -/// projection that changes the filter's expression. +/// # Handling Column Aliases /// -/// Filter: b Gt Int64(10) -/// Projection: a AS b +/// This optimizer must sometimes handle re-writing filter expressions when they +/// pushed, for example if there is a projection that aliases `a+1` to `"b"`: /// -/// is optimized to +/// ```text +/// Filter (b > 10) +/// Projection: [a+1 AS "b"] <-- changes the name of `a+1` to `b` +/// ``` /// -/// Projection: a AS b -/// Filter: a Gt Int64(10) <--- changed from b to a +/// To apply the filter prior to the `Projection`, all references to `b` must be +/// rewritten to `a+1`: /// -/// This performs a single pass through the plan. When it passes through a filter, it stores that filter, -/// and when it reaches a node that does not commute with it, it adds the filter to that place. -/// When it passes through a projection, it re-writes the filter's expression taking into account that projection. -/// When multiple filters would have been written, it `AND` their expressions into a single expression. +/// ```text +/// Projection: a AS "b" +/// Filter: (a + 1 > 10) <--- changed from b to a + 1 +/// ``` +/// # Implementation Notes +/// +/// This implementation performs a single pass through the plan, "pushing" down +/// filters. When it passes through a filter, it stores that filter, and when it +/// reaches a plan node that does not commute with that filter, it adds the +/// filter to that place. When it passes through a projection, it re-writes the +/// filter's expression taking into account that projection. #[derive(Default)] pub struct PushDownFilter {} @@ -93,9 +159,7 @@ fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), }, LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => Err(DataFusionError::Internal( - "lr_is_preserved only valid for JOIN nodes".to_string(), - )), + _ => internal_err!("lr_is_preserved only valid for JOIN nodes"), } } @@ -113,12 +177,10 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { JoinType::LeftAnti => Ok((false, true)), JoinType::RightAnti => Ok((true, false)), }, - LogicalPlan::CrossJoin(_) => Err(DataFusionError::Internal( - "on_lr_is_preserved cannot be applied to CROSSJOIN nodes".to_string(), - )), - _ => Err(DataFusionError::Internal( - "on_lr_is_preserved only valid for JOIN nodes".to_string(), - )), + LogicalPlan::CrossJoin(_) => { + internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") + } + _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"), } } @@ -160,14 +222,16 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarUDF(..) => { + | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(_), + .. + }) => { is_evaluate = false; Ok(VisitRecursion::Stop) } - Expr::Alias(_, _) + Expr::Alias(_) | Expr::BinaryExpr(_) | Expr::Like(_) - | Expr::ILike(_) | Expr::SimilarTo(_) | Expr::Not(_) | Expr::IsNotNull(_) @@ -189,12 +253,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } - | Expr::GroupingSet(_) => Err(DataFusionError::Internal( - "Unsupported predicate type".to_string(), - )), + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -463,7 +523,7 @@ fn push_down_all_join( if !join_conditions.is_empty() { new_exprs.push(join_conditions.into_iter().reduce(Expr::and).unwrap()); } - let plan = from_plan(join_plan, &new_exprs, &[left, right])?; + let plan = join_plan.with_new_exprs(new_exprs, &[left, right])?; if keep_predicates.is_empty() { Ok(plan) @@ -485,9 +545,7 @@ fn push_down_join( parent_predicate: Option<&Expr>, ) -> Result> { let predicates = match parent_predicate { - Some(parent_predicate) => { - utils::split_conjunction_owned(parent_predicate.clone()) - } + Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), None => vec![], }; @@ -495,8 +553,8 @@ fn push_down_join( let on_filters = join .filter .as_ref() - .map(|e| utils::split_conjunction_owned(e.clone())) - .unwrap_or_else(Vec::new); + .map(|e| split_conjunction_owned(e.clone())) + .unwrap_or_default(); let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { @@ -617,7 +675,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| (*e).clone()) .collect::>(); let new_predicate = conjunction(new_predicates).ok_or_else(|| { - DataFusionError::Plan("at least one expression exists".to_string()) + plan_datafusion_err!("at least one expression exists") })?; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, @@ -658,32 +716,60 @@ impl OptimizerRule for PushDownFilter { child_plan.with_new_inputs(&[new_filter])? } LogicalPlan::Projection(projection) => { - // A projection is filter-commutable, but re-writes all predicate expressions + // A projection is filter-commutable if it do not contain volatile predicates or contain volatile + // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. - let replace_map = projection - .schema - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(expr, _) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - (field.qualified_name(), expr) - }) - .collect::>(); + let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = + projection + .schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + // strip alias, as they should not be part of filters + let expr = match &projection.expr[i] { + Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), + expr => expr.clone(), + }; + + (field.qualified_name(), expr) + }) + .partition(|(_, value)| is_volatile_expression(value)); - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(filter.predicate.clone(), &replace_map)?, - projection.input.clone(), - )?); + let mut push_predicates = vec![]; + let mut keep_predicates = vec![]; + for expr in split_conjunction_owned(filter.predicate.clone()).into_iter() + { + if contain(&expr, &volatile_map) { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } - child_plan.with_new_inputs(&[new_filter])? + match conjunction(push_predicates) { + Some(expr) => { + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + let new_filter = LogicalPlan::Filter(Filter::try_new( + replace_cols_by_name(expr, &non_volatile_map)?, + projection.input.clone(), + )?); + + match conjunction(keep_predicates) { + None => child_plan.with_new_inputs(&[new_filter])?, + Some(keep_predicate) => { + let child_plan = + child_plan.with_new_inputs(&[new_filter])?; + LogicalPlan::Filter(Filter::try_new( + keep_predicate, + Arc::new(child_plan), + )?) + } + } + } + None => return Ok(None), + } } LogicalPlan::Union(union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); @@ -716,7 +802,7 @@ impl OptimizerRule for PushDownFilter { .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) .collect::>>()?; - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -748,8 +834,7 @@ impl OptimizerRule for PushDownFilter { )?), None => (*agg.input).clone(), }; - let new_agg = - from_plan(&filter.input, &filter.input.expressions(), &vec![child])?; + let new_agg = filter.input.with_new_inputs(&vec![child])?; match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, @@ -765,7 +850,7 @@ impl OptimizerRule for PushDownFilter { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); push_down_all_join( predicates, vec![], @@ -781,7 +866,7 @@ impl OptimizerRule for PushDownFilter { let results = scan .source .supports_filters_pushdown(filter_predicates.as_slice())?; - let zip = filter_predicates.iter().zip(results.into_iter()); + let zip = filter_predicates.iter().zip(results); let new_scan_filters = zip .clone() @@ -820,7 +905,7 @@ impl OptimizerRule for PushDownFilter { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate.clone()); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; @@ -888,6 +973,58 @@ pub fn replace_cols_by_name( }) } +/// check whether the expression is volatile predicates +fn is_volatile_expression(e: &Expr) -> bool { + let mut is_volatile = false; + e.apply(&mut |expr| { + Ok(match expr { + Expr::ScalarFunction(f) => match &f.func_def { + ScalarFunctionDefinition::BuiltIn(fun) + if fun.volatility() == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::UDF(fun) + if fun.signature().volatility == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + _ => VisitRecursion::Continue, + }, + _ => VisitRecursion::Continue, + }) + }) + .unwrap(); + is_volatile +} + +/// check whether the expression uses the columns in `check_map`. +fn contain(e: &Expr, check_map: &HashMap) -> bool { + let mut is_contain = false; + e.apply(&mut |expr| { + Ok(if let Expr::Column(c) = &expr { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + VisitRecursion::Stop + } + None => VisitRecursion::Continue, + } + } else { + VisitRecursion::Continue + }) + }) + .unwrap(); + is_contain +} + #[cfg(test)] mod tests { use super::*; @@ -900,9 +1037,9 @@ mod tests { use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, - Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, - UserDefinedLogicalNodeCore, + and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum, + BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, + TableType, UserDefinedLogicalNodeCore, }; use std::fmt::{Debug, Formatter}; use std::sync::Arc; @@ -925,7 +1062,7 @@ mod tests { ]); let mut optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? @@ -953,8 +1090,7 @@ mod tests { // filter is before projection let expected = "\ Projection: test.a, test.b\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -981,9 +1117,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(lit(0i64).eq(lit(1i64)))? .build()?; - let expected = "\ - Filter: Int64(0) = Int64(1)\ - \n TableScan: test"; + let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -999,8 +1133,7 @@ mod tests { let expected = "\ Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1014,8 +1147,7 @@ mod tests { // filter of key aggregation is commutative let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ - \n Filter: test.a > Int64(10)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1040,8 +1172,7 @@ mod tests { .build()?; let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ - \n Filter: test.b + test.a > Int64(10)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1071,8 +1202,7 @@ mod tests { // filter is before projection let expected = "\ Projection: test.a AS b, test.c\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1116,8 +1246,7 @@ mod tests { // filter is before projection let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n Filter: test.a * Int32(2) + test.c = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1149,8 +1278,7 @@ mod tests { let expected = "\ Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1213,8 +1341,7 @@ mod tests { // Push filter below NoopPlan let expected = "\ NoopPlan\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { @@ -1231,8 +1358,7 @@ mod tests { let expected = "\ Filter: test.c = Int64(2)\ \n NoopPlan\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { @@ -1248,10 +1374,8 @@ mod tests { // Push filter below NoopPlan for each child branch let expected = "\ NoopPlan\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]\ + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { @@ -1268,10 +1392,8 @@ mod tests { let expected = "\ Filter: test.c = Int64(2)\ \n NoopPlan\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]\ + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1304,8 +1426,7 @@ mod tests { Filter: SUM(test.c) > Int64(10)\ \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ - \n Filter: test.a > Int64(10)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1339,8 +1460,7 @@ mod tests { Filter: SUM(test.c) > Int64(10) AND SUM(test.c) < Int64(20)\ \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ - \n Filter: test.a > Int64(10)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1376,10 +1496,8 @@ mod tests { .build()?; // filter appears below Union let expected = "Union\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test\ - \n Filter: test2.a = Int64(1)\ - \n TableScan: test2"; + \n TableScan: test, full_filters=[test.a = Int64(1)]\ + \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1399,12 +1517,10 @@ mod tests { // filter appears below Union let expected = "Union\n SubqueryAlias: test2\ \n Projection: test.a AS b\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n SubqueryAlias: test2\ \n Projection: test.a AS b\ - \n Filter: test.a = Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a = Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1432,11 +1548,9 @@ mod tests { let expected = "Projection: test.a, test1.d\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.a = Int32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ - \n Filter: test1.d > Int32(2)\ - \n TableScan: test1"; + \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1461,11 +1575,9 @@ mod tests { let expected = "Projection: test.a, test1.a\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.a = Int32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ - \n Filter: test1.a > Int32(2)\ - \n TableScan: test1"; + \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1498,8 +1610,7 @@ mod tests { \n Filter: test.a >= Int64(1)\ \n Limit: skip=0, fetch=1\ \n Projection: test.a\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test"; + \n TableScan: test, full_filters=[test.a <= Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1553,6 +1664,10 @@ mod tests { // not part of the test assert_eq!(format!("{plan:?}"), expected); + let expected = "\ + TestUserDefined\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]"; + assert_optimized_plan_eq(&plan, expected) } @@ -1588,11 +1703,9 @@ mod tests { // filter sent to side before the join let expected = "\ Inner Join: test.a = test2.a\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ - \n Filter: test2.a <= Int64(1)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1627,11 +1740,9 @@ mod tests { // filter sent to side before the join let expected = "\ Inner Join: Using test.a = test2.a\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ - \n Filter: test2.a <= Int64(1)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1713,8 +1824,7 @@ mod tests { let expected = "\ Inner Join: test.a = test2.a\ \n Projection: test.a, test.b\ - \n Filter: test.b <= Int64(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) @@ -1829,8 +1939,7 @@ mod tests { // filter sent to left side of the join, not the right let expected = "\ Left Join: Using test.a = test2.a\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) @@ -1870,8 +1979,7 @@ mod tests { Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ - \n Filter: test2.a <= Int64(1)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1912,11 +2020,9 @@ mod tests { let expected = "\ Inner Join: test.a = test2.a Filter: test.b < test2.b\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.c > UInt32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ - \n Filter: test2.c > UInt32(4)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1956,11 +2062,9 @@ mod tests { let expected = "\ Inner Join: test.a = test2.a\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.b > UInt32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ - \n Filter: test2.c > UInt32(4)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1998,11 +2102,9 @@ mod tests { let expected = "\ Inner Join: test.a = test2.b\ \n Projection: test.a\ - \n Filter: test.a > UInt32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ - \n Filter: test2.b > UInt32(1)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2045,8 +2147,7 @@ mod tests { \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ - \n Filter: test2.c > UInt32(4)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2087,8 +2188,7 @@ mod tests { let expected = "\ Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.a > UInt32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) @@ -2314,8 +2414,7 @@ Projection: a, b // rewrite filter col b to test.a let expected = "\ Projection: test.a AS b, test.c\ - \n Filter: test.a > Int64(10) AND test.c > Int64(10)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; assert_optimized_plan_eq(&plan, expected) @@ -2347,8 +2446,7 @@ Projection: a, b let expected = "\ Projection: b, test.c\ \n Projection: test.a AS b, test.c\ - \n Filter: test.a > Int64(10) AND test.c > Int64(10)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; assert_optimized_plan_eq(&plan, expected) @@ -2374,9 +2472,7 @@ Projection: a, b // rewrite filter col b to test.a, col d to test.c let expected = "\ Projection: test.a AS b, test.c AS d\ - \n Filter: test.a > Int64(10) AND test.c > Int64(10)\ - \n TableScan: test\ - "; + \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2415,11 +2511,9 @@ Projection: a, b let expected = "\ Inner Join: c = d\ \n Projection: test.a AS c\ - \n Filter: test.a > UInt32(1)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ - \n Filter: test2.b > UInt32(1)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2447,9 +2541,7 @@ Projection: a, b // rewrite filter col b to test.a let expected = "\ Projection: test.a AS b, test.c\ - \n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n TableScan: test\ - "; + \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; assert_optimized_plan_eq(&plan, expected) } @@ -2481,9 +2573,7 @@ Projection: a, b let expected = "\ Projection: b, test.c\ \n Projection: test.a AS b, test.c\ - \n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n TableScan: test\ - "; + \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; assert_optimized_plan_eq(&plan, expected) } @@ -2517,11 +2607,10 @@ Projection: a, b // rewrite filter col b to test.a let expected_after = "\ Projection: test.a AS b, test.c\ - \n Filter: test.a IN ()\ + \n TableScan: test, full_filters=[test.a IN ()]\ \n Subquery:\ \n Projection: sq.c\ - \n TableScan: sq\ - \n TableScan: test"; + \n TableScan: sq"; assert_optimized_plan_eq(&plan, expected_after) } @@ -2582,8 +2671,7 @@ Projection: a, b Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ - \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; @@ -2631,11 +2719,9 @@ Projection: a, b // Both side will be pushed down. let expected = "\ LeftSemi Join: test1.a = test2.a\ - \n Filter: test1.b > UInt32(1)\ - \n TableScan: test1\ + \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ - \n Filter: test2.b > UInt32(2)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2674,11 +2760,9 @@ Projection: a, b // Both side will be pushed down. let expected = "\ RightSemi Join: test1.a = test2.a\ - \n Filter: test1.b > UInt32(1)\ - \n TableScan: test1\ + \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ - \n Filter: test2.b > UInt32(2)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2724,8 +2808,7 @@ Projection: a, b \n Projection: test1.a, test1.b\ \n TableScan: test1\ \n Projection: test2.a, test2.b\ - \n Filter: test2.b > UInt32(2)\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; assert_optimized_plan_eq(&plan, expected) } @@ -2768,10 +2851,84 @@ Projection: a, b // For right anti, filter of the left side can be pushed down. let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\ \n Projection: test1.a, test1.b\ - \n Filter: test1.b > UInt32(1)\ - \n TableScan: test1\ + \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) } + + #[test] + fn test_push_down_volatile_function_in_aggregate() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .project(vec![ + col("a"), + sum(col("b")), + add(random(), lit(1)).alias("r"), + ])? + .alias("t")? + .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ + \n SubqueryAlias: t\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected_after = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.5)\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; + assert_optimized_plan_eq(&plan, expected_after) + } + + #[test] + fn test_push_down_volatile_function_in_join() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan).build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::Inner, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .project(vec![col("test1.a").alias("a"), random().alias("r")])? + .alias("t")? + .filter(col("t.r").gt(lit(0.8)))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.r > Float64(0.8)\ + \n SubqueryAlias: t\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.8)\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_optimized_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index c64dfc578b960..10cc1879aeeb9 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -18,593 +18,27 @@ //! Projection Push Down optimizer rule ensures that only referenced columns are //! loaded into memory -use crate::eliminate_project::can_eliminate; -use crate::optimizer::ApplyOrder; -use crate::push_down_filter::replace_cols_by_name; -use crate::{OptimizerConfig, OptimizerRule}; -use arrow::error::Result as ArrowResult; -use datafusion_common::ScalarValue::UInt8; -use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, -}; -use datafusion_expr::expr::AggregateFunction; -use datafusion_expr::utils::exprlist_to_fields; -use datafusion_expr::{ - logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union}, - utils::{expr_to_columns, exprlist_to_columns}, - Expr, LogicalPlanBuilder, SubqueryAlias, -}; -use std::collections::HashMap; -use std::{ - collections::{BTreeSet, HashSet}, - sync::Arc, -}; - -// if projection is empty return projection-new_plan, else return new_plan. -#[macro_export] -macro_rules! generate_plan { - ($projection_is_empty:expr, $plan:expr, $new_plan:expr) => { - if $projection_is_empty { - $new_plan - } else { - $plan.with_new_inputs(&[$new_plan])? - } - }; -} - -/// Optimizer that removes unused projections and aggregations from plans -/// This reduces both scans and -#[derive(Default)] -pub struct PushDownProjection {} - -impl OptimizerRule for PushDownProjection { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - let projection = match plan { - LogicalPlan::Projection(projection) => projection, - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - for e in agg.aggr_expr.iter().chain(agg.group_expr.iter()) { - expr_to_columns(e, &mut required_columns)? - } - let new_expr = get_expr(&required_columns, agg.input.schema())?; - let projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - agg.input.clone(), - )?); - let optimized_child = self - .try_optimize(&projection, _config)? - .unwrap_or(projection); - return Ok(Some(plan.with_new_inputs(&[optimized_child])?)); - } - LogicalPlan::TableScan(scan) if scan.projection.is_none() => { - return Ok(Some(push_down_scan(&HashSet::new(), scan, false)?)); - } - _ => return Ok(None), - }; - - let child_plan = &*projection.input; - let projection_is_empty = projection.expr.is_empty(); - - let new_plan = match child_plan { - LogicalPlan::Projection(child_projection) => { - // merge projection - let replace_map = collect_projection_expr(child_projection); - let new_exprs = projection - .expr - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .enumerate() - .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = - projection.schema.fields()[i].qualified_name(); - if e.display_name()? == parent_expr { - Ok(e) - } else { - Ok(e.alias(parent_expr)) - } - } - Err(e) => Err(e), - }) - .collect::>>()?; - let new_plan = LogicalPlan::Projection(Projection::try_new_with_schema( - new_exprs, - child_projection.input.clone(), - projection.schema.clone(), - )?); - - self.try_optimize(&new_plan, _config)?.unwrap_or(new_plan) - } - LogicalPlan::Join(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - for (l, r) in join.on.iter() { - expr_to_columns(l, &mut push_columns)?; - expr_to_columns(r, &mut push_columns)?; - } - if let Some(expr) = &join.filter { - expr_to_columns(expr, &mut push_columns)?; - } - - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::CrossJoin(join) => { - // collect column in on/filter in join and projection. - let mut push_columns: HashSet = HashSet::new(); - for e in projection.expr.iter() { - expr_to_columns(e, &mut push_columns)?; - } - let new_left = generate_projection( - &push_columns, - join.left.schema(), - join.left.clone(), - )?; - let new_right = generate_projection( - &push_columns, - join.right.schema(), - join.right.clone(), - )?; - let new_join = child_plan.with_new_inputs(&[new_left, new_right])?; - - generate_plan!(projection_is_empty, plan, new_join) - } - LogicalPlan::TableScan(scan) - if !scan.projected_schema.fields().is_empty() => - { - let mut used_columns: HashSet = HashSet::new(); - // filter expr may not exist in expr in projection. - // like: TableScan: t1 projection=[bool_col, int_col], full_filters=[t1.id = Int32(1)] - // projection=[bool_col, int_col] don't contain `ti.id`. - exprlist_to_columns(&scan.filters, &mut used_columns)?; - if projection_is_empty { - used_columns - .insert(scan.projected_schema.fields()[0].qualified_column()); - push_down_scan(&used_columns, scan, true)? - } else { - for expr in projection.expr.iter() { - expr_to_columns(expr, &mut used_columns)?; - } - let new_scan = push_down_scan(&used_columns, scan, true)?; - - plan.with_new_inputs(&[new_scan])? - } - } - LogicalPlan::Values(values) if projection_is_empty => { - let first_col = - Expr::Column(values.schema.fields()[0].qualified_column()); - LogicalPlan::Projection(Projection::try_new( - vec![first_col], - Arc::new(child_plan.clone()), - )?) - } - LogicalPlan::Union(union) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // When there is no projection, we need to add the first column to the projection - // Because if push empty down, children may output different columns. - if required_columns.is_empty() { - required_columns.insert(union.schema.fields()[0].qualified_column()); - } - // we don't push down projection expr, we just prune columns, so we just push column - // because push expr may cause more cost. - let projection_column_exprs = get_expr(&required_columns, &union.schema)?; - let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { - let mut replace_map = HashMap::new(); - for (i, field) in input.schema().fields().iter().enumerate() { - replace_map.insert( - union.schema.fields()[i].qualified_name(), - Expr::Column(field.qualified_column()), - ); - } - - let exprs = projection_column_exprs - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) - .collect::>>()?; - - inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new( - exprs, - input.clone(), - )?))) - } - // create schema of all used columns - let schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projection_column_exprs, child_plan)?, - union.schema.metadata().clone(), - )?; - let new_union = LogicalPlan::Union(Union { - inputs, - schema: Arc::new(schema), - }); - - generate_plan!(projection_is_empty, plan, new_union) - } - LogicalPlan::SubqueryAlias(subquery_alias) => { - let replace_map = generate_column_replace_map(subquery_alias); - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - - let new_required_columns = required_columns - .iter() - .map(|c| { - replace_map.get(c).cloned().ok_or_else(|| { - DataFusionError::Internal("replace column failed".to_string()) - }) - }) - .collect::>>()?; - - let new_expr = - get_expr(&new_required_columns, subquery_alias.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - subquery_alias.input.clone(), - )?); - let new_alias = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_alias) - } - LogicalPlan::Aggregate(agg) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Aggregate - let mut new_aggr_expr = vec![]; - for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_aggr_expr.push(e.clone()); - } - } - - // if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it - if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 { - if let Expr::AggregateFunction(AggregateFunction { - fun, args, .. - }) = &agg.aggr_expr[0] - { - if matches!(fun, datafusion_expr::AggregateFunction::Count) - && args.len() == 1 - && args[0] == Expr::Literal(UInt8(Some(1))) - { - new_aggr_expr.push(agg.aggr_expr[0].clone()); - } - } - } - - let new_agg = LogicalPlan::Aggregate(Aggregate::try_new( - agg.input.clone(), - agg.group_expr.clone(), - new_aggr_expr, - )?); - - generate_plan!(projection_is_empty, plan, new_agg) - } - LogicalPlan::Window(window) => { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - // Gather all columns needed for expressions in this Window - let mut new_window_expr = vec![]; - for e in window.window_expr.iter() { - let column = Column::from_name(e.display_name()?); - if required_columns.contains(&column) { - new_window_expr.push(e.clone()); - } - } - - if new_window_expr.is_empty() { - // none columns in window expr are needed, remove the window expr - let input = window.input.clone(); - let new_window = restrict_outputs(input.clone(), &required_columns)? - .unwrap_or((*input).clone()); - - generate_plan!(projection_is_empty, plan, new_window) - } else { - let mut referenced_inputs = HashSet::new(); - exprlist_to_columns(&new_window_expr, &mut referenced_inputs)?; - window - .input - .schema() - .fields() - .iter() - .filter(|f| required_columns.contains(&f.qualified_column())) - .for_each(|f| { - referenced_inputs.insert(f.qualified_column()); - }); - - let input = window.input.clone(); - let new_input = restrict_outputs(input.clone(), &referenced_inputs)? - .unwrap_or((*input).clone()); - let new_window = LogicalPlanBuilder::from(new_input) - .window(new_window_expr)? - .build()?; - - generate_plan!(projection_is_empty, plan, new_window) - } - } - LogicalPlan::Filter(filter) => { - if can_eliminate(projection, child_plan.schema()) { - // when projection schema == filter schema, we can commute directly. - let new_proj = - plan.with_new_inputs(&[filter.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns( - &[filter.predicate.clone()], - &mut required_columns, - )?; - - let new_expr = get_expr(&required_columns, filter.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - filter.input.clone(), - )?); - let new_filter = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_filter) - } - } - LogicalPlan::Sort(sort) => { - if can_eliminate(projection, child_plan.schema()) { - // can commute - let new_proj = plan.with_new_inputs(&[(*sort.input).clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } else { - let mut required_columns = HashSet::new(); - exprlist_to_columns(&projection.expr, &mut required_columns)?; - exprlist_to_columns(&sort.expr, &mut required_columns)?; - - let new_expr = get_expr(&required_columns, sort.input.schema())?; - let new_projection = LogicalPlan::Projection(Projection::try_new( - new_expr, - sort.input.clone(), - )?); - let new_sort = child_plan.with_new_inputs(&[new_projection])?; - - generate_plan!(projection_is_empty, plan, new_sort) - } - } - LogicalPlan::Limit(limit) => { - // can commute - let new_proj = plan.with_new_inputs(&[limit.input.as_ref().clone()])?; - child_plan.with_new_inputs(&[new_proj])? - } - _ => return Ok(None), - }; - - Ok(Some(new_plan)) - } - - fn name(&self) -> &str { - "push_down_projection" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -impl PushDownProjection { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn generate_column_replace_map( - subquery_alias: &SubqueryAlias, -) -> HashMap { - subquery_alias - .input - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - ( - subquery_alias.schema.fields()[i].qualified_column(), - field.qualified_column(), - ) - }) - .collect() -} - -pub fn collect_projection_expr(projection: &Projection) -> HashMap { - projection - .schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(expr, _) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>() -} - -// Get the projection exprs from columns in the order of the schema -fn get_expr(columns: &HashSet, schema: &DFSchemaRef) -> Result> { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let qc = field.qualified_column(); - let uqc = field.unqualified_column(); - if columns.contains(&qc) || columns.contains(&uqc) { - Some(Expr::Column(qc)) - } else { - None - } - }) - .collect::>(); - if columns.len() != expr.len() { - Err(DataFusionError::Plan(format!( - "required columns can't push down, columns: {columns:?}" - ))) - } else { - Ok(expr) - } -} - -fn generate_projection( - used_columns: &HashSet, - schema: &DFSchemaRef, - input: Arc, -) -> Result { - let expr = schema - .fields() - .iter() - .flat_map(|field| { - let column = field.qualified_column(); - if used_columns.contains(&column) { - Some(Expr::Column(column)) - } else { - None - } - }) - .collect::>(); - - Ok(LogicalPlan::Projection(Projection::try_new(expr, input)?)) -} - -fn push_down_scan( - used_columns: &HashSet, - scan: &TableScan, - has_projection: bool, -) -> Result { - // once we reach the table scan, we can use the accumulated set of column - // names to construct the set of column indexes in the scan - // - // we discard non-existing columns because some column names are not part of the schema, - // e.g. when the column derives from an aggregation - // - // Use BTreeSet to remove potential duplicates (e.g. union) as - // well as to sort the projection to ensure deterministic behavior - let schema = scan.source.schema(); - let mut projection: BTreeSet = used_columns - .iter() - .filter(|c| { - c.relation.is_none() || c.relation.as_ref().unwrap() == &scan.table_name - }) - .map(|c| schema.index_of(&c.name)) - .filter_map(ArrowResult::ok) - .collect(); - - if projection.is_empty() { - if has_projection && !schema.fields().is_empty() { - // Ensure that we are reading at least one column from the table in case the query - // does not reference any columns directly such as "SELECT COUNT(1) FROM table", - // except when the table is empty (no column) - projection.insert(0); - } else { - // for table scan without projection, we default to return all columns - projection = scan - .source - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, _)| i) - .collect::>(); - } - } - - // Building new projection from BTreeSet - // preserving source projection order if it exists - let projection = if let Some(original_projection) = &scan.projection { - original_projection - .clone() - .into_iter() - .filter(|idx| projection.contains(idx)) - .collect::>() - } else { - projection.into_iter().collect::>() - }; - - // create the projected schema - let projected_fields: Vec = projection - .iter() - .map(|i| { - DFField::from_qualified(scan.table_name.clone(), schema.fields()[*i].clone()) - }) - .collect(); - - let projected_schema = projected_fields.to_dfschema_ref()?; - - Ok(LogicalPlan::TableScan(TableScan { - table_name: scan.table_name.clone(), - source: scan.source.clone(), - projection: Some(projection), - projected_schema, - filters: scan.filters.clone(), - fetch: scan.fetch, - })) -} - -fn restrict_outputs( - plan: Arc, - permitted_outputs: &HashSet, -) -> Result> { - let schema = plan.schema(); - if permitted_outputs.len() == schema.fields().len() { - return Ok(None); - } - Ok(Some(generate_projection( - permitted_outputs, - schema, - plan.clone(), - )?)) -} - #[cfg(test)] mod tests { - use super::*; - use crate::eliminate_project::EliminateProjection; + use std::collections::HashMap; + use std::sync::Arc; + use std::vec; + + use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::DFSchema; - use datafusion_expr::expr; - use datafusion_expr::expr::Cast; - use datafusion_expr::WindowFrame; - use datafusion_expr::WindowFunction; + use datafusion_common::{Column, DFField, DFSchema, Result}; + use datafusion_expr::builder::table_scan_with_filters; + use datafusion_expr::expr::{self, Cast}; + use datafusion_expr::logical_plan::{ + builder::LogicalPlanBuilder, table_scan, JoinType, + }; use datafusion_expr::{ - col, count, lit, - logical_plan::{builder::LogicalPlanBuilder, table_scan, JoinType}, - max, min, AggregateFunction, Expr, + col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, + WindowFrame, WindowFunction, }; - use std::collections::HashMap; - use std::vec; #[test] fn aggregate_no_group_by() -> Result<()> { @@ -667,6 +101,31 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[test] + fn aggregate_with_periods() -> Result<()> { + let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); + + // Build a plan that looks as follows (note "tag.one" is a column named + // "tag.one", not a column named "one" in a table named "tag"): + // + // Projection: tag.one + // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // TableScan + let plan = table_scan(Some("m4"), &schema, None)? + .aggregate( + Vec::::new(), + vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], + )? + .project([col(Column::new_unqualified("tag.one"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + \n TableScan: m4 projection=[tag.one]"; + + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn redundant_project() -> Result<()> { let table_scan = test_table_scan()?; @@ -904,7 +363,7 @@ mod tests { // Build the LogicalPlan directly (don't use PlanBuilder), so // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions - let expr = vec![col("a"), col("b")]; + let expr = vec![col("test.a"), col("test.b")]; let plan = LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); @@ -951,7 +410,7 @@ mod tests { .project(vec![lit(1_i64), lit(2_i64)])? .build()?; let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[a]"; + \n TableScan: test projection=[]"; assert_optimized_plan_eq(&plan, expected) } @@ -998,7 +457,36 @@ mod tests { let expected = "\ Projection: Int32(1) AS a\ - \n TableScan: test projection=[a]"; + \n TableScan: test projection=[]"; + + assert_optimized_plan_eq(&plan, expected) + } + + #[test] + fn table_full_filter_pushdown() -> Result<()> { + let schema = Schema::new(test_table_scan_fields()); + + let table_scan = table_scan_with_filters( + Some("test"), + &schema, + None, + vec![col("b").eq(lit(1))], + )? + .build()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // there is no need for the first projection + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? + .build()?; + + assert_fields_eq(&plan, vec!["a"]); + + let expected = "\ + Projection: Int32(1) AS a\ + \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; assert_optimized_plan_eq(&plan, expected) } @@ -1134,24 +622,14 @@ mod tests { } fn optimize(plan: &LogicalPlan) -> Result { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownProjection::new()), - Arc::new(EliminateProjection::new()), - ]); - let mut optimized_plan = optimizer + let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); + let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or(optimized_plan); Ok(optimized_plan) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index f58d4b159745d..187e510e557db 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::optimizer::ApplyOrder; +use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::Distinct; -use datafusion_expr::{Aggregate, LogicalPlan}; -use ApplyOrder::BottomUp; +use datafusion_expr::{ + aggregate_function::AggregateFunction as AggregateFunctionFunc, col, + expr::AggregateFunction, LogicalPlanBuilder, +}; +use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -33,6 +36,22 @@ use ApplyOrder::BottomUp; /// ```text /// SELECT a, b FROM tab GROUP BY a, b /// ``` +/// +/// On the other hand, for a `DISTINCT ON` query the replacement is +/// a bit more involved and effectively converts +/// ```text +/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c +/// ``` +/// +/// into +/// ```text +/// SELECT b FROM ( +/// SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b +/// FROM tab +/// GROUP BY a +/// ) +/// ORDER BY a DESC +/// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default)] @@ -52,16 +71,74 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), group_expr, vec![], - input.schema().clone(), // input schema and aggregate schema are the same in this case )?); Ok(Some(aggregate)) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + })) => { + // Construct the aggregation expression to be used to fetch the selected expressions. + let aggr_expr = select_expr + .iter() + .map(|e| { + Expr::AggregateFunction(AggregateFunction::new( + AggregateFunctionFunc::FirstValue, + vec![e.clone()], + false, + None, + sort_expr.clone(), + )) + }) + .collect::>(); + + // Build the aggregation plan + let plan = LogicalPlanBuilder::from(input.as_ref().clone()) + .aggregate(on_expr.clone(), aggr_expr.to_vec())? + .build()?; + + let plan = if let Some(sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + LogicalPlanBuilder::from(plan) + .sort(sort_expr[..on_expr.len()].to_vec())? + .build()? + } else { + plan + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan + .schema() + .fields() + .iter() + .skip(on_expr.len()) + .zip(schema.fields().iter()) + .map(|(new_field, old_field)| { + Ok(col(new_field.qualified_column()).alias_qualified( + old_field.qualifier().cloned(), + old_field.name(), + )) + }) + .collect::>>()?; + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Some(plan)) + } _ => Ok(None), } } @@ -100,4 +177,27 @@ mod tests { expected, ) } + + #[test] + fn replace_distinct_on() -> datafusion_common::Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on( + vec![col("a")], + vec![col("b")], + Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), + )? + .build()?; + + let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ + \n Sort: test.a DESC NULLS FIRST\ + \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } } diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 57513fa4fff41..90c96b4b8b8cb 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -119,7 +119,7 @@ pub struct RewriteDisjunctivePredicate; impl RewriteDisjunctivePredicate { pub fn new() -> Self { - Self::default() + Self } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 04e0e0920b0c9..34ed4a9475cba 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,25 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::alias::AliasGenerator; +use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; -use crate::utils::{ - collect_subquery_cols, conjunction, extract_join_filters, only_or_err, - replace_qualified_name, -}; +use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{context, Column, Result}; +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{ + RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, +}; +use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; -use log::debug; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] -pub struct ScalarSubqueryToJoin { - alias: Arc, -} +pub struct ScalarSubqueryToJoin {} impl ScalarSubqueryToJoin { #[allow(missing_docs)] @@ -65,12 +65,14 @@ impl OptimizerRule for ScalarSubqueryToJoin { fn try_optimize( &self, plan: &LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { - let (subqueries, expr) = - self.extract_subquery_exprs(&filter.predicate, self.alias.clone())?; + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( + &filter.predicate, + config.alias_generator(), + )?; if subqueries.is_empty() { // regular filter, no subquery exists clause here @@ -80,26 +82,48 @@ impl OptimizerRule for ScalarSubqueryToJoin { // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = filter.input.as_ref().clone(); for (subquery, alias) in subqueries { - if let Some(optimized_subquery) = - optimize_scalar(&subquery, &cur_input, &alias)? + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? { + if !expr_check_map.is_empty() { + rewrite_expr = + rewrite_expr.clone().transform_up(&|expr| { + if let Expr::Column(col) = &expr { + if let Some(map_expr) = + expr_check_map.get(&col.name) + { + Ok(Transformed::Yes(map_expr.clone())) + } else { + Ok(Transformed::No(expr)) + } + } else { + Ok(Transformed::No(expr)) + } + })?; + } cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now return Ok(None); } } - let new_plan = LogicalPlanBuilder::from(cur_input); - Ok(Some(new_plan.filter(expr)?.build()?)) + let new_plan = LogicalPlanBuilder::from(cur_input) + .filter(rewrite_expr)? + .build()?; + Ok(Some(new_plan)) } LogicalPlan::Projection(projection) => { let mut all_subqueryies = vec![]; - let mut rewrite_exprs = vec![]; + let mut expr_to_rewrite_expr_map = HashMap::new(); + let mut subquery_to_expr_map = HashMap::new(); for expr in projection.expr.iter() { - let (subqueries, expr) = - self.extract_subquery_exprs(expr, self.alias.clone())?; + let (subqueries, rewrite_exprs) = + self.extract_subquery_exprs(expr, config.alias_generator())?; + for (subquery, _) in &subqueries { + subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + } all_subqueryies.extend(subqueries); - rewrite_exprs.push(expr); + expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } if all_subqueryies.is_empty() { // regular projection, no subquery exists clause here @@ -108,17 +132,54 @@ impl OptimizerRule for ScalarSubqueryToJoin { // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); for (subquery, alias) in all_subqueryies { - if let Some(optimized_subquery) = - optimize_scalar(&subquery, &cur_input, &alias)? + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? { cur_input = optimized_subquery; + if !expr_check_map.is_empty() { + if let Some(expr) = subquery_to_expr_map.get(&subquery) { + if let Some(rewrite_expr) = + expr_to_rewrite_expr_map.get(expr) + { + let new_expr = + rewrite_expr.clone().transform_up(&|expr| { + if let Expr::Column(col) = &expr { + if let Some(map_expr) = + expr_check_map.get(&col.name) + { + Ok(Transformed::Yes(map_expr.clone())) + } else { + Ok(Transformed::No(expr)) + } + } else { + Ok(Transformed::No(expr)) + } + })?; + expr_to_rewrite_expr_map.insert(expr, new_expr); + } + } + } } else { // if we can't handle all of the subqueries then bail for now return Ok(None); } } - let new_plan = LogicalPlanBuilder::from(cur_input); - Ok(Some(new_plan.project(rewrite_exprs)?.build()?)) + + let mut proj_exprs = vec![]; + for expr in projection.expr.iter() { + let old_expr_name = expr.display_name()?; + let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); + let new_expr_name = new_expr.display_name()?; + if new_expr_name != old_expr_name { + proj_exprs.push(new_expr.clone().alias(old_expr_name)) + } else { + proj_exprs.push(new_expr.clone()); + } + } + let new_plan = LogicalPlanBuilder::from(cur_input) + .project(proj_exprs)? + .build()?; + Ok(Some(new_plan)) } _ => Ok(None), @@ -153,9 +214,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); - self.sub_query_info.push((subquery, subqry_alias.clone())); - let scalar_column = "__value"; - Ok(Expr::Column(Column::new(Some(subqry_alias), scalar_column))) + self.sub_query_info + .push((subquery.clone(), subqry_alias.clone())); + let scalar_expr = subquery + .subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), Ok)?; + Ok(Expr::Column(create_col_from_scalar_expr( + &scalar_expr, + subqry_alias, + )?)) } _ => Ok(expr), } @@ -188,7 +256,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { /// /// ```text /// select c.id from customers c -/// cross join (select avg(total) as val from orders) a +/// left join (select avg(total) as val from orders) a /// where c.balance > a.val /// ``` /// @@ -198,106 +266,112 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { /// * `filter_input` - The non-subquery portion (from customers) /// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `subquery_alias` - Subquery aliases -fn optimize_scalar( +fn build_join( subquery: &Subquery, filter_input: &LogicalPlan, subquery_alias: &str, -) -> Result> { +) -> Result)>> { let subquery_plan = subquery.subquery.as_ref(); - let proj = match &subquery_plan { - LogicalPlan::Projection(proj) => proj, - _ => { - // this rule does not support this type of scalar subquery - // TODO support more types - debug!( - "cannot translate this type of scalar subquery to a join: {}", - subquery_plan.display_indent() - ); - return Ok(None); - } - }; - let proj = only_or_err(proj.expr.as_slice()) - .map_err(|e| context!("exactly one expression should be projected", e))?; - let proj = Expr::Alias(Box::new(proj.clone()), "__value".to_string()); - let sub_inputs = subquery_plan.inputs(); - let sub_input = only_or_err(sub_inputs.as_slice()) - .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?; - - let aggr = match sub_input { - LogicalPlan::Aggregate(aggr) => aggr, - _ => { - // this rule does not support this type of scalar subquery - // TODO support more types - debug!( - "cannot translate this type of scalar subquery to a join: {}", - subquery_plan.display_indent() - ); - return Ok(None); - } + let mut pull_up = PullUpCorrelatedExpr { + join_filters: vec![], + correlated_subquery_cols_map: Default::default(), + in_predicate_opt: None, + exists_sub_query: false, + can_pull_up: true, + need_handle_count_bug: true, + collected_count_expr_map: Default::default(), + pull_up_having_expr: None, }; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?; + if !pull_up.can_pull_up { + return Ok(None); + } - // extract join filters - let (join_filters, subquery_input) = extract_join_filters(&aggr.input)?; - // Only operate if one column is present and the other closed upon from outside scope - let input_schema = subquery_input.schema(); - let subqry_cols = collect_subquery_cols(&join_filters, input_schema.clone())?; - let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &subqry_cols, subquery_alias).map(Option::Some) - })?; - - let group_by: Vec<_> = subqry_cols - .iter() - .map(|it| Expr::Column(it.clone())) - .collect(); - let subqry_plan = LogicalPlanBuilder::from(subquery_input); - - // project the prior projection + any correlated (and now grouped) columns - let proj: Vec<_> = group_by - .iter() - .cloned() - .chain(vec![proj].iter().cloned()) - .collect(); - let subqry_plan = subqry_plan - .aggregate(group_by, aggr.aggr_expr.clone())? - .project(proj)? + let collected_count_expr_map = + pull_up.collected_count_expr_map.get(&new_plan).cloned(); + let sub_query_alias = LogicalPlanBuilder::from(new_plan) .alias(subquery_alias.to_string())? .build()?; + let mut all_correlated_cols = BTreeSet::new(); + pull_up + .correlated_subquery_cols_map + .values() + .for_each(|cols| all_correlated_cols.extend(cols.clone())); + + // alias the join filter + let join_filter_opt = + conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, subquery_alias) + .map(Option::Some) + })?; + // join our sub query into the main plan - let new_plan = if join_filter.is_none() { + let new_plan = if join_filter_opt.is_none() { match filter_input { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: true, schema: _, - }) => subqry_plan, + }) => sub_query_alias, _ => { - // if not correlated, group down to 1 row and cross join on that (preserving row count) + // if not correlated, group down to 1 row and left join on that (preserving row count) LogicalPlanBuilder::from(filter_input.clone()) - .cross_join(subqry_plan)? + .join_on(sub_query_alias, JoinType::Left, None)? .build()? } } } else { // left join if correlated, grouping by the join keys so we don't change row count LogicalPlanBuilder::from(filter_input.clone()) - .join( - subqry_plan, - JoinType::Left, - (Vec::::new(), Vec::::new()), - join_filter, - )? + .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? .build()? }; + let mut computation_project_expr = HashMap::new(); + if let Some(expr_map) = collected_count_expr_map { + for (name, result) in expr_map { + let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr { + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + ( + Box::new(Expr::IsNull(Box::new(Expr::Column( + Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), + )))), + Box::new(result), + ), + ( + Box::new(Expr::Not(Box::new(filter.clone()))), + Box::new(Expr::Literal(ScalarValue::Null)), + ), + ], + else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( + name.clone(), + )))), + }) + } else { + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNull(Box::new(Expr::Column( + Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), + )))), + Box::new(result), + )], + else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( + name.clone(), + )))), + }) + }; + computation_project_expr.insert(name, computer_expr); + } + } - Ok(Some(new_plan)) + Ok(Some((new_plan, computation_project_expr))) } #[cfg(test)] mod tests { use super::*; - use crate::eliminate_cross_join::EliminateCrossJoin; - use crate::eliminate_outer_join::EliminateOuterJoin; - use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::test::*; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -331,24 +405,20 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_1.__value AND Int32(1) < __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ - \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: Int32(1) < __scalar_sq_1.MAX(orders.o_custkey) AND Int32(1) < __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(EliminateOuterJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -390,25 +460,21 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ + \n Filter: customer.c_acctbal < __scalar_sq_1.SUM(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, SUM(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, SUM(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Float64;N]\ - \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value [o_custkey:Int64, __value:Float64;N]\ + \n SubqueryAlias: __scalar_sq_1 [SUM(orders.o_totalprice):Float64;N, o_custkey:Int64]\ + \n Projection: SUM(orders.o_totalprice), orders.o_custkey [SUM(orders.o_totalprice):Float64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64;N, __value:Float64;N]\ - \n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64;N, __value:Float64;N]\ + \n Filter: orders.o_totalprice < __scalar_sq_2.SUM(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [l_orderkey:Int64, __value:Float64;N]\ - \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS __value [l_orderkey:Int64, __value:Float64;N]\ + \n SubqueryAlias: __scalar_sq_2 [SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ + \n Projection: SUM(lineitem.l_extendedprice), lineitem.l_orderkey [SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -436,21 +502,17 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -478,18 +540,15 @@ mod tests { // it will optimize, but fail for the same reason the unoptimized query would let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ + \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateCrossJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -513,20 +572,17 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ + \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateCrossJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -629,24 +685,6 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // we expect the plan to be unchanged because this subquery is not supported by this rule - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] - Subquery: [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], - &plan, - expected, - ); - let expected = "check_analyzed_plan\ \ncaused by\ \nError during planning: Scalar subquery should only return one column"; @@ -674,20 +712,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) + Int32(1) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey) + Int32(1), orders.o_custkey [MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -744,20 +778,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.__value AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey >= __scalar_sq_1.MAX(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -787,20 +817,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(EliminateOuterJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -831,20 +857,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ - \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateCrossJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -868,20 +890,16 @@ mod tests { .build()?; let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.__value [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, __value:UInt32;N]\ - \n Inner Join: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, __value:UInt32;N]\ + \n Filter: test.c < __scalar_sq_1.MIN(sq.c) [a:UInt32, b:UInt32, c:UInt32, MIN(sq.c):UInt32;N, a:UInt32;N]\ + \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, MIN(sq.c):UInt32;N, a:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __scalar_sq_1 [a:UInt32, __value:UInt32;N]\ - \n Projection: sq.a, MIN(sq.c) AS __value [a:UInt32, __value:UInt32;N]\ + \n SubqueryAlias: __scalar_sq_1 [MIN(sq.c):UInt32;N, a:UInt32]\ + \n Projection: MIN(sq.c), sq.a [MIN(sq.c):UInt32;N, a:UInt32]\ \n Aggregate: groupBy=[[sq.a]], aggr=[[MIN(sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -904,20 +922,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey < __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ + \n Filter: customer.c_custkey < __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ + \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateCrossJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -939,19 +953,16 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ + \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateCrossJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -994,25 +1005,21 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.__value AND __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64;N, __value:Int64;N]\ - \n Left Join: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64;N, __value:Int64;N]\ - \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.MIN(orders.o_custkey) AND __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MIN(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MIN(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MIN(orders.o_custkey), orders.o_custkey [MIN(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MIN(orders.o_custkey)]] [o_custkey:Int64, MIN(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); @@ -1047,25 +1054,21 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.__value AND __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N, __value:Int64;N]\ - \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N, __value:Int64;N]\ - \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.MIN(orders.o_custkey) AND __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, MAX(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MIN(orders.o_custkey) AS __value [__value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [MIN(orders.o_custkey):Int64;N]\ + \n Projection: MIN(orders.o_custkey) [MIN(orders.o_custkey):Int64;N]\ \n Aggregate: groupBy=[[]], aggr=[[MIN(orders.o_custkey)]] [MIN(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N]\ + \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( - vec![ - Arc::new(ScalarSubqueryToJoin::new()), - Arc::new(ExtractEquijoinPredicate::new()), - Arc::new(EliminateOuterJoin::new()), - ], + vec![Arc::new(ScalarSubqueryToJoin::new())], &plan, expected, ); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8aebae18c1ae9..e2fbd5e927a16 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,32 +21,43 @@ use std::ops::Not; use super::or_in_list_simplifier::OrInListSimplifier; use super::utils::*; - use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::new_null_array, datatypes::{DataType, Field, Schema}, error::ArrowError, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, +}; +use datafusion_common::{ + exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, + ScalarFunctionDefinition, Volatility, +}; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like, - Volatility, + expr::{InList, InSubquery, ScalarFunction}, + interval_arithmetic::NullableInterval, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::simplify_expressions::SimplifyInfo; - /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, + /// Guarantees about the values of columns. This is provided by the user + /// in [ExprSimplifier::with_guarantees()]. + guarantees: Vec<(Expr, NullableInterval)>, } -const THRESHOLD_INLINE_INLIST: usize = 3; +pub const THRESHOLD_INLINE_INLIST: usize = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an @@ -55,7 +66,10 @@ impl ExprSimplifier { /// /// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext pub fn new(info: S) -> Self { - Self { info } + Self { + info, + guarantees: vec![], + } } /// Simplifies this [`Expr`]`s as much as possible, evaluating @@ -119,6 +133,7 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut or_in_list_simplifier = OrInListSimplifier::new(); + let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and @@ -127,6 +142,7 @@ impl ExprSimplifier { expr.rewrite(&mut const_evaluator)? .rewrite(&mut simplifier)? .rewrite(&mut or_in_list_simplifier)? + .rewrite(&mut guarantee_rewriter)? // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? .rewrite(&mut simplifier) @@ -147,6 +163,65 @@ impl ExprSimplifier { expr.rewrite(&mut expr_rewrite) } + + /// Input guarantees about the values of columns. + /// + /// The guarantees can simplify expressions. For example, if a column `x` is + /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the + /// literal `true`. + /// + /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`, + /// where the [Expr] is a column reference and the [NullableInterval] + /// is an interval representing the known possible values of that column. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_optimizer::simplify_expressions::{ + /// ExprSimplifier, SimplifyContext}; + /// + /// let schema = Schema::new(vec![ + /// Field::new("x", DataType::Int64, false), + /// Field::new("y", DataType::UInt32, false), + /// Field::new("z", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// + /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) + /// let expr_x = col("x").gt_eq(lit(3_i64)); + /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32)); + /// let expr_z = col("z").gt(lit(5_i64)); + /// let expr = expr_x.and(expr_y).and(expr_z.clone()); + /// + /// let guarantees = vec![ + /// // x ∈ [3, 5] + /// ( + /// col("x"), + /// NullableInterval::NotNull { + /// values: Interval::make(Some(3_i64), Some(5_i64)).unwrap() + /// } + /// ), + /// // y = 3 + /// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))), + /// ]; + /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees); + /// let output = simplifier.simplify(expr).unwrap(); + /// // Expression becomes: true AND true AND (z > 5), which simplifies to + /// // z > 5. + /// assert_eq!(output, expr_z); + /// ``` + pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self { + self.guarantees = guarantees; + self + } } #[allow(rustdoc::private_intra_doc_links)] @@ -208,9 +283,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { match self.can_evaluate.pop() { Some(true) => Ok(Expr::Literal(self.evaluate_to_scalar(expr)?)), Some(false) => Ok(expr), - _ => Err(DataFusionError::Internal( - "Failed to pop can_evaluate".to_string(), - )), + _ => internal_err!("Failed to pop can_evaluate"), } } } @@ -259,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) @@ -269,15 +341,17 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::Sort { .. } | Expr::GroupingSet(_) - | Expr::Wildcard - | Expr::QualifiedWildcard { .. } + | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - Self::volatility_ok(fun.volatility()) - } - Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature.volatility) - } + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Self::volatility_ok(fun.volatility()) + } + ScalarFunctionDefinition::UDF(fun) => { + Self::volatility_ok(fun.signature().volatility) + } + ScalarFunctionDefinition::Name(_) => false, + }, Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) @@ -292,7 +366,6 @@ impl<'a> ConstEvaluator<'a> { | Expr::Negative(_) | Expr::Between { .. } | Expr::Like { .. } - | Expr::ILike { .. } | Expr::SimilarTo { .. } | Expr::Case(_) | Expr::Cast { .. } @@ -318,12 +391,15 @@ impl<'a> ConstEvaluator<'a> { match col_val { ColumnarValue::Array(a) => { if a.len() != 1 { - Err(DataFusionError::Execution(format!( + exec_err!( "Could not evaluate the expression, found a result of length {}", a.len() - ))) + ) + } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { + Ok(ScalarValue::List(a)) } else { - Ok(ScalarValue::try_from_array(&a, 0)?) + // Non-ListArray + ScalarValue::try_from_array(&a, 0) } } ColumnarValue::Scalar(s) => Ok(s), @@ -413,7 +489,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }) if list.len() == 1 && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => { - let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; + let Expr::ScalarSubquery(subquery) = list.remove(0) else { + unreachable!() + }; Expr::InSubquery(InSubquery::new(expr, subquery, negated)) } @@ -669,18 +747,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right: _, }) if is_null(&left) => *left, - // A * 0 --> 0 (if A is not null) + // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if !info.nullable(&left)? && is_zero(&right) => *right, - // 0 * A --> 0 (if A is not null) + }) if !info.nullable(&left)? + && !info.get_data_type(&left)?.is_floating() + && is_zero(&right) => + { + *right + } + // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if !info.nullable(&right)? && is_zero(&left) => *left, + }) if !info.nullable(&right)? + && !info.get_data_type(&right)?.is_floating() + && is_zero(&left) => + { + *left + } // // Rules for Divide @@ -704,20 +792,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Divide, right, }) if is_null(&right) => *right, - // 0 / 0 -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right, - }) if is_zero(&left) && is_zero(&right) => { - Expr::Literal(ScalarValue::Int32(None)) - } - // A / 0 -> DivideByZero Error + // A / 0 -> DivideByZero Error if A is not null and not floating + // (float / 0 -> inf | -inf | NAN) Expr::BinaryExpr(BinaryExpr { left, op: Divide, right, - }) if !info.nullable(&left)? && is_zero(&right) => { + }) if !info.nullable(&left)? + && !info.get_data_type(&left)?.is_floating() + && is_zero(&right) => + { return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } @@ -737,19 +821,33 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Modulo, right: _, }) if is_null(&left) => *left, - // A % 1 --> 0 + // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right, - }) if !info.nullable(&left)? && is_one(&right) => lit(0), - // A % 0 --> DivideByZero Error + }) if !info.nullable(&left)? + && !info.get_data_type(&left)?.is_floating() + && is_one(&right) => + { + lit(0) + } + // A % 0 --> DivideByZero Error (if A is not floating and not null) + // A % 0 --> NAN (if A is floating and not null) Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right, }) if !info.nullable(&left)? && is_zero(&right) => { - return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); + match info.get_data_type(&left)? { + DataType::Float32 => lit(f32::NAN), + DataType::Float64 => lit(f64::NAN), + _ => { + return Err(DataFusionError::ArrowError( + ArrowError::DivideByZero, + )); + } + } } // @@ -1069,17 +1167,20 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Note: the rationale for this rewrite is that the expr can then be further // simplified using the existing rules for AND/OR - Expr::Case(case) - if !case.when_then_expr.is_empty() - && case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number - && info.is_boolean_type(&case.when_then_expr[0].1)? => + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) if !when_then_expr.is_empty() + && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && info.is_boolean_type(&when_then_expr[0].1)? => { // The disjunction of all the when predicates encountered so far let mut filter_expr = lit(false); // The disjunction of all the cases let mut out_expr = lit(false); - for (when, then) in case.when_then_expr { + for (when, then) in when_then_expr { let case_expr = when .as_ref() .clone() @@ -1090,7 +1191,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { filter_expr = filter_expr.or(*when); } - if let Some(else_expr) = case.else_expr { + if let Some(else_expr) = else_expr { let case_expr = filter_expr.not().and(*else_expr); out_expr = out_expr.or(case_expr); } @@ -1101,25 +1202,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Log, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Power, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Concat, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ConcatWithSeparator, + ), args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, @@ -1163,6 +1267,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { pattern, negated, escape_char: _, + case_insensitive: _, }) if !is_null(&expr) && matches!( pattern.as_ref(), @@ -1172,26 +1277,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { lit(!negated) } - // Rules for ILike - Expr::ILike(Like { - expr, - pattern, - negated, - escape_char: _, - }) if !is_null(&expr) - && matches!( - pattern.as_ref(), - Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" - ) => + // a is not null/unknown --> true (if a is not nullable) + Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) + if !info.nullable(&expr)? => { - lit(!negated) + lit(true) } - // a IS NOT NULL --> true, if a is not nullable - Expr::IsNotNull(expr) if !info.nullable(&expr)? => lit(true), - - // a IS NULL --> false, if a is not nullable - Expr::IsNull(expr) if !info.nullable(&expr)? => lit(false), + // a is null/unknown --> false (if a is not nullable) + Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { + lit(false) + } // no additional rewrites possible expr => expr, @@ -1208,24 +1304,25 @@ mod tests { sync::Arc, }; + use super::*; use crate::simplify_expressions::{ utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; - - use super::*; use crate::test::test_table_scan_with_name; + use arrow::{ array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, }; - use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; - use datafusion_expr::*; + use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::{ execution_props::ExecutionProps, functions::make_scalar_function, }; + use chrono::{DateTime, TimeZone, Utc}; + // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -1313,10 +1410,8 @@ mod tests { expected_expr: Expr, date_time: &DateTime, ) { - let execution_props = ExecutionProps { - query_execution_start_time: *date_time, - var_providers: None, - }; + let execution_props = + ExecutionProps::new().with_query_execution_start_time(*date_time); let mut const_evaluator = ConstEvaluator::try_new(&execution_props).unwrap(); let evaluated_expr = input_expr @@ -1408,7 +1503,7 @@ mod tests { test_evaluate(expr, lit("foobarbaz")); // Check non string arguments - // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) + // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400i64) let expr = call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); @@ -1460,7 +1555,7 @@ mod tests { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF(expr::ScalarUDF::new( + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -1469,15 +1564,21 @@ mod tests { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone())); + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + args.clone(), + )); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args)); - let expected_expr = - Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args)); + let expr = + Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + folded_args, + )); test_evaluate(expr, expected_expr); } @@ -1670,11 +1771,14 @@ mod tests { #[test] fn test_simplify_divide_zero_by_zero() { - // 0 / 0 -> null + // 0 / 0 -> DivideByZero let expr = lit(0) / lit(0); - let expected = lit(ScalarValue::Int32(None)); + let err = try_simplify(expr).unwrap_err(); - assert_eq!(simplify(expr), expected); + assert!( + matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), + "{err}" + ); } #[test] @@ -2504,10 +2608,43 @@ mod tests { col("c1") .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false), ); + assert_change( + regex_match(col("c1"), lit("^(fo_o)$")), + col("c1").eq(lit("fo_o")), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o)$")), + col("c1").eq(lit("fo_o")), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r)$")), + col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))), + ); + assert_change( + regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")), + col("c1") + .not_eq(lit("fo_o")) + .and(col("c1").not_eq(lit("ba_r"))), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")), + ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r")))) + .or(col("c1").eq(lit("ba_z"))), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")), + col("c1").in_list( + vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")], + false, + ), + ); // regular expressions that mismatch captured literals assert_no_change(regex_match(col("c1"), lit("(foo|bar)"))); assert_no_change(regex_match(col("c1"), lit("(foo|bar)*"))); + assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)"))); + assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*"))); + assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*"))); assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*"))); assert_no_change(regex_match(col("c1"), lit("^foo|bar$"))); assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$"))); @@ -2602,6 +2739,7 @@ mod tests { expr: Box::new(expr), pattern: Box::new(lit(pattern)), escape_char: None, + case_insensitive: false, }) } @@ -2611,24 +2749,27 @@ mod tests { expr: Box::new(expr), pattern: Box::new(lit(pattern)), escape_char: None, + case_insensitive: false, }) } fn ilike(expr: Expr, pattern: &str) -> Expr { - Expr::ILike(Like { + Expr::Like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(lit(pattern)), escape_char: None, + case_insensitive: true, }) } fn not_ilike(expr: Expr, pattern: &str) -> Expr { - Expr::ILike(Like { + Expr::Like(Like { negated: true, expr: Box::new(expr), pattern: Box::new(lit(pattern)), escape_char: None, + case_insensitive: true, }) } @@ -2649,6 +2790,19 @@ mod tests { try_simplify(expr).unwrap() } + fn simplify_with_guarantee( + expr: Expr, + guarantees: Vec<(Expr, NullableInterval)>, + ) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(schema), + ) + .with_guarantees(guarantees); + simplifier.simplify(expr).unwrap() + } + fn expr_test_schema() -> DFSchemaRef { Arc::new( DFSchema::new_with_metadata( @@ -2725,6 +2879,25 @@ mod tests { ); } + #[test] + fn simplify_expr_is_unknown() { + assert_eq!(simplify(col("c2").is_unknown()), col("c2").is_unknown(),); + + // 'c2_non_null is unknown' is always false + assert_eq!(simplify(col("c2_non_null").is_unknown()), lit(false)); + } + + #[test] + fn simplify_expr_is_not_known() { + assert_eq!( + simplify(col("c2").is_not_unknown()), + col("c2").is_not_unknown() + ); + + // 'c2_non_null is not unknown' is always true + assert_eq!(simplify(col("c2_non_null").is_not_unknown()), lit(true)); + } + #[test] fn simplify_expr_eq() { let schema = expr_test_schema(); @@ -2793,9 +2966,9 @@ mod tests { #[test] fn simplify_expr_case_when_then_else() { - // CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true + // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true // --> - // CASE WHERE c2 THEN false ELSE c2 + // CASE WHEN c2 THEN false ELSE c2 // --> // false assert_eq!( @@ -2810,9 +2983,9 @@ mod tests { col("c2").not().and(col("c2")) // #1716 ); - // CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2 + // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2 // --> - // CASE WHERE c2 THEN true ELSE c2 + // CASE WHEN c2 THEN true ELSE c2 // --> // c2 // @@ -2830,7 +3003,7 @@ mod tests { col("c2").or(col("c2").not().and(col("c2"))) // #1716 ); - // CASE WHERE ISNULL(c2) THEN true ELSE c2 + // CASE WHEN ISNULL(c2) THEN true ELSE c2 // --> // ISNULL(c2) OR c2 // @@ -2847,7 +3020,7 @@ mod tests { .or(col("c2").is_not_null().and(col("c2"))) ); - // CASE WHERE c1 then true WHERE c2 then false ELSE true + // CASE WHEN c1 then true WHEN c2 then false ELSE true // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) // --> c1 OR (NOT(c1) AND NOT(c2)) // --> c1 OR NOT(c2) @@ -2866,7 +3039,7 @@ mod tests { col("c1").or(col("c1").not().and(col("c2").not())) ); - // CASE WHERE c1 then true WHERE c2 then true ELSE false + // CASE WHEN c1 then true WHEN c2 then true ELSE false // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) // --> c1 OR (NOT(c1) AND c2) // --> c1 OR c2 @@ -3093,4 +3266,90 @@ mod tests { let expr = not_ilike(null, "%"); assert_eq!(simplify(expr), lit_bool_null()); } + + #[test] + fn test_simplify_with_guarantee() { + // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b"))) + let expr_x = col("c3").gt(lit(3_i64)); + let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32)); + let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true); + let expr = expr_x.clone().and(expr_y.clone().or(expr_z)); + + // All guaranteed null + let guarantees = vec![ + (col("c3"), NullableInterval::from(ScalarValue::Int64(None))), + (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))), + (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))), + ]; + + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit_bool_null()); + + // All guaranteed false + let guarantees = vec![ + ( + col("c3"), + NullableInterval::NotNull { + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), + }, + ), + ( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(9))), + ), + (col("c1"), NullableInterval::from(ScalarValue::from("a"))), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit(false)); + + // Guaranteed false or null -> no change. + let guarantees = vec![ + ( + col("c3"), + NullableInterval::MaybeNull { + values: Interval::make(Some(0_i64), Some(2_i64)).unwrap(), + }, + ), + ( + col("c4"), + NullableInterval::MaybeNull { + values: Interval::make(Some(9_u32), Some(9_u32)).unwrap(), + }, + ), + ( + col("c1"), + NullableInterval::NotNull { + values: Interval::try_new( + ScalarValue::from("d"), + ScalarValue::from("f"), + ) + .unwrap(), + }, + ), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(&output, &expr_x); + + // Sufficient true guarantees + let guarantees = vec![ + ( + col("c3"), + NullableInterval::from(ScalarValue::Int64(Some(9))), + ), + ( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(3))), + ), + ]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(output, lit(true)); + + // Only partially simplify + let guarantees = vec![( + col("c4"), + NullableInterval::from(ScalarValue::UInt32(Some(3))), + )]; + let output = simplify_with_guarantee(expr.clone(), guarantees); + assert_eq!(&output, &expr_x); + } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs new file mode 100644 index 0000000000000..860dc326b9b08 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] +//! +//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees + +use std::{borrow::Cow, collections::HashMap}; + +use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; + +/// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval]. The interval represents the known +/// possible values of the column. Using these known values, expressions are +/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +/// +/// See a full example in [`ExprSimplifier::with_guarantees()`]. +/// +/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees +pub(crate) struct GuaranteeRewriter<'a> { + guarantees: HashMap<&'a Expr, &'a NullableInterval>, +} + +impl<'a> GuaranteeRewriter<'a> { + pub fn new( + guarantees: impl IntoIterator, + ) -> Self { + Self { + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + } + } +} + +impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + if self.guarantees.is_empty() { + return Ok(expr); + } + + match &expr { + Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(true)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), + _ => Ok(expr), + }, + Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(false)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), + _ => Ok(expr), + }, + Expr::Between(Between { + expr: inner, + negated, + low, + high, + }) => { + if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + self.guarantees.get(inner.as_ref()), + low.as_ref(), + high.as_ref(), + ) { + let expr_interval = NullableInterval::NotNull { + values: Interval::try_new(low.clone(), high.clone())?, + }; + + let contains = expr_interval.contains(*interval)?; + + if contains.is_certainly_true() { + Ok(lit(!negated)) + } else if contains.is_certainly_false() { + Ok(lit(*negated)) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } + + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = self + .guarantees + .get(left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = left.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + let right_interval = self + .guarantees + .get(right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value) = right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + match (left_interval, right_interval) { + (Some(left_interval), Some(right_interval)) => { + let result = + left_interval.apply_operator(op, right_interval.as_ref())?; + if result.is_certainly_true() { + Ok(lit(true)) + } else if result.is_certainly_false() { + Ok(lit(false)) + } else { + Ok(expr) + } + } + _ => Ok(expr), + } + } + + // Columns (if interval is collapsed to a single value) + Expr::Column(_) => { + if let Some(interval) = self.guarantees.get(&expr) { + Ok(interval.single_value().map_or(expr, lit)) + } else { + Ok(expr) + } + } + + Expr::InList(InList { + expr: inner, + list, + negated, + }) => { + if let Some(interval) = self.guarantees.get(inner.as_ref()) { + // Can remove items from the list that don't match the guarantee + let new_list: Vec = list + .iter() + .filter_map(|expr| { + if let Expr::Literal(item) = expr { + match interval + .contains(&NullableInterval::from(item.clone())) + { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr.clone())), + Err(e) => Some(Err(e)), + } + } else { + Some(Ok(expr.clone())) + } + }) + .collect::>()?; + + Ok(Expr::InList(InList { + expr: inner.clone(), + list: new_list, + negated: *negated, + })) + } else { + Ok(expr) + } + } + + _ => Ok(expr), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::DataType; + use datafusion_common::{tree_node::TreeNode, ScalarValue}; + use datafusion_expr::{col, lit, Operator}; + + #[test] + fn test_null_handling() { + // IsNull / IsNotNull can be rewritten to true / false + let guarantees = vec![ + // Note: AlwaysNull case handled by test_column_single_value test, + // since it's a special case of a column with a single value. + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make_unbounded(&DataType::Boolean).unwrap(), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // x IS NULL => guaranteed false + let expr = col("x").is_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(false)); + + // x IS NOT NULL => guaranteed true + let expr = col("x").is_not_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(true)); + } + + fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) + where + ScalarValue: From, + T: Clone, + { + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + let expected = lit(ScalarValue::from(expected_value.clone())); + assert_eq!( + output, expected, + "{} simplified to {}, but expected {}", + expr, output, expected + ); + } + } + + fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { + for expr in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + assert_eq!( + &output, expr, + "{} was simplified to {}, but expected it to be unchanged", + expr, output + ); + } + } + + #[test] + fn test_inequalities_non_null_bounded() { + let guarantees = vec![ + // x ∈ [1, 3] (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + // s.y ∈ [1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + ]; + + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(0)), false), + (col("s").field("y").lt(lit(0)), false), + (col("x").lt_eq(lit(3)), true), + (col("x").gt(lit(3)), false), + (col("x").gt(lit(0)), true), + (col("x").eq(lit(0)), false), + (col("x").not_eq(lit(0)), true), + (col("x").between(lit(0), lit(5)), true), + (col("x").between(lit(5), lit(10)), false), + (col("x").not_between(lit(0), lit(5)), false), + (col("x").not_between(lit(5), lit(10)), true), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(5)), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").gt(lit(2)), + col("x").lt_eq(lit(2)), + col("x").eq(lit(2)), + col("x").not_eq(lit(2)), + col("x").between(lit(3), lit(5)), + col("x").not_between(lit(3), lit(10)), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_non_null_unbounded() { + let guarantees = vec![ + // y ∈ [2021-01-01, ∞) (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::try_new( + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ) + .unwrap(), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false), + (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true), + (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), + (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true), + ( + col("x").between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + false, + ), + ( + col("x").not_between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Date32(Some(17000)))), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit(ScalarValue::Date32(Some(19000)))), + col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + col("x").not_between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_maybe_null() { + let guarantees = vec![ + // x ∈ ("abc", "def"]? (maybe null) + ( + col("x"), + NullableInterval::MaybeNull { + values: Interval::try_new( + ScalarValue::from("abc"), + ScalarValue::from("def"), + ) + .unwrap(), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit("z")), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit("z")), + }), + false, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit("z")), + col("x").lt_eq(lit("z")), + col("x").gt(lit("a")), + col("x").gt_eq(lit("a")), + col("x").eq(lit("abc")), + col("x").not_eq(lit("a")), + col("x").between(lit("a"), lit("z")), + col("x").not_between(lit("a"), lit("z")), + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_column_single_value() { + let scalars = [ + ScalarValue::Null, + ScalarValue::Int32(Some(1)), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(None), + ScalarValue::from("abc"), + ScalarValue::LargeUtf8(Some("def".to_string())), + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ScalarValue::Decimal128(Some(1000), 19, 2), + ]; + + for scalar in scalars { + let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + let output = col("x").rewrite(&mut rewriter).unwrap(); + assert_eq!(output, Expr::Literal(scalar.clone())); + } + } + + #[test] + fn test_in_list() { + let guarantees = vec![ + // x ∈ [1, 10] (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::try_new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ) + .unwrap(), + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // These cases should be simplified so the list doesn't contain any + // values the guarantee says are outside the range. + // (column_name, starting_list, negated, expected_list) + let cases = &[ + // x IN (9, 11) => x IN (9) + ("x", vec![9, 11], false, vec![9]), + // x IN (10, 2) => x IN (10, 2) + ("x", vec![10, 2], false, vec![10, 2]), + // x NOT IN (9, 11) => x NOT IN (9) + ("x", vec![9, 11], true, vec![9]), + // x NOT IN (0, 22) => x NOT IN () + ("x", vec![0, 22], true, vec![]), + ]; + + for (column_name, starting_list, negated, expected_list) in cases { + let expr = col(*column_name).in_list( + starting_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(), + *negated, + ); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let expected_list = expected_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); + assert_eq!( + output, + Expr::InList(InList { + expr: Box::new(col(*column_name)), + list: expected_list, + negated: *negated, + }) + ); + } + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index dfa0fe70433ba..2cf6ed166cdde 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -17,6 +17,7 @@ pub mod context; pub mod expr_simplifier; +mod guarantees; mod or_in_list_simplifier; mod regex; pub mod simplify_exprs; diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 108f1774b42c0..175b70f2b10e4 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -84,15 +84,12 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::Utf8(Some(pattern)))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), escape_char: None, + case_insensitive: self.i, }; - if self.i { - Expr::ILike(like) - } else { - Expr::Like(like) - } + Expr::Like(like) } fn expr_matches_literal(&self, left: Box, right: Box) -> Expr { @@ -111,7 +108,7 @@ fn collect_concat_to_like_string(parts: &[Hir]) -> Option { for sub in parts { if let HirKind::Literal(l) = sub.kind() { - s.push_str(str_from_literal(l)?); + s.push_str(like_str_from_literal(l)?); } else { return None; } @@ -123,7 +120,7 @@ fn collect_concat_to_like_string(parts: &[Hir]) -> Option { /// returns a str represented by `Literal` if it contains a valid utf8 /// sequence and is safe for like (has no '%' and '_') -fn str_from_literal(l: &Literal) -> Option<&str> { +fn like_str_from_literal(l: &Literal) -> Option<&str> { // if not utf8, no good let s = std::str::from_utf8(&l.0).ok()?; @@ -134,6 +131,14 @@ fn str_from_literal(l: &Literal) -> Option<&str> { } } +/// returns a str represented by `Literal` if it contains a valid utf8 +fn str_from_literal(l: &Literal) -> Option<&str> { + // if not utf8, no good + let s = std::str::from_utf8(&l.0).ok()?; + + Some(s) +} + fn is_safe_for_like(c: char) -> bool { (c != '%') && (c != '_') } @@ -198,8 +203,10 @@ fn anchored_literal_to_expr(v: &[Hir]) -> Option { match v.len() { 2 => Some(lit("")), 3 => { - let HirKind::Literal(l) = v[1].kind() else { return None }; - str_from_literal(l).map(lit) + let HirKind::Literal(l) = v[1].kind() else { + return None; + }; + like_str_from_literal(l).map(lit) } _ => None, } @@ -245,7 +252,7 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { return Some(mode.expr(Box::new(left.clone()), "%".to_owned())); } HirKind::Literal(l) => { - let s = str_from_literal(l)?; + let s = like_str_from_literal(l)?; return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); } HirKind::Concat(inner) if is_anchored_literal(inner) => { diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 239497d9fa7b2..43a41b1185a33 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use super::{ExprSimplifier, SimplifyContext}; -use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, DFSchemaRef, Result}; -use datafusion_expr::{logical_plan::LogicalPlan, utils::from_plan}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::merge_schema; use datafusion_physical_expr::execution_props::ExecutionProps; /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting @@ -65,10 +65,21 @@ impl SimplifyExpressions { ) -> Result { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(plan.inputs())) - } else if let LogicalPlan::TableScan(_) = plan { - // When predicates are pushed into a table scan, there needs to be - // a schema to resolve the fields against. - Arc::clone(plan.schema()) + } else if let LogicalPlan::TableScan(scan) = plan { + // When predicates are pushed into a table scan, there is no input + // schema to resolve predicates against, so it must be handled specially + // + // Note that this is not `plan.schema()` which is the *output* + // schema, and reflects any pushed down projection. The output schema + // will not contain columns that *only* appear in pushed down predicates + // (and no where else) in the plan. + // + // Thus, use the full schema of the inner provider without any + // projection applied for simplification + Arc::new(DFSchema::try_from_qualified_schema( + &scan.table_name, + &scan.source.schema(), + )?) } else { Arc::new(DFSchema::empty()) }; @@ -86,28 +97,14 @@ impl SimplifyExpressions { .expressions() .into_iter() .map(|e| { - // We need to keep original expression name, if any. - // Constant folding should not change expression name. - let name = &e.display_name(); - - // Apply the actual simplification logic + // TODO: unify with `rewrite_preserving_name` + let original_name = e.name_for_alias()?; let new_e = simplifier.simplify(e)?; - - let new_name = &new_e.display_name(); - - if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { - if expr_name != new_expr_name { - Ok(new_e.alias(expr_name)) - } else { - Ok(new_e) - } - } else { - Ok(new_e) - } + new_e.alias_if_changed(original_name) }) .collect::>>()?; - from_plan(plan, &expr, &new_inputs) + plan.with_new_exprs(expr, &new_inputs) } } @@ -125,7 +122,7 @@ mod tests { use crate::simplify_expressions::utils::for_test::{ cast_to_int64_expr, now_expr, to_timestamp_expr, }; - use crate::test::test_table_scan_with_name; + use crate::test::{assert_fields_eq, test_table_scan_with_name}; use super::*; use arrow::datatypes::{DataType, Field, Schema}; @@ -188,6 +185,48 @@ mod tests { Ok(()) } + #[test] + fn test_simplify_table_full_filter_in_scan() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ]; + + let schema = Schema::new(fields); + + let table_scan = table_scan_with_filters( + Some("test"), + &schema, + Some(vec![0]), + vec![col("b").is_not_null()], + )? + .build()?; + assert_eq!(1, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a"]); + + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + + assert_optimized_plan_eq(&table_scan, expected) + } + + #[test] + fn test_simplify_filter_pushdown() -> Result<()> { + let table_scan = test_table_scan(); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a")])? + .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? + .build()?; + + assert_optimized_plan_eq( + &plan, + "\ + Filter: test.b > Int32(1)\ + \n Projection: test.a\ + \n TableScan: test", + ) + } + #[test] fn test_simplify_optimized_plan() -> Result<()> { let table_scan = test_table_scan(); @@ -487,8 +526,8 @@ mod tests { let expected = format!( "Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\ \n TableScan: test", - time.timestamp_nanos(), - time.timestamp_nanos() + time.timestamp_nanos_opt().unwrap(), + time.timestamp_nanos_opt().unwrap() ); assert_eq!(expected, actual); @@ -847,8 +886,7 @@ mod tests { // before simplify: t.g = power(t.f, 1.0) // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = - "TableScan: test, unsupported_filters=[g = f AS g = power(f,Float64(1))]"; + let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; assert_optimized_plan_eq(&plan, expected) } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 9d3620248581d..fa91a3ace2a25 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -18,12 +18,12 @@ //! Utility functions for expression simplification use crate::simplify_expressions::SimplifyInfo; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -226,9 +226,7 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { pub fn as_bool_lit(expr: Expr) -> Result> { match expr { Expr::Literal(ScalarValue::Boolean(v)) => Ok(v), - _ => Err(DataFusionError::Internal(format!( - "Expected boolean literal, got {expr:?}" - ))), + _ => internal_err!("Expected boolean literal, got {expr:?}"), } } @@ -300,13 +298,7 @@ pub fn negate_clause(expr: Expr) -> Expr { like.expr, like.pattern, like.escape_char, - )), - // not (A ilike B) ===> A not ilike B - Expr::ILike(like) => Expr::ILike(Like::new( - !like.negated, - like.expr, - like.pattern, - like.escape_char, + like.case_insensitive, )), // use not clause _ => Expr::Not(Box::new(expr)), @@ -373,7 +365,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -413,7 +405,7 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( @@ -444,9 +436,9 @@ pub fn simpl_concat(args: Vec) -> Result { ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), ) => contiguous_scalar += &v, Expr::Literal(x) => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "The scalar {x} should be casted to string type during the type coercion." - ))) + ) } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. @@ -502,7 +494,7 @@ pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { } } } - Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {s} should be casted to string type during the type coercion."))), + Expr::Literal(s) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -527,14 +519,14 @@ pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { None => Ok(Expr::Literal(ScalarValue::Utf8(None))), } } - Expr::Literal(d) => Err(DataFusionError::Internal(format!( + Expr::Literal(d) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." - ))), + ), d => Ok(concat_ws( d.clone(), args.iter() + .filter(|&x| !is_null(x)) .cloned() - .filter(|x| !is_null(x)) .collect::>(), )), } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index ba7e89094b0f3..7e6fb6b355ab1 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -17,32 +17,39 @@ //! single distinct to group by optimizer rule +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ + aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan, Projection}, utils::columnize_expr, Expr, ExprSchemable, }; + use hashbrown::HashSet; -use std::sync::Arc; /// single distinct to group by optimizer rule /// ```text -/// SELECT F1(DISTINCT s),F2(DISTINCT s) -/// ... -/// GROUP BY k +/// Before: +/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// FROM t +/// GROUP BY a /// -/// Into -/// -/// SELECT F1(alias1),F2(alias1) +/// After: +/// SELECT a, COUNT(alias1), SUM(alias2) /// FROM ( -/// SELECT s as alias1, k ... GROUP BY s, k +/// SELECT a, b as alias1, SUM(c) as alias2 +/// FROM t +/// GROUP BY a, b /// ) -/// GROUP BY k +/// GROUP BY a /// ``` #[derive(Default)] pub struct SingleDistinctToGroupBy {} @@ -61,22 +68,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut distinct_count = 0; + let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - distinct, args, .. + func_def: AggregateFunctionDefinition::BuiltIn(fun), + distinct, + args, + filter, + order_by, }) = expr { - if *distinct { - distinct_count += 1; + if filter.is_some() || order_by.is_some() { + return Ok(false); } - for e in args { - fields_set.insert(e.display_name()?); + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); } } } - let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; - Ok(res) + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } _ => Ok(false), } @@ -102,51 +117,104 @@ impl OptimizerRule for SingleDistinctToGroupBy { .. }) => { if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { + let fields = schema.fields(); // alias all original group_by exprs - let mut group_expr_alias = Vec::with_capacity(group_expr.len()); - let mut inner_group_exprs = group_expr + let (mut inner_group_exprs, out_group_expr_with_alias): ( + Vec, + Vec<(Expr, Option)>, + ) = group_expr .iter() .enumerate() .map(|(i, group_expr)| { - let alias_str = format!("group_alias_{i}"); - let alias_expr = group_expr.clone().alias(&alias_str); - group_expr_alias - .push((alias_str, schema.fields()[i].clone())); - alias_expr + if let Expr::Column(_) = group_expr { + // For Column expressions we can use existing expression as is. + (group_expr.clone(), (group_expr.clone(), None)) + } else { + // For complex expression write is as alias, to be able to refer + // if from parent operators successfully. + // Consider plan below. + // + // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // First aggregate(from bottom) refers to `test.a` column. + // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias + // + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. + let alias_str = format!("group_alias_{i}"); + let alias_expr = group_expr.clone().alias(&alias_str); + ( + alias_expr, + (col(alias_str), Some(fields[i].qualified_name())), + ) + } }) - .collect::>(); + .unzip(); // and they can be referenced by the alias in the outer aggr plan - let outer_group_exprs = group_expr_alias + let outer_group_exprs = out_group_expr_with_alias .iter() - .map(|(alias, _)| col(alias)) + .map(|(out_group_expr, _)| out_group_expr.clone()) .collect::>(); // replace the distinct arg with alias + let mut index = 1; let mut group_fields_set = HashSet::new(); - let new_aggr_exprs = aggr_expr + let mut inner_aggr_exprs = vec![]; + let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, - filter, - order_by, + distinct, .. }) => { // is_single_distinct_agg ensure args.len=1 - if group_fields_set.insert(args[0].display_name()?) { + if *distinct + && group_fields_set.insert(args[0].display_name()?) + { inner_group_exprs.push( args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - ))) + + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + if !(*distinct) { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + args.clone(), + false, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(&alias_str)], + false, + None, + None, + ))) + } else { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + ))) + } } _ => Ok(aggr_expr.clone()), }) @@ -155,6 +223,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // construct the inner AggrPlan let inner_fields = inner_group_exprs .iter() + .chain(inner_aggr_exprs.iter()) .map(|expr| expr.to_field(input.schema())) .collect::>>()?; let inner_schema = DFSchema::new_with_metadata( @@ -164,15 +233,16 @@ impl OptimizerRule for SingleDistinctToGroupBy { let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, - Vec::new(), + inner_aggr_exprs, )?); + let outer_fields = outer_group_exprs + .iter() + .chain(outer_aggr_exprs.iter()) + .map(|expr| expr.to_field(&inner_schema)) + .collect::>>()?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( - outer_group_exprs - .iter() - .chain(new_aggr_exprs.iter()) - .map(|expr| expr.to_field(&inner_schema)) - .collect::>>()?, + outer_fields, input.schema().metadata().clone(), )?); @@ -180,34 +250,33 @@ impl OptimizerRule for SingleDistinctToGroupBy { // this optimizer has two kinds of alias: // - group_by aggr // - aggr expr - let mut alias_expr: Vec = Vec::new(); - for (alias, original_field) in group_expr_alias { - alias_expr - .push(col(alias).alias(original_field.qualified_name())); - } - for (i, expr) in new_aggr_exprs.iter().enumerate() { - alias_expr.push(columnize_expr( - expr.clone().alias( - schema.clone().fields()[i + group_expr.len()] - .qualified_name(), - ), - &outer_aggr_schema, - )); - } + let group_size = group_expr.len(); + let alias_expr = out_group_expr_with_alias + .into_iter() + .map(|(group_expr, original_field)| { + if let Some(name) = original_field { + group_expr.alias(name) + } else { + group_expr + } + }) + .chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + let idx = idx + group_size; + let name = fields[idx].qualified_name(); + columnize_expr(expr.clone().alias(name), &outer_aggr_schema) + })) + .collect(); let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( Arc::new(inner_agg), outer_group_exprs, - new_aggr_exprs, + outer_aggr_exprs, )?); - Ok(Some(LogicalPlan::Projection( - Projection::try_new_with_schema( - alias_expr, - Arc::new(outer_aggr), - schema.clone(), - )?, - ))) + Ok(Some(LogicalPlan::Projection(Projection::try_new( + alias_expr, + Arc::new(outer_aggr), + )?))) } else { Ok(None) } @@ -233,7 +302,7 @@ mod tests { use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, - AggregateFunction, + min, sum, AggregateFunction, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -293,7 +362,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -311,7 +380,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -330,7 +399,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -361,9 +430,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -407,9 +476,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1), MAX(alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ + let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) @@ -449,4 +518,181 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + + #[test] + fn two_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + sum(col("c")), + count_distinct(col("b")), + Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Max, + vec![col("b")], + true, + None, + None, + )), + ], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinctand_and_two_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("c")], + vec![min(col("a")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + None, + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn common_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // SUM(a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Sum, + vec![col("a")], + false, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn distinct_with_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + None, + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn aggregate_with_filter_and_order_by() -> Result<()> { + let table_scan = test_table_scan()?; + + // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("a")], + true, + Some(Box::new(col("a").gt(lit(5)))), + Some(vec![col("a")]), + )); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![sum(col("a")), expr])? + .build()?; + // Do nothing + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 7d334a80b6826..e691fe9a53516 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::Optimizer; +use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -155,20 +155,42 @@ pub fn assert_optimized_plan_eq( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); + let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); + + // Ensure schemas always match after an optimization + assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } +pub fn assert_optimized_plan_eq_with_rules( + rules: Vec>, + plan: &LogicalPlan, + expected: &str, +) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let config = &mut OptimizerContext::new() + .with_max_passes(1) + .with_skip_failing_rules(false); + let optimizer = Optimizer::with_rules(rules); + let optimized_plan = optimizer + .optimize(plan, config, observe) + .expect("failed to optimize plan"); + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimized_plan.schema()); + Ok(()) +} + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, plan: &LogicalPlan, @@ -177,7 +199,7 @@ pub fn assert_optimized_plan_eq_display_indent( let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ) @@ -211,7 +233,7 @@ pub fn assert_optimizer_err( ) { let optimizer = Optimizer::with_rules(vec![rule]); let res = optimizer.optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), ); @@ -233,7 +255,7 @@ pub fn assert_optimization_skipped( let optimizer = Optimizer::with_rules(vec![rule]); let new_plan = optimizer .optimize_recursively( - optimizer.rules.get(0).unwrap(), + optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 2c70ad0e9acd5..91603e82a54fc 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,16 +19,18 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::{merge_schema, rewrite_preserving_name}; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; -use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; -use datafusion_expr::utils::from_plan; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; @@ -89,6 +91,12 @@ impl OptimizerRule for UnwrapCastInComparison { ) -> Result> { let mut schema = merge_schema(plan.inputs()); + if let LogicalPlan::TableScan(ts) = plan { + let source_schema = + DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; + schema.merge(&source_schema); + } + schema.merge(plan.schema()); let mut expr_rewriter = UnwrapCastExprRewriter { @@ -102,11 +110,7 @@ impl OptimizerRule for UnwrapCastInComparison { .collect::>>()?; let inputs: Vec = plan.inputs().into_iter().cloned().collect(); - Ok(Some(from_plan( - plan, - new_exprs.as_slice(), - inputs.as_slice(), - )?)) + Ok(Some(plan.with_new_exprs(new_exprs, inputs.as_slice())?)) } fn name(&self) -> &str { @@ -225,10 +229,10 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { .map(|right| { let right_type = right.get_type(&self.schema)?; if !is_support_data_type(&right_type) { - return Err(DataFusionError::Internal(format!( + return internal_err!( "The type of list expr {} not support", &right_type - ))); + ); } match right { Expr::Literal(right_lit_value) => { @@ -239,16 +243,16 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { if let Some(value) = casted_scalar_value { Ok(lit(value)) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "Can't cast the list expr {:?} to type {:?}", right_lit_value, &internal_left_type - ))) + ) } } - other_expr => Err(DataFusionError::Internal(format!( + other_expr => internal_err!( "Only support literal expr to optimize, but the expr is {:?}", &other_expr - ))), + ), } }) .collect::>>(); @@ -300,7 +304,7 @@ fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, ) -> Result> { - let lit_data_type = lit_value.get_datatype(); + let lit_data_type = lit_value.data_type(); // the rule just support the signed numeric data type now if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) { return Ok(None); @@ -321,9 +325,7 @@ fn try_cast_literal_to_type( DataType::Timestamp(_, _) => 1_i128, DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), other_type => { - return Err(DataFusionError::Internal(format!( - "Error target data type {other_type:?}" - ))); + return internal_err!("Error target data type {other_type:?}"); } }; let (target_min, target_max) = match target_type { @@ -344,9 +346,7 @@ fn try_cast_literal_to_type( MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1], ), other_type => { - return Err(DataFusionError::Internal(format!( - "Error target data type {other_type:?}" - ))); + return internal_err!("Error target data type {other_type:?}"); } }; let lit_value_target_type = match lit_value { @@ -382,9 +382,7 @@ fn try_cast_literal_to_type( } } other_value => { - return Err(DataFusionError::Internal(format!( - "Invalid literal value {other_value:?}" - ))); + return internal_err!("Invalid literal value {other_value:?}"); } }; @@ -439,9 +437,7 @@ fn try_cast_literal_to_type( ScalarValue::Decimal128(Some(value), *p, *s) } other_type => { - return Err(DataFusionError::Internal(format!( - "Error target data type {other_type:?}" - ))); + return internal_err!("Error target data type {other_type:?}"); } }; Ok(Some(result_scalar)) @@ -827,7 +823,7 @@ mod tests { for s2 in &scalars { let expected_value = ExpectedCast::Value(s2.clone()); - expect_cast(s1.clone(), s2.get_datatype(), expected_value); + expect_cast(s1.clone(), s2.data_type(), expected_value); } } } @@ -852,7 +848,7 @@ mod tests { for s2 in &scalars { let expected_value = ExpectedCast::Value(s2.clone()); - expect_cast(s1.clone(), s2.get_datatype(), expected_value); + expect_cast(s1.clone(), s2.data_type(), expected_value); } } @@ -986,10 +982,10 @@ mod tests { assert_eq!(lit_tz_none, lit_tz_utc); // e.g. DataType::Timestamp(_, None) - let dt_tz_none = lit_tz_none.get_datatype(); + let dt_tz_none = lit_tz_none.data_type(); // e.g. DataType::Timestamp(_, Some(utc)) - let dt_tz_utc = lit_tz_utc.get_datatype(); + let dt_tz_utc = lit_tz_utc.data_type(); // None <--> None expect_cast( @@ -1093,8 +1089,12 @@ mod tests { // Verify that calling the arrow // cast kernel yields the same results // input array - let literal_array = literal.to_array_of_size(1); - let expected_array = expected_value.to_array_of_size(1); + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); let cast_array = cast_with_options( &literal_array, &target_type, @@ -1112,7 +1112,7 @@ mod tests { if let ( DataType::Timestamp(left_unit, left_tz), DataType::Timestamp(right_unit, right_tz), - ) = (actual_value.get_datatype(), expected_value.get_datatype()) + ) = (actual_value.data_type(), expected_value.data_type()) { assert_eq!(left_unit, right_unit); assert_eq!(left_tz, right_tz); diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 266d0a0be7145..48f72ee7a0f87 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,21 +18,13 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{BinaryExpr, Sort}; -use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; -use datafusion_expr::logical_plan::LogicalPlanBuilder; -use datafusion_expr::utils::from_plan; -use datafusion_expr::{ - and, - logical_plan::{Filter, LogicalPlan}, - Expr, Operator, -}; +use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::utils as expr_utils; +use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; use log::{debug, trace}; use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -46,7 +38,6 @@ pub fn optimize_children( plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let new_exprs = plan.expressions(); let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { @@ -55,35 +46,61 @@ pub fn optimize_children( new_inputs.push(new_input.unwrap_or_else(|| input.clone())) } if plan_is_changed { - Ok(Some(from_plan(plan, &new_exprs, &new_inputs)?)) + Ok(Some(plan.with_new_inputs(&new_inputs)?)) } else { Ok(None) } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.has_column(&col) { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_>::Ok(cols) + }) +} + +pub(crate) fn replace_qualified_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols: Vec = cols + .iter() + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); + + replace_col(expr, &replace_map) +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +pub fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction` instead" +)] pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { - split_conjunction_impl(expr, vec![]) -} - -fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_impl(left, exprs); - split_conjunction_impl(right, exprs) - } - Expr::Alias(expr, _) => split_conjunction_impl(expr, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_conjunction(expr) } /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` @@ -107,8 +124,12 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// // use split_conjunction_owned to split them /// assert_eq!(split_conjunction_owned(expr), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_conjunction_owned` instead" +)] pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_binary_owned(expr, Operator::And) + expr_utils::split_conjunction_owned(expr) } /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` @@ -133,51 +154,23 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// // use split_binary_owned to split them /// assert_eq!(split_binary_owned(expr, Operator::Plus), split); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary_owned` instead" +)] pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { - split_binary_owned_impl(expr, op, vec![]) -} - -fn split_binary_owned_impl( - expr: Expr, - operator: Operator, - mut exprs: Vec, -) -> Vec { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { - let exprs = split_binary_owned_impl(*left, operator, exprs); - split_binary_owned_impl(*right, operator, exprs) - } - Expr::Alias(expr, _) => split_binary_owned_impl(*expr, operator, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary_owned(expr, op) } /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::split_binary` instead" +)] pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { - split_binary_impl(expr, op, vec![]) -} - -fn split_binary_impl<'a>( - expr: &'a Expr, - operator: Operator, - mut exprs: Vec<&'a Expr>, -) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { - let exprs = split_binary_impl(left, operator, exprs); - split_binary_impl(right, operator, exprs) - } - Expr::Alias(expr, _) => split_binary_impl(expr, operator, exprs), - other => { - exprs.push(other); - exprs - } - } + expr_utils::split_binary(expr, op) } /// Combines an array of filter expressions into a single filter @@ -202,8 +195,12 @@ fn split_binary_impl<'a>( /// // use conjunction to join them together with `AND` /// assert_eq!(conjunction(split), Some(expr)); /// ``` +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::conjunction` instead" +)] pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + expr_utils::conjunction(filters) } /// Combines an array of filter expressions into a single filter @@ -211,34 +208,22 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::disjunction` instead" +)] pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) -} - -/// Recursively un-alias an expressions -#[inline] -pub fn unalias(expr: Expr) -> Expr { - match expr { - Expr::Alias(sub_expr, _) => unalias(*sub_expr), - _ => expr, - } + expr_utils::disjunction(filters) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::add_filter` instead" +)] pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { - // reduce filters to a single filter with an AND - let predicate = predicates - .iter() - .skip(1) - .fold(predicates[0].clone(), |acc, predicate| { - and(acc, (*predicate).to_owned()) - }); - - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(plan), - )?)) + expr_utils::add_filter(plan, predicates) } /// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and @@ -251,22 +236,12 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result) -> Result<(Vec, Vec)> { - let mut joins = vec![]; - let mut others = vec![]; - for filter in exprs.into_iter() { - // If the expression contains correlated predicates, add it to join filters - if filter.contains_outer() { - if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right)) - { - joins.push(strip_outer_reference((*filter).clone())); - } - } else { - others.push((*filter).clone()); - } - } - - Ok((joins, others)) + expr_utils::find_join_exprs(exprs) } /// Returns the first (and only) element in a slice, or an error @@ -278,344 +253,19 @@ pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { /// # Return value /// /// The first element, or an error +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::only_or_err` instead" +)] pub fn only_or_err(slice: &[T]) -> Result<&T> { - match slice { - [it] => Ok(it), - [] => plan_err!("No items found!"), - _ => plan_err!("More than one item found!"), - } -} - -/// Rewrites `expr` using `rewriter`, ensuring that the output has the -/// same name as `expr` prior to rewrite, adding an alias if necessary. -/// -/// This is important when optimizing plans to ensure the output -/// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result -where - R: TreeNodeRewriter, -{ - let original_name = name_for_alias(&expr)?; - let expr = expr.rewrite(rewriter)?; - add_alias_if_changed(original_name, expr) -} - -/// Return the name to use for the specific Expr, recursing into -/// `Expr::Sort` as appropriate -fn name_for_alias(expr: &Expr) -> Result { - match expr { - Expr::Sort(Sort { expr, .. }) => name_for_alias(expr), - expr => expr.display_name(), - } -} - -/// Ensure `expr` has the name as `original_name` by adding an -/// alias if necessary. -fn add_alias_if_changed(original_name: String, expr: Expr) -> Result { - let new_name = name_for_alias(&expr)?; - - if new_name == original_name { - return Ok(expr); - } - - Ok(match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = add_alias_if_changed(original_name, *expr)?; - Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first)) - } - expr => expr.alias(original_name), - }) + expr_utils::only_or_err(slice) } /// merge inputs schema into a single schema. +#[deprecated( + since = "34.0.0", + note = "use `datafusion_expr::utils::merge_schema` instead" +)] pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() - } else { - inputs.iter().map(|input| input.schema()).fold( - DFSchema::empty(), - |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }, - ) - } -} - -/// Extract join predicates from the correlated subquery's [Filter] expressions. -/// The join predicate means that the expression references columns -/// from both the subquery and outer table or only from the outer table. -/// -/// Returns join predicates and subquery(extracted). -pub(crate) fn extract_join_filters( - maybe_filter: &LogicalPlan, -) -> Result<(Vec, LogicalPlan)> { - if let LogicalPlan::Filter(plan_filter) = maybe_filter { - let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); - let (join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; - // if the subquery still has filter expressions, restore them. - let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); - if let Some(expr) = conjunction(subquery_filters) { - plan = plan.filter(expr)? - } - - Ok((join_filters, plan.build()?)) - } else { - Ok((vec![], maybe_filter.clone())) - } -} - -pub(crate) fn collect_subquery_cols( - exprs: &[Expr], - subquery_schema: DFSchemaRef, -) -> Result> { - exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { - let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); - } - } - - cols.extend(using_cols); - Result::<_>::Ok(cols) - }) -} - -pub(crate) fn replace_qualified_name( - expr: Expr, - cols: &BTreeSet, - subquery_alias: &str, -) -> Result { - let alias_cols: Vec = cols - .iter() - .map(|col| { - Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) - }) - .collect(); - let replace_map: HashMap<&Column, &Column> = - cols.iter().zip(alias_cols.iter()).collect(); - - replace_col(expr, &replace_map) -} - -/// Log the plan in debug/tracing mode after some part of the optimizer runs -pub fn log_plan(description: &str, plan: &LogicalPlan) { - debug!("{description}:\n{}\n", plan.display_indent()); - trace!("{description}::\n{}\n", plan.display_indent_schema()); -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - use datafusion_common::Column; - use datafusion_expr::expr::Cast; - use datafusion_expr::{col, lit, utils::expr_to_columns}; - use std::collections::HashSet; - use std::ops::Add; - - #[test] - fn test_split_conjunction() { - let expr = col("a"); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_conjunction_two() { - let expr = col("a").eq(lit(5)).and(col("b")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_alias() { - let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias")); - let expr1 = col("a").eq(lit(5)); - let expr2 = col("b"); // has no alias - - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr1, &expr2]); - } - - #[test] - fn test_split_conjunction_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - let result = split_conjunction(&expr); - assert_eq!(result, vec![&expr]); - } - - #[test] - fn test_split_binary_owned() { - let expr = col("a"); - assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); - } - - #[test] - fn test_split_binary_owned_two() { - assert_eq!( - split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_binary_owned_different_op() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!( - // expr is connected by OR, but pass in AND - split_binary_owned(expr.clone(), Operator::And), - vec![expr] - ); - } - - #[test] - fn test_split_conjunction_owned() { - let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_split_conjunction_owned_two() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), - vec![col("a").eq(lit(5)), col("b")] - ); - } - - #[test] - fn test_split_conjunction_owned_alias() { - assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), - vec![ - col("a").eq(lit(5)), - // no alias on b - col("b"), - ] - ); - } - - #[test] - fn test_conjunction_empty() { - assert_eq!(conjunction(vec![]), None); - } - - #[test] - fn test_conjunction() { - // `[A, B, C]` - let expr = conjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A AND B) AND C` - assert_eq!(expr, Some(col("a").and(col("b")).and(col("c")))); - - // which is different than `A AND (B AND C)` - assert_ne!(expr, Some(col("a").and(col("b").and(col("c"))))); - } - - #[test] - fn test_disjunction_empty() { - assert_eq!(disjunction(vec![]), None); - } - - #[test] - fn test_disjunction() { - // `[A, B, C]` - let expr = disjunction(vec![col("a"), col("b"), col("c")]); - - // --> `(A OR B) OR C` - assert_eq!(expr, Some(col("a").or(col("b")).or(col("c")))); - - // which is different than `A OR (B OR C)` - assert_ne!(expr, Some(col("a").or(col("b").or(col("c"))))); - } - - #[test] - fn test_split_conjunction_owned_or() { - let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); - } - - #[test] - fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), - &mut accum, - )?; - assert_eq!(1, accum.len()); - assert!(accum.contains(&Column::from_name("a"))); - Ok(()) - } - - #[test] - fn test_rewrite_preserving_name() { - test_rewrite(col("a"), col("a")); - - test_rewrite(col("a"), col("b")); - - // cast data types - test_rewrite( - col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), - ); - - // change literal type from i32 to i64 - test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); - - // SortExpr a+1 ==> b + 2 - test_rewrite( - Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), - Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), - ); - } - - /// rewrites `expr_from` to `rewrite_to` using - /// `rewrite_preserving_name` verifying the result is `expected_expr` - fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { - struct TestRewriter { - rewrite_to: Expr, - } - - impl TreeNodeRewriter for TestRewriter { - type N = Expr; - - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) - } - } - - let mut rewriter = TestRewriter { - rewrite_to: rewrite_to.clone(), - }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); - - let original_name = match &expr_from { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); - - let new_name = match &expr { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); - - assert_eq!( - original_name, new_name, - "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" - ) - } + expr_utils::merge_schema(inputs) } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/optimizer_integration.rs similarity index 89% rename from datafusion/optimizer/tests/integration-test.rs rename to datafusion/optimizer/tests/optimizer_integration.rs index dfd1955177c9a..d857c6154ea97 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; -use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; @@ -28,9 +31,8 @@ use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; + +use chrono::{DateTime, NaiveDateTime, Utc}; #[cfg(test)] #[ctor::ctor] @@ -67,15 +69,13 @@ fn subquery_filter_with_cast() -> Result<()> { )"; let plan = test_sql(sql)?; let expected = "Projection: test.col_int32\ - \n Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.__value\ - \n CrossJoin:\ - \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: __scalar_sq_1\ - \n Projection: AVG(test.col_int32) AS __value\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ - \n Projection: test.col_int32\ - \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ - \n TableScan: test projection=[col_int32, col_utf8]"; + \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.AVG(test.col_int32)\ + \n TableScan: test projection=[col_int32]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Aggregate: groupBy=[[]], aggr=[[AVG(CAST(test.col_int32 AS Float64))]]\ + \n Projection: test.col_int32\ + \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -187,8 +187,9 @@ fn between_date32_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ - \n TableScan: test projection=[col_date32]"; + \n Projection: \ + \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ + \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -200,8 +201,9 @@ fn between_date64_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ - \n TableScan: test projection=[col_date64]"; + \n Projection: \ + \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ + \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -213,7 +215,7 @@ fn concat_literals() -> Result<()> { FROM test"; let plan = test_sql(sql)?; let expected = - "Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ + "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) @@ -226,7 +228,7 @@ fn concat_ws_literals() -> Result<()> { FROM test"; let plan = test_sql(sql)?; let expected = - "Projection: concat_ws(Utf8(\"-\"), Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ + "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) @@ -312,8 +314,8 @@ fn join_keys_in_subquery_alias_1() { fn push_down_filter_groupby_expr_contains_alias() { let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; let plan = test_sql(sql).unwrap(); - let expected = "Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: test.col_int32 + test.col_uint32 AS c, COUNT(*)\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]\ \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); @@ -324,11 +326,10 @@ fn push_down_filter_groupby_expr_contains_alias() { fn test_same_name_but_not_ambiguous() { let sql = "SELECT t1.col_int32 AS col_int32 FROM test t1 intersect SELECT col_int32 FROM test t2"; let plan = test_sql(sql).unwrap(); - let expected = "LeftSemi Join: col_int32 = t2.col_int32\ - \n Aggregate: groupBy=[[col_int32]], aggr=[[]]\ - \n Projection: t1.col_int32 AS col_int32\ - \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ + \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t1\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); @@ -341,13 +342,13 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let schema_provider = MySchemaProvider::default(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // hard code the return value of now() let ts = NaiveDateTime::from_timestamp_opt(1666615693, 0).unwrap(); - let now_time = DateTime::::from_utc(ts, Utc); + let now_time = DateTime::::from_naive_utc_and_offset(ts, Utc); let config = OptimizerContext::new() .with_skip_failing_rules(false) .with_query_execution_start_time(now_time); @@ -359,12 +360,12 @@ fn test_sql(sql: &str) -> Result { } #[derive(Default)] -struct MySchemaProvider { +struct MyContextProvider { options: ConfigOptions, } -impl ContextProvider for MySchemaProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.table(); if table_name.starts_with("test") { let schema = Schema::new_with_metadata( @@ -394,7 +395,7 @@ impl ContextProvider for MySchemaProvider { schema: Arc::new(schema), })) } else { - Err(DataFusionError::Plan("table does not exist".to_string())) + plan_err!("table does not exist") } } @@ -410,6 +411,10 @@ impl ContextProvider for MySchemaProvider { None } + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + fn options(&self) -> &ConfigOptions { &self.options } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b851c00edc2ba..d237c68657a1f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-physical-expr" description = "Physical expression implementation for DataFusion query engine" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -34,35 +34,37 @@ path = "src/lib.rs" [features] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] -default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] -# Enables support for non-scalar, binary operations on dictionaries -# Note: this results in significant additional codegen -dictionary_expressions = ["arrow/dyn_cmp_dict", "arrow/dyn_arith_dict"] +default = ["crypto_expressions", "regex_expressions", "unicode_expressions", "encoding_expressions", +] +encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } +base64 = { version = "0.21", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } -chrono = { version = "0.4.23", default-features = false } -datafusion-common = { path = "../common", version = "26.0.0" } -datafusion-expr = { path = "../expr", version = "26.0.0" } -datafusion-row = { path = "../row", version = "26.0.0" } +chrono = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } -indexmap = "1.9.2" -itertools = { version = "0.10", features = ["use_std"] } -lazy_static = { version = "^1.4.0" } -libc = "0.2.140" +hex = { version = "0.4", optional = true } +indexmap = { workspace = true } +itertools = { version = "0.12", features = ["use_std"] } +log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } paste = "^1.0" petgraph = "0.6.2" -rand = "0.8" +rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } @@ -70,8 +72,8 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.5" -rand = "0.8" -rstest = "0.17.0" +rand = { workspace = true } +rstest = { workspace = true } [[bench]] harness = false diff --git a/datafusion/physical-expr/README.md b/datafusion/physical-expr/README.md index a887d3eb29fe3..424256c77e7e2 100644 --- a/datafusion/physical-expr/README.md +++ b/datafusion/physical-expr/README.md @@ -19,7 +19,7 @@ # DataFusion Physical Expressions -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for physical expressions. diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index db017326083ab..90bfc5efb61e8 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -57,7 +57,7 @@ fn do_benches( .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Utf8(Some(random_string(&mut rng, string_length)))) + .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs b/datafusion/physical-expr/src/aggregate/approx_distinct.rs index b8922e1992482..b79a5611c334f 100644 --- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs @@ -29,8 +29,9 @@ use arrow::datatypes::{ ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; use std::any::Any; use std::convert::TryFrom; @@ -102,9 +103,9 @@ impl AggregateExpr for ApproxDistinct { DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), other => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Support for 'approx_distinct' for data type {other} is not implemented" - ))) + ) } }; Ok(accumulator) @@ -219,10 +220,9 @@ impl TryFrom<&ScalarValue> for HyperLogLog { if let ScalarValue::Binary(Some(slice)) = v { slice.as_slice().try_into() } else { - Err(DataFusionError::Internal( + internal_err!( "Impossibly got invalid scalar value while converting to HyperLogLog" - .into(), - )) + ) } } } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index f0a44cc97a66d..aa4749f64ae9c 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -27,9 +27,10 @@ use arrow::{ }, datatypes::{DataType, Field}, }; -use datafusion_common::DataFusionError; -use datafusion_common::Result; -use datafusion_common::{downcast_value, ScalarValue}; +use datafusion_common::{ + downcast_value, exec_err, internal_err, not_impl_err, plan_err, DataFusionError, + Result, ScalarValue, +}; use datafusion_expr::Accumulator; use std::{any::Any, iter, sync::Arc}; @@ -106,9 +107,9 @@ impl ApproxPercentileCont { } } other => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" - ))) + ) } }; Ok(accumulator) @@ -144,17 +145,17 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { let percentile = match lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, - got => return Err(DataFusionError::NotImplemented(format!( + got => return not_impl_err!( "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.get_datatype() - ))) + got.data_type() + ) }; // Ensure the percentile is between 0 and 1. if !(0.0..=1.0).contains(&percentile) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ))); + ); } Ok(percentile) } @@ -179,10 +180,10 @@ fn validate_input_max_size_expr(expr: &Arc) -> Result { ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return Err(DataFusionError::NotImplemented(format!( + got => return not_impl_err!( "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.get_datatype() - ))) + got.data_type() + ) }; Ok(max_size) } @@ -371,9 +372,9 @@ impl ApproxPercentileAccumulator { .filter_map(|v| v.try_as_f64().transpose()) .collect::>>()?) } - e => Err(DataFusionError::Internal(format!( + e => internal_err!( "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" - ))), + ), } } } @@ -393,9 +394,7 @@ impl Accumulator for ApproxPercentileAccumulator { fn evaluate(&self) -> Result { if self.digest.count() == 0.0 { - return Err(DataFusionError::Execution( - "aggregate function needs at least one non-null element".to_string(), - )); + return exec_err!("aggregate function needs at least one non-null element"); } let q = self.digest.estimate_quantile(self.percentile); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index ebf24750cb9f1..91d5c867d3125 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -22,8 +22,11 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -31,9 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -42,11 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, + expr, + nullable, } } } @@ -59,8 +69,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -74,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -102,7 +113,7 @@ impl PartialEq for ArrayAgg { #[derive(Debug)] pub(crate) struct ArrayAggAccumulator { - values: Vec, + values: Vec, datatype: DataType, } @@ -117,36 +128,29 @@ impl ArrayAggAccumulator { } impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let arr = &values[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.push(scalar); - Ok(()) - }) + let val = values[0].clone(); + self.values.push(val); + Ok(()) } + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } assert!(states.len() == 1, "array_agg states must be singleton!"); - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(values); - Ok(()) - } else { - Err(DataFusionError::Internal( - "array_agg state must be list!".into(), - )) - } - }) + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) } fn state(&self) -> Result> { @@ -154,15 +158,30 @@ impl Accumulator for ArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone()), - self.datatype.clone(), - )) + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values) + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + self.datatype.size() - std::mem::size_of_val(&self.datatype) } @@ -173,81 +192,110 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use arrow_array::Array; + use arrow_array::ListArray; + use arrow_buffer::OffsetBuffer; + use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - let list = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])]); + let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] fn array_agg_nested() -> Result<()> { - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len()]), + arrow::compute::concat(&[&a1])?, + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let list = ListArray::new( + Arc::new(Field::new("item", l1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([l1.len() + l2.len() + l3.len()]), + arrow::compute::concat(&[&l1, &l2, &l3])?, + None, ); + let list = ScalarValue::List(Arc::new(list)); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 2d7a6e5b0e4db..1efae424cc699 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -22,14 +22,13 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use arrow::array::{Array, ArrayRef}; +use arrow::array::ArrayRef; use std::collections::HashSet; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; /// Expression for a ARRAY_AGG(DISTINCT) aggregation. @@ -41,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -49,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, + expr, + nullable, } } } @@ -68,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -83,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -126,18 +130,16 @@ impl DistinctArrayAggAccumulator { impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - let arr = &values[0]; - for i in 0..arr.len() { - self.values.insert(ScalarValue::try_from_array(arr, i)?); + let array = &values[0]; + let scalars = ScalarValue::convert_array_to_scalar_vec(array)?; + for scalar in scalars { + self.values.extend(scalar) } Ok(()) } @@ -147,20 +149,24 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - for array in states { - for j in 0..array.len() { - self.values.insert(ScalarValue::try_from_array(array, j)?); - } + assert_eq!( + states.len(), + 1, + "array_agg_distinct states must contain single array" + ); + + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec { + self.values.extend(scalars) } Ok(()) } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )) + let values: Vec = self.values.iter().cloned().collect(); + let arr = ScalarValue::new_list(&values, &self.datatype); + Ok(ScalarValue::List(arr)) } fn size(&self) -> usize { @@ -173,12 +179,57 @@ impl Accumulator for DistinctArrayAggAccumulator { #[cfg(test)] mod tests { + use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::cast::as_list_array; + use arrow_array::types::Int32Type; + use arrow_array::{Array, ListArray}; + use arrow_buffer::OffsetBuffer; + use datafusion_common::utils::array_into_list_array; + use datafusion_common::{internal_err, DataFusionError}; + + // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. + fn sort_list_inner(arr: ScalarValue) -> ScalarValue { + let arr = match arr { + ScalarValue::List(arr) => { + let list_arr = as_list_array(&arr); + list_arr.value(0) + } + _ => { + panic!("Expected ScalarValue::List, got {:?}", arr) + } + }; + + let arr = arrow::compute::sort(&arr, None).unwrap(); + let list_arr = array_into_list_array(arr); + ScalarValue::List(Arc::new(list_arr)) + } + + fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> { + let actual = sort_list_inner(actual); + + match (&expected, &actual) { + (ScalarValue::List(arr1), ScalarValue::List(arr2)) => { + if arr1.eq(arr2) { + Ok(()) + } else { + internal_err!( + "Actual value {:?} not found in expected values {:?}", + actual, + expected + ) + } + } + _ => { + internal_err!("Expected scalar lists as inputs") + } + } + } fn check_distinct_array_agg( input: ArrayRef, @@ -192,103 +243,147 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; - match (expected, actual) { - (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - - e.sort_by(cmp); - a.sort_by(cmp); - // Check that the inputs are the same - assert_eq!(e, a); - } - _ => { - unreachable!() - } - } + compare_list_contents(expected, actual) + } - Ok(()) + fn check_merge_distinct_array_agg( + input1: ArrayRef, + input2: ArrayRef, + expected: ScalarValue, + datatype: DataType, + ) -> Result<()> { + let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); + let agg = Arc::new(DistinctArrayAgg::new( + col("a", &schema)?, + "bla".to_string(), + datatype, + true, + )); + + let mut accum1 = agg.create_accumulator()?; + let mut accum2 = agg.create_accumulator()?; + + accum1.update_batch(&[input1])?; + accum2.update_batch(&[input2])?; + + let array = accum2.state()?[0].raw_data()?; + accum1.merge_batch(&[array])?; + + let actual = accum1.evaluate()?; + + compare_list_contents(expected, actual) } #[test] fn distinct_array_agg_i32() -> Result<()> { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(4), + Some(5), + Some(7), + ])]), + )); + + check_distinct_array_agg(col, expected, DataType::Int32) + } - let out = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); - - check_distinct_array_agg(col, out, DataType::Int32) + #[test] + fn merge_distinct_array_agg_i32() -> Result<()> { + let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); + let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); + + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) } #[test] fn distinct_array_agg_nested() -> Result<()> { // [[1, 2, 3], [4, 5]] - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[6], [7, 8]] - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[9]] - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])]), + )); + check_distinct_array_agg( array, - list, + expected, DataType::List(Arc::new(Field::new_list( "item", Field::new("item", DataType::Int32, true), @@ -296,4 +391,70 @@ mod tests { ))), ) } + + #[test] + fn merge_distinct_array_agg_nested() -> Result<()> { + // [[1, 2], [3, 4]] + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(3), + Some(4), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, + ); + + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, + ); + + // [[6, 7], [8]] + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(6), + Some(7), + ])]); + let a2 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, + ); + + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); + + // Duplicate l1 in the input array and check that it is deduped in the output. + let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap(); + let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap(); + + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) + } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs new file mode 100644 index 0000000000000..9ca83a781a013 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -0,0 +1,653 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions which specify ordering requirement +//! that can evaluated at runtime during query execution + +use std::any::Any; +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use arrow_schema::{Fields, SortOptions}; +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; + +use itertools::izip; + +/// Expression for a ARRAY_AGG(ORDER BY) aggregation. +/// When aggregation works in multiple partitions +/// aggregations are split into multiple partitions, +/// then their results are merged. This aggregator +/// is a version of ARRAY_AGG that can support producing +/// intermediate aggregation (with necessary side information) +/// and that can merge aggregations from multiple partitions. +#[derive(Debug)] +pub struct OrderSensitiveArrayAgg { + /// Column name + name: String, + /// The DataType for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement + ordering_req: LexOrdering, +} + +impl OrderSensitiveArrayAgg { + /// Create a new `OrderSensitiveArrayAgg` aggregate function + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + nullable: bool, + order_by_data_types: Vec, + ordering_req: LexOrdering, + ) -> Self { + Self { + name: name.into(), + input_data_type, + expr, + nullable, + order_by_data_types, + ordering_req, + } + } +} + +impl AggregateExpr for OrderSensitiveArrayAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new_list( + &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg + Field::new("item", self.input_data_type.clone(), true), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(OrderSensitiveArrayAggAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) + } + + fn state_fields(&self) -> Result> { + let mut fields = vec![Field::new_list( + format_state_name(&self.name, "array_agg"), + Field::new("item", self.input_data_type.clone(), true), + self.nullable, // This should be the same as field() + )]; + let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); + fields.push(Field::new_list( + format_state_name(&self.name, "array_agg_orderings"), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), + self.nullable, + )); + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + if self.ordering_req.is_empty() { + None + } else { + Some(&self.ordering_req) + } + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for OrderSensitiveArrayAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct OrderSensitiveArrayAggAccumulator { + // `values` stores entries in the ARRAY_AGG result. + values: Vec, + // `ordering_values` stores values of ordering requirement expression + // corresponding to each value in the ARRAY_AGG. + // For each `ScalarValue` inside `values`, there will be a corresponding + // `Vec` inside `ordering_values` which stores it ordering. + // This information is used during merging results of the different partitions. + // For detailed information how merging is done see [`merge_ordered_arrays`] + ordering_values: Vec>, + // `datatypes` stores, datatype of expression inside ARRAY_AGG and ordering requirement expressions. + datatypes: Vec, + // Stores ordering requirement of the Accumulator + ordering_req: LexOrdering, +} + +impl OrderSensitiveArrayAggAccumulator { + /// Create a new order-sensitive ARRAY_AGG accumulator based on the given + /// item data type. + pub fn try_new( + datatype: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { + let mut datatypes = vec![datatype.clone()]; + datatypes.extend(ordering_dtypes.iter().cloned()); + Ok(Self { + values: vec![], + ordering_values: vec![], + datatypes, + ordering_req, + }) + } +} + +impl Accumulator for OrderSensitiveArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let n_row = values[0].len(); + for index in 0..n_row { + let row = get_row_at_idx(values, index)?; + self.values.push(row[0].clone()); + self.ordering_values.push(row[1..].to_vec()); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + // First entry in the state is the aggregation result. + let array_agg_values = &states[0]; + // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside ARRAY_AGG list. + // For each `StructArray` inside ARRAY_AGG list, we will receive an `Array` that stores + // values received from its ordering requirement expression. (This information is necessary for during merging). + let agg_orderings = &states[1]; + + if as_list_array(agg_orderings).is_ok() { + // Stores ARRAY_AGG results coming from each partition + let mut partition_values = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values = vec![]; + + // Existing values should be merged also. + partition_values.push(self.values.clone()); + partition_ordering_values.push(self.ordering_values.clone()); + + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + // Ordering requirement expression values for each entry in the ARRAY_AGG list + let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?; + for v in other_ordering_values.into_iter() { + partition_ordering_values.push(v); + } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &partition_values, + &partition_ordering_values, + &sort_options, + )?; + self.values = new_values; + self.ordering_values = new_orderings; + } else { + return exec_err!("Expects to receive a list array"); + } + Ok(()) + } + + fn state(&self) -> Result> { + let mut result = vec![self.evaluate()?]; + result.push(self.evaluate_orderings()?); + Ok(result) + } + + fn evaluate(&self) -> Result { + let arr = ScalarValue::new_list(&self.values, &self.datatypes[0]); + Ok(ScalarValue::List(arr)) + } + + fn size(&self) -> usize { + let mut total = std::mem::size_of_val(self) + + ScalarValue::size_of_vec(&self.values) + - std::mem::size_of_val(&self.values); + + // Add size of the `self.ordering_values` + total += + std::mem::size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + } + + // Add size of the `self.datatypes` + total += std::mem::size_of::() * self.datatypes.capacity(); + for dtype in &self.datatypes { + total += dtype.size() - std::mem::size_of_val(dtype); + } + + // Add size of the `self.ordering_req` + total += std::mem::size_of::() * self.ordering_req.capacity(); + // TODO: Calculate size of each `PhysicalSortExpr` more accurately. + total + } +} + +impl OrderSensitiveArrayAggAccumulator { + /// Inner Vec\ in the ordering_values can be thought as ordering information for the each ScalarValue in the values array. + /// See [`merge_ordered_arrays`] for more information. + fn convert_array_agg_to_orderings( + &self, + array_agg: Vec>, + ) -> Result>>> { + let mut orderings = vec![]; + // in_data is Vec where ScalarValue does not include ScalarValue::List + for in_data in array_agg.into_iter() { + let ordering = in_data.into_iter().map(|struct_vals| { + if let ScalarValue::Struct(Some(orderings), _) = struct_vals { + Ok(orderings) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", + struct_vals.data_type() + ) + } + }).collect::>>()?; + orderings.push(ordering); + } + Ok(orderings) + } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let struct_field = Fields::from(fields.clone()); + let orderings: Vec = self + .ordering_values + .iter() + .map(|ordering| { + ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) + }) + .collect(); + let struct_type = DataType::Struct(Fields::from(fields)); + + let arr = ScalarValue::new_list(&orderings, &struct_type); + Ok(ScalarValue::List(arr)) + } +} + +/// This is a wrapper struct to be able to correctly merge ARRAY_AGG +/// data from multiple partitions using `BinaryHeap`. +/// When used inside `BinaryHeap` this struct returns smallest `CustomElement`, +/// where smallest is determined by `ordering` values (`Vec`) +/// according to `sort_options` +#[derive(Debug, PartialEq, Eq)] +struct CustomElement<'a> { + // Stores from which partition entry is received + branch_idx: usize, + // values to be merged + value: ScalarValue, + // according to `ordering` values, comparisons will be done. + ordering: Vec, + // `sort_options` defines, desired ordering by the user + sort_options: &'a [SortOptions], +} + +impl<'a> CustomElement<'a> { + fn new( + branch_idx: usize, + value: ScalarValue, + ordering: Vec, + sort_options: &'a [SortOptions], + ) -> Self { + Self { + branch_idx, + value, + ordering, + sort_options, + } + } + + fn ordering( + &self, + current: &[ScalarValue], + target: &[ScalarValue], + ) -> Result { + // Calculate ordering according to `sort_options` + compare_rows(current, target, self.sort_options) + } +} + +// Overwrite ordering implementation such that +// - `self.ordering` values are used for comparison, +// - When used inside `BinaryHeap` it is a min-heap. +impl<'a> Ord for CustomElement<'a> { + fn cmp(&self, other: &Self) -> Ordering { + // Compares according to custom ordering + self.ordering(&self.ordering, &other.ordering) + // Convert max heap to min heap + .map(|ordering| ordering.reverse()) + // This function return error, when `self.ordering` and `other.ordering` + // have different types (such as one is `ScalarValue::Int64`, other is `ScalarValue::Float32`) + // Here this case won't happen, because data from each partition will have same type + .unwrap() + } +} + +impl<'a> PartialOrd for CustomElement<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// This functions merges `values` array (`&[Vec]`) into single array `Vec` +/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) +/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the +/// each `ScalarValue` in the `values` array. +/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` +/// of the `ordering_values` array). +/// +/// As an example +/// values can be \[ +/// \[1, 2, 3, 4, 5\], +/// \[1, 2, 3, 4\], +/// \[1, 2, 3, 4, 5, 6\], +/// \] +/// In this case we will be merging three arrays (doesn't have to be same size) +/// and produce a merged array with size 15 (sum of 5+4+6) +/// Merging will be done according to ordering at `ordering_values` vector. +/// As an example `ordering_values` can be [ +/// \[(1, a), (2, b), (3, b), (4, a), (5, b) \], +/// \[(1, a), (2, b), (3, b), (4, a) \], +/// \[(1, b), (2, c), (3, d), (4, e), (5, a), (6, b) \], +/// ] +/// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) +/// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. +/// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) +fn merge_ordered_arrays( + // We will merge values into single `Vec`. + values: &[Vec], + // `values` will be merged according to `ordering_values`. + // Inner `Vec` can be thought as ordering information for the + // each `ScalarValue` in the values`. + ordering_values: &[Vec>], + // Defines according to which ordering comparisons should be done. + sort_options: &[SortOptions], +) -> Result<(Vec, Vec>)> { + // Keep track the most recent data of each branch, in binary heap data structure. + let mut heap: BinaryHeap = BinaryHeap::new(); + + if !(values.len() == ordering_values.len() + && values + .iter() + .zip(ordering_values.iter()) + .all(|(vals, ordering_vals)| vals.len() == ordering_vals.len())) + { + return exec_err!( + "Expects values arguments and/or ordering_values arguments to have same size" + ); + } + let n_branch = values.len(); + // For each branch we keep track of indices of next will be merged entry + let mut indices = vec![0_usize; n_branch]; + // Keep track of sizes of each branch. + let end_indices = (0..n_branch) + .map(|idx| values[idx].len()) + .collect::>(); + let mut merged_values = vec![]; + let mut merged_orderings = vec![]; + // Continue iterating the loop until consuming data of all branches. + loop { + let min_elem = if let Some(min_elem) = heap.pop() { + min_elem + } else { + // Heap is empty, fill it with the next entries from each branch. + for (idx, end_idx, ordering, branch_index) in izip!( + indices.iter(), + end_indices.iter(), + ordering_values.iter(), + 0..n_branch + ) { + // We consumed this branch, skip it + if idx == end_idx { + continue; + } + + // Push the next element to the heap. + let elem = CustomElement::new( + branch_index, + values[branch_index][*idx].clone(), + ordering[*idx].to_vec(), + sort_options, + ); + heap.push(elem); + } + // Now we have filled the heap, get the largest entry (this will be the next element in merge) + if let Some(min_elem) = heap.pop() { + min_elem + } else { + // Heap is empty, this means that all indices are same with end_indices. e.g + // We have consumed all of the branches. Merging is completed + // Exit from the loop + break; + } + }; + let branch_idx = min_elem.branch_idx; + // Increment the index of merged branch, + indices[branch_idx] += 1; + let row_idx = indices[branch_idx]; + merged_values.push(min_elem.value.clone()); + merged_orderings.push(min_elem.ordering.clone()); + if row_idx < end_indices[branch_idx] { + // Push next entry in the most recently consumed branch to the heap + // If there is an available entry + let value = values[branch_idx][row_idx].clone(); + let ordering_row = ordering_values[branch_idx][row_idx].to_vec(); + let elem = CustomElement::new(branch_idx, value, ordering_row, sort_options); + heap.push(elem); + } + } + + Ok((merged_values, merged_orderings)) +} + +#[cfg(test)] +mod tests { + use crate::aggregate::array_agg_ordered::merge_ordered_arrays; + use arrow_array::{Array, ArrayRef, Int64Array}; + use arrow_schema::SortOptions; + use datafusion_common::utils::get_row_at_idx; + use datafusion_common::{Result, ScalarValue}; + use std::sync::Arc; + + #[test] + fn test_merge_asc() -> Result<()> { + let lhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), + ]; + let n_row = lhs_arrays[0].len(); + let lhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&lhs_arrays, idx)) + .collect::>>()?; + + let rhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), + ]; + let n_row = rhs_arrays[0].len(); + let rhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&rhs_arrays, idx)) + .collect::>>()?; + let sort_options = vec![ + SortOptions { + descending: false, + nulls_first: false, + }, + SortOptions { + descending: false, + nulls_first: false, + }, + ]; + + let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; + let lhs_vals = (0..lhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) + .collect::>>()?; + + let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; + let rhs_vals = (0..rhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) + .collect::>>()?; + let expected = + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef, + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef, + ]; + + let (merged_vals, merged_ts) = merge_ordered_arrays( + &[lhs_vals, rhs_vals], + &[lhs_orderings, rhs_orderings], + &sort_options, + )?; + let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; + + assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); + + Ok(()) + } + + #[test] + fn test_merge_desc() -> Result<()> { + let lhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), + ]; + let n_row = lhs_arrays[0].len(); + let lhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&lhs_arrays, idx)) + .collect::>>()?; + + let rhs_arrays: Vec = vec![ + Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), + ]; + let n_row = rhs_arrays[0].len(); + let rhs_orderings = (0..n_row) + .map(|idx| get_row_at_idx(&rhs_arrays, idx)) + .collect::>>()?; + let sort_options = vec![ + SortOptions { + descending: true, + nulls_first: false, + }, + SortOptions { + descending: true, + nulls_first: false, + }, + ]; + + // Values (which will be merged) doesn't have to be ordered. + let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; + let lhs_vals = (0..lhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) + .collect::>>()?; + + let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; + let rhs_vals = (0..rhs_vals_arr.len()) + .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) + .collect::>>()?; + let expected = + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef, + Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, + ]; + let (merged_vals, merged_ts) = merge_ordered_arrays( + &[lhs_vals, rhs_vals], + &[lhs_orderings, rhs_orderings], + &sort_options, + )?; + let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; + + assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 2fe44602d831a..91f2fb952dcea 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -17,39 +17,42 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use arrow::array::{AsArray, PrimitiveBuilder}; +use log::debug; + use std::any::Any; -use std::convert::TryFrom; +use std::fmt::Debug; use std::sync::Arc; -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; -use crate::aggregate::sum; -use crate::aggregate::sum::sum_batch; -use crate::aggregate::utils::calculate_result_decimal_for_avg; +use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::compute; -use arrow::datatypes::DataType; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use arrow::compute::sum; +use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; use arrow::{ array::{ArrayRef, UInt64Array}, datatypes::Field, }; -use arrow_array::Array; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; +use arrow_array::types::{Decimal256Type, DecimalType}; +use arrow_array::{ + Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, +}; +use arrow_buffer::{i256, ArrowNativeType}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::type_coercion::aggregates::avg_return_type; use datafusion_expr::Accumulator; -use datafusion_row::accessor::RowAccessor; + +use super::groups_accumulator::EmitTo; +use super::utils::DecimalAverager; /// AVG aggregate expression #[derive(Debug, Clone)] pub struct Avg { name: String, expr: Arc, - pub sum_data_type: DataType, - rt_data_type: DataType, - pub pre_cast_to_sum_type: bool, + input_data_type: DataType, + result_data_type: DataType, } impl Avg { @@ -57,34 +60,15 @@ impl Avg { pub fn new( expr: Arc, name: impl Into, - sum_data_type: DataType, + data_type: DataType, ) -> Self { - Self::new_with_pre_cast(expr, name, sum_data_type.clone(), sum_data_type, false) - } + let result_data_type = avg_return_type(&data_type).unwrap(); - pub fn new_with_pre_cast( - expr: Arc, - name: impl Into, - sum_data_type: DataType, - rt_data_type: DataType, - cast_to_sum_type: bool, - ) -> Self { - // the internal sum data type of avg just support FLOAT64 and Decimal data type. - assert!(matches!( - sum_data_type, - DataType::Float64 | DataType::Decimal128(_, _) - )); - // the result of avg just support FLOAT64 and Decimal data type. - assert!(matches!( - rt_data_type, - DataType::Float64 | DataType::Decimal128(_, _) - )); Self { name: name.into(), expr, - sum_data_type, - rt_data_type, - pre_cast_to_sum_type: cast_to_sum_type, + input_data_type: data_type, + result_data_type, } } } @@ -96,15 +80,43 @@ impl AggregateExpr for Avg { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.rt_data_type.clone(), true)) + Ok(Field::new(&self.name, self.result_data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 or decimal - &self.sum_data_type, - &self.rt_data_type, - )?)) + use DataType::*; + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => Ok(Box::::default()), + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + + ( + Decimal256(sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + _ => not_impl_err!( + "AvgAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } } fn state_fields(&self) -> Result> { @@ -116,7 +128,7 @@ impl AggregateExpr for Avg { ), Field::new( format_state_name(&self.name, "sum"), - self.sum_data_type.clone(), + self.input_data_type.clone(), true, ), ]) @@ -130,34 +142,78 @@ impl AggregateExpr for Avg { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.sum_data_type) + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) } - fn supports_bounded_execution(&self) -> bool { - true + fn create_sliding_accumulator(&self) -> Result> { + self.create_accumulator() } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(AvgRowAccumulator::new( - start_index, - &self.sum_data_type, - &self.rt_data_type, - ))) - } + fn groups_accumulator_supported(&self) -> bool { + use DataType::*; - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) + matches!(&self.result_data_type, Float64 | Decimal128(_, _)) } - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - &self.sum_data_type, - &self.rt_data_type, - )?)) + fn create_groups_accumulator(&self) -> Result> { + use DataType::*; + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => { + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + &self.result_data_type, + |sum: f64, count: u64| Ok(sum / count as f64), + ))) + } + ( + Decimal128(_sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + &self.result_data_type, + avg_fn, + ))) + } + + ( + Decimal256(_sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = move |sum: i256, count: u64| { + decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap()) + }; + + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + &self.result_data_type, + avg_fn, + ))) + } + + _ => not_impl_err!( + "AvgGroupsAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } } } @@ -167,8 +223,8 @@ impl PartialEq for Avg { .downcast_ref::() .map(|x| { self.name == x.name - && self.sum_data_type == x.sum_data_type - && self.rt_data_type == x.rt_data_type + && self.input_data_type == x.input_data_type + && self.result_data_type == x.result_data_type && self.expr.eq(&x.expr) }) .unwrap_or(false) @@ -176,276 +232,411 @@ impl PartialEq for Avg { } /// An accumulator to compute the average -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AvgAccumulator { - // sum is used for null - sum: ScalarValue, - sum_data_type: DataType, - return_data_type: DataType, + sum: Option, count: u64, } -impl AvgAccumulator { - /// Creates a new `AvgAccumulator` - pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result { - Ok(Self { - sum: ScalarValue::try_from(datatype)?, - sum_data_type: datatype.clone(), - return_data_type: return_data_type.clone(), - count: 0, - }) +impl Accumulator for AvgAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Float64(self.sum), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) } } -impl Accumulator for AvgAccumulator { +/// An accumulator to compute the average for decimals +struct DecimalAvgAccumulator { + sum: Option, + count: u64, + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, +} + +impl Debug for DecimalAvgAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DecimalAvgAccumulator") + .field("sum", &self.sum) + .field("count", &self.count) + .field("sum_scale", &self.sum_scale) + .field("sum_precision", &self.sum_precision) + .field("target_precision", &self.target_precision) + .field("target_scale", &self.target_scale) + .finish() + } +} + +impl Accumulator for DecimalAvgAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::new_primitive::( + self.sum, + &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), + )?, + ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; - self.sum = self - .sum - .add(&sum::sum_batch(values, &self.sum_data_type)?)?; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); + } Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = values[0].as_primitive::(); self.count -= (values.len() - values.null_count()) as u64; - let delta = sum_batch(values, &self.sum.get_datatype())?; - self.sum = self.sum.sub(&delta)?; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap().sub_wrapping(x)); + } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); // counts are summed - self.count += compute::sum(counts).unwrap_or(0); + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); // sums are summed - self.sum = self - .sum - .add(&sum::sum_batch(&states[1], &self.sum_data_type)?)?; + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); + } Ok(()) } fn evaluate(&self) -> Result { - match self.sum { - ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) - } - ScalarValue::Decimal128(value, precision, scale) => { - Ok(match value { - None => ScalarValue::Decimal128(None, precision, scale), - Some(value) => { - // now the sum_type and return type is not the same, need to convert the sum type to return type - calculate_result_decimal_for_avg( - value, - self.count as i128, - scale, - &self.return_data_type, - )? - } - }) - } - _ => Err(DataFusionError::Internal( - "Sum should be f64 or decimal128 on average".to_string(), - )), - } + let v = self + .sum + .map(|v| { + DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )? + .avg(v, T::Native::from_usize(self.count as usize).unwrap()) + }) + .transpose()?; + + ScalarValue::new_primitive::( + v, + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ) + } + fn supports_retract_batch(&self) -> bool { + true } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + std::mem::size_of_val(self) } } +/// An accumulator to compute the average of `[PrimitiveArray]`. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calculates the average value from a sum of +/// T::Native and a total count #[derive(Debug)] -struct AvgRowAccumulator { - state_index: usize, - sum_datatype: DataType, +struct AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ + /// The type of the internal sum + sum_data_type: DataType, + + /// The type of the returned sum return_data_type: DataType, + + /// Count per group (use u64 to make UInt64Array) + counts: Vec, + + /// Sums per group, stored as the native type + sums: Vec, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the final average (value / count) + avg_fn: F, } -impl AvgRowAccumulator { - pub fn new( - start_index: usize, - sum_datatype: &DataType, - return_data_type: &DataType, - ) -> Self { +impl AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ + pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + debug!( + "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", + std::any::type_name::() + ); + Self { - state_index: start_index, - sum_datatype: sum_datatype.clone(), return_data_type: return_data_type.clone(), + sum_data_type: sum_data_type.clone(), + counts: vec![], + sums: vec![], + null_state: NullState::new(), + avg_fn, } } } -impl RowAccumulator for AvgRowAccumulator { +impl GroupsAccumulator for AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result + Send, +{ fn update_batch( &mut self, values: &[ArrayRef], - accessor: &mut RowAccessor, + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, ) -> Result<()> { - let values = &values[0]; - // count - let delta = (values.len() - values.null_count()) as u64; - accessor.add_u64(self.state_index(), delta); - - // sum - sum::add_to_row( - self.state_index() + 1, - accessor, - &sum::sum_batch(values, &self.sum_datatype)?, - ) - } + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, T::default_value()); + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + + self.counts[group_index] += 1; + }, + ); - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - sum::update_avg_to_row(self.state_index(), accessor, value) + Ok(()) } - fn update_scalar( + fn merge_batch( &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, ) -> Result<()> { - sum::update_avg_to_row(self.state_index(), accessor, value) + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + let partial_sums = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + self.null_state.accumulate( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ); + + // update sums + self.sums.resize(total_num_groups, T::default_value()); + self.null_state.accumulate( + group_indices, + partial_sums, + opt_filter, + total_num_groups, + |group_index, new_value: ::Native| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ); + + Ok(()) } - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - // count - let delta = compute::sum(counts).unwrap_or(0); - accessor.add_u64(self.state_index(), delta); - - // sum - let difference = sum::sum_batch(&states[1], &self.sum_datatype)?; - sum::add_to_row(self.state_index() + 1, accessor, &difference) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - match self.sum_datatype { - DataType::Decimal128(p, s) => { - match accessor.get_u64_opt(self.state_index()) { - None => Ok(ScalarValue::Decimal128(None, p, s)), - Some(0) => Ok(ScalarValue::Decimal128(None, p, s)), - Some(n) => { - // now the sum_type and return type is not the same, need to convert the sum type to return type - accessor.get_i128_opt(self.state_index() + 1).map_or_else( - || Ok(ScalarValue::Decimal128(None, p, s)), - |f| { - calculate_result_decimal_for_avg( - f, - n as i128, - s, - &self.return_data_type, - ) - }, - ) - } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + let sums = emit_to.take_needed(&mut self.sums); + let nulls = self.null_state.build(emit_to); + + assert_eq!(nulls.len(), sums.len()); + assert_eq!(counts.len(), sums.len()); + + // don't evaluate averages with null inputs to avoid errors on null values + + let array: PrimitiveArray = if nulls.null_count() > 0 { + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); + let iter = sums.into_iter().zip(counts).zip(nulls.iter()); + + for ((sum, count), is_valid) in iter { + if is_valid { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); } } - DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) { - None => ScalarValue::Float64(None), - Some(0) => ScalarValue::Float64(None), - Some(n) => ScalarValue::Float64( - accessor - .get_f64_opt(self.state_index() + 1) - .map(|f| f / n as f64), - ), - }), - _ => Err(DataFusionError::Internal( - "Sum should be f64 or decimal128 on average".to_string(), - )), - } + builder.finish() + } else { + let averages: Vec = sums + .into_iter() + .zip(counts.into_iter()) + .map(|(sum, count)| (self.avg_fn)(sum, count)) + .collect::>>()?; + PrimitiveArray::new(averages.into(), Some(nulls)) // no copy + .with_data_type(self.return_data_type.clone()) + }; + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let nulls = self.null_state.build(emit_to); + let nulls = Some(nulls); + + let counts = emit_to.take_needed(&mut self.counts); + let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy + + let sums = emit_to.take_needed(&mut self.sums); + let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy + .with_data_type(self.sum_data_type.clone()); + + Ok(vec![ + Arc::new(counts) as ArrayRef, + Arc::new(sums) as ArrayRef, + ]) } - #[inline(always)] - fn state_index(&self) -> usize { - self.state_index + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + + self.sums.capacity() * std::mem::size_of::() } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; - use datafusion_common::Result; + use crate::expressions::tests::assert_aggregate; + use arrow::array::*; + use datafusion_expr::AggregateFunction; #[test] - fn avg_decimal() -> Result<()> { + fn avg_decimal() { // test agg let array: ArrayRef = Arc::new( (1..7) .map(Some) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( + assert_aggregate( array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(Some(35000), 14, 4) - ) + AggregateFunction::Avg, + false, + ScalarValue::Decimal128(Some(35000), 14, 4), + ); } #[test] - fn avg_decimal_with_nulls() -> Result<()> { + fn avg_decimal_with_nulls() { let array: ArrayRef = Arc::new( (1..6) .map(|i| if i == 2 { None } else { Some(i) }) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( + assert_aggregate( array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(Some(32500), 14, 4) - ) + AggregateFunction::Avg, + false, + ScalarValue::Decimal128(Some(32500), 14, 4), + ); } #[test] - fn avg_decimal_all_nulls() -> Result<()> { + fn avg_decimal_all_nulls() { // test agg let array: ArrayRef = Arc::new( std::iter::repeat::>(None) .take(6) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( + assert_aggregate( array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(None, 14, 4) - ) + AggregateFunction::Avg, + false, + ScalarValue::Decimal128(None, 14, 4), + ); } #[test] - fn avg_i32() -> Result<()> { + fn avg_i32() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3_f64)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } #[test] - fn avg_i32_with_nulls() -> Result<()> { + fn avg_i32_with_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ Some(1), None, @@ -453,33 +644,33 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3.25f64)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.25f64)); } #[test] - fn avg_i32_all_nulls() -> Result<()> { + fn avg_i32_all_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::Float64(None)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::Float64(None)); } #[test] - fn avg_u32() -> Result<()> { + fn avg_u32() { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Avg, ScalarValue::from(3.0f64)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.0f64)); } #[test] - fn avg_f32() -> Result<()> { + fn avg_f32() { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Avg, ScalarValue::from(3_f64)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } #[test] - fn avg_f64() -> Result<()> { + fn avg_f64() { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Avg, ScalarValue::from(3_f64)) + assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); } } diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 4bbe563edce89..6c97d620616a9 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -15,171 +15,28 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines BitAnd, BitOr, and BitXor Aggregate accumulators use ahash::RandomState; +use datafusion_common::cast::as_list_array; use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::datatypes::DataType; -use arrow::{ - array::{ - ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::Field, -}; -use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use arrow::{array::ArrayRef, datatypes::Field}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::collections::HashSet; -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; +use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::compute::{bit_and, bit_or, bit_xor}; -use datafusion_row::accessor::RowAccessor; - -// returns the new value after bit_and/bit_or/bit_xor with the new values, taking nullability into account -macro_rules! typed_bit_and_or_xor_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let delta = $OP(array); - Ok(ScalarValue::$SCALAR(delta)) - }}; -} - -// bit_and/bit_or/bit_xor the array and returns a ScalarValue of its corresponding type. -macro_rules! bit_and_or_xor_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Int64 => { - typed_bit_and_or_xor_batch!($VALUES, Int64Array, Int64, $OP) - } - DataType::Int32 => { - typed_bit_and_or_xor_batch!($VALUES, Int32Array, Int32, $OP) - } - DataType::Int16 => { - typed_bit_and_or_xor_batch!($VALUES, Int16Array, Int16, $OP) - } - DataType::Int8 => { - typed_bit_and_or_xor_batch!($VALUES, Int8Array, Int8, $OP) - } - DataType::UInt64 => { - typed_bit_and_or_xor_batch!($VALUES, UInt64Array, UInt64, $OP) - } - DataType::UInt32 => { - typed_bit_and_or_xor_batch!($VALUES, UInt32Array, UInt32, $OP) - } - DataType::UInt16 => { - typed_bit_and_or_xor_batch!($VALUES, UInt16Array, UInt16, $OP) - } - DataType::UInt8 => { - typed_bit_and_or_xor_batch!($VALUES, UInt8Array, UInt8, $OP) - } - e => { - return Err(DataFusionError::Internal(format!( - "Bit and/Bit or/Bit xor is not expected to receive the type {e:?}" - ))); - } - } - }}; -} - -/// dynamically-typed bit_and(array) -> ScalarValue -fn bit_and_batch(values: &ArrayRef) -> Result { - bit_and_or_xor_batch!(values, bit_and) -} - -/// dynamically-typed bit_or(array) -> ScalarValue -fn bit_or_batch(values: &ArrayRef) -> Result { - bit_and_or_xor_batch!(values, bit_or) -} - -/// dynamically-typed bit_xor(array) -> ScalarValue -fn bit_xor_batch(values: &ArrayRef) -> Result { - bit_and_or_xor_batch!(values, bit_xor) -} - -// bit_and/bit_or/bit_xor of two scalar values. -macro_rules! typed_bit_and_or_xor_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $TYPE:ident, $OP:ident) => {{ - paste::item! { - match $SCALAR { - None => {} - Some(v) => $ACC.[<$OP _ $TYPE>]($INDEX, *v as $TYPE) - } - } - }}; -} - -macro_rules! bit_and_or_xor_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $OP:ident) => {{ - Ok(match $SCALAR { - ScalarValue::UInt64(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, u64, $OP) - } - ScalarValue::UInt32(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, u32, $OP) - } - ScalarValue::UInt16(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, u16, $OP) - } - ScalarValue::UInt8(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, u8, $OP) - } - ScalarValue::Int64(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, i64, $OP) - } - ScalarValue::Int32(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, i32, $OP) - } - ScalarValue::Int16(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, i16, $OP) - } - ScalarValue::Int8(rhs) => { - typed_bit_and_or_xor_v2!($INDEX, $ACC, rhs, i8, $OP) - } - ScalarValue::Null => { - // do nothing - } - e => { - return Err(DataFusionError::Internal(format!( - "BIT AND/BIT OR/BIT XOR is not expected to receive scalars of incompatible types {:?}", - e - ))) - } - }) - }}; -} - -pub fn bit_and_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - bit_and_or_xor_v2!(index, accessor, s, bitand) -} - -pub fn bit_or_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - bit_and_or_xor_v2!(index, accessor, s, bitor) -} - -pub fn bit_xor_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - bit_and_or_xor_v2!(index, accessor, s, bitxor) -} +use arrow_array::cast::AsArray; +use arrow_array::{downcast_integer, ArrowNumericType}; +use arrow_buffer::ArrowNativeType; /// BIT_AND aggregate expression #[derive(Debug, Clone)] @@ -221,7 +78,19 @@ impl AggregateExpr for BitAnd { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(BitAndAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty) => { + Ok(Box::>::default()) + }; + } + downcast_integer! { + &self.data_type => (helper), + _ => Err(DataFusionError::NotImplemented(format!( + "BitAndAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -240,18 +109,34 @@ impl AggregateExpr for BitAnd { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(BitAndRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitAndAssign; + + // Note the default value for BitAnd should be all set, i.e. `!0` + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| { + x.bitand_assign(y) + }) + .with_starting_value(!0), + )) + }; + } + + let data_type = &self.data_type; + downcast_integer! { + data_type => (helper, data_type), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -273,25 +158,31 @@ impl PartialEq for BitAnd { } } -#[derive(Debug)] -struct BitAndAccumulator { - bit_and: ScalarValue, +struct BitAndAccumulator { + value: Option, } -impl BitAndAccumulator { - /// new bit_and accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - bit_and: ScalarValue::try_from(data_type)?, - }) +impl std::fmt::Debug for BitAndAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitAndAccumulator({})", T::DATA_TYPE) } } -impl Accumulator for BitAndAccumulator { +impl Default for BitAndAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitAndAccumulator +where + T::Native: std::ops::BitAnd, +{ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &bit_and_batch(values)?; - self.bit_and = self.bit_and.bitand(delta)?; + if let Some(x) = bit_and(values[0].as_primitive::()) { + let v = self.value.get_or_insert(x); + *v = *v & x; + } Ok(()) } @@ -300,74 +191,15 @@ impl Accumulator for BitAndAccumulator { } fn state(&self) -> Result> { - Ok(vec![self.bit_and.clone()]) + Ok(vec![self.evaluate()?]) } fn evaluate(&self) -> Result { - Ok(self.bit_and.clone()) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_and) - + self.bit_and.size() - } -} - -#[derive(Debug)] -struct BitAndRowAccumulator { - index: usize, - datatype: DataType, -} - -impl BitAndRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } - } -} - -impl RowAccumulator for BitAndRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &bit_and_batch(values)?; - bit_and_row(self.index, accessor, delta) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - bit_and_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - bit_and_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) - } - - #[inline(always)] - fn state_index(&self) -> usize { - self.index + std::mem::size_of_val(self) } } @@ -411,7 +243,19 @@ impl AggregateExpr for BitOr { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(BitOrAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty) => { + Ok(Box::>::default()) + }; + } + downcast_integer! { + &self.data_type => (helper), + _ => Err(DataFusionError::NotImplemented(format!( + "BitOrAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -430,18 +274,30 @@ impl AggregateExpr for BitOr { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(BitOrRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitOrAssign; + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + $dt, + |x, y| x.bitor_assign(y), + ))) + }; + } + + let data_type = &self.data_type; + downcast_integer! { + data_type => (helper, data_type), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -463,29 +319,35 @@ impl PartialEq for BitOr { } } -#[derive(Debug)] -struct BitOrAccumulator { - bit_or: ScalarValue, +struct BitOrAccumulator { + value: Option, } -impl BitOrAccumulator { - /// new bit_or accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - bit_or: ScalarValue::try_from(data_type)?, - }) +impl std::fmt::Debug for BitOrAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitOrAccumulator({})", T::DATA_TYPE) } } -impl Accumulator for BitOrAccumulator { +impl Default for BitOrAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitOrAccumulator +where + T::Native: std::ops::BitOr, +{ fn state(&self) -> Result> { - Ok(vec![self.bit_or.clone()]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &bit_or_batch(values)?; - self.bit_or = self.bit_or.bitor(delta)?; + if let Some(x) = bit_or(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v | x; + } Ok(()) } @@ -494,71 +356,11 @@ impl Accumulator for BitOrAccumulator { } fn evaluate(&self) -> Result { - Ok(self.bit_or.clone()) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_or) - + self.bit_or.size() - } -} - -#[derive(Debug)] -struct BitOrRowAccumulator { - index: usize, - datatype: DataType, -} - -impl BitOrRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } - } -} - -impl RowAccumulator for BitOrRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &bit_or_batch(values)?; - bit_or_row(self.index, accessor, delta)?; - Ok(()) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - bit_or_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - bit_or_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) - } - - #[inline(always)] - fn state_index(&self) -> usize { - self.index + std::mem::size_of_val(self) } } @@ -602,7 +404,19 @@ impl AggregateExpr for BitXor { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(BitXorAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty) => { + Ok(Box::>::default()) + }; + } + downcast_integer! { + &self.data_type => (helper), + _ => Err(DataFusionError::NotImplemented(format!( + "BitXor not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -621,18 +435,30 @@ impl AggregateExpr for BitXor { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(BitXorRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + use std::ops::BitXorAssign; + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + $dt, + |x, y| x.bitxor_assign(y), + ))) + }; + } + + let data_type = &self.data_type; + downcast_integer! { + data_type => (helper, data_type), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -654,29 +480,35 @@ impl PartialEq for BitXor { } } -#[derive(Debug)] -struct BitXorAccumulator { - bit_xor: ScalarValue, +struct BitXorAccumulator { + value: Option, } -impl BitXorAccumulator { - /// new bit_xor accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - bit_xor: ScalarValue::try_from(data_type)?, - }) +impl std::fmt::Debug for BitXorAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitXorAccumulator({})", T::DATA_TYPE) } } -impl Accumulator for BitXorAccumulator { +impl Default for BitXorAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitXorAccumulator +where + T::Native: std::ops::BitXor, +{ fn state(&self) -> Result> { - Ok(vec![self.bit_xor.clone()]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &bit_xor_batch(values)?; - self.bit_xor = self.bit_xor.bitxor(delta)?; + if let Some(x) = bit_xor(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v ^ x; + } Ok(()) } @@ -685,71 +517,11 @@ impl Accumulator for BitXorAccumulator { } fn evaluate(&self) -> Result { - Ok(self.bit_xor.clone()) + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.bit_xor) - + self.bit_xor.size() - } -} - -#[derive(Debug)] -struct BitXorRowAccumulator { - index: usize, - datatype: DataType, -} - -impl BitXorRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } - } -} - -impl RowAccumulator for BitXorRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &bit_xor_batch(values)?; - bit_xor_row(self.index, accessor, delta)?; - Ok(()) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - bit_xor_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - bit_xor_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) - } - - #[inline(always)] - fn state_index(&self) -> usize { - self.index + std::mem::size_of_val(self) } } @@ -793,9 +565,19 @@ impl AggregateExpr for DistinctBitXor { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctBitXorAccumulator::try_new( - &self.data_type, - )?)) + macro_rules! helper { + ($t:ty) => { + Ok(Box::>::default()) + }; + } + downcast_integer! { + &self.data_type => (helper), + _ => Err(DataFusionError::NotImplemented(format!( + "DistinctBitXorAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -830,34 +612,40 @@ impl PartialEq for DistinctBitXor { } } -#[derive(Debug)] -struct DistinctBitXorAccumulator { - hash_values: HashSet, - data_type: DataType, +struct DistinctBitXorAccumulator { + values: HashSet, } -impl DistinctBitXorAccumulator { - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - hash_values: HashSet::default(), - data_type: data_type.clone(), - }) +impl std::fmt::Debug for DistinctBitXorAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for DistinctBitXorAccumulator { + fn default() -> Self { + Self { + values: HashSet::default(), + } } } -impl Accumulator for DistinctBitXorAccumulator { +impl Accumulator for DistinctBitXorAccumulator +where + T::Native: std::ops::BitXor + std::hash::Hash + Eq, +{ fn state(&self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { - let mut distinct_values = Vec::new(); - self.hash_values + let values = self + .values .iter() - .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![ScalarValue::new_list( - Some(distinct_values), - self.data_type.clone(), - )] + .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + vec![ScalarValue::List(arr)] }; Ok(state_out) } @@ -867,53 +655,42 @@ impl Accumulator for DistinctBitXorAccumulator { return Ok(()); } - let arr = &values[0]; - (0..values[0].len()).try_for_each(|index| { - if !arr.is_null(index) { - let v = ScalarValue::try_from_array(arr, index)?; - self.hash_values.insert(v); + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(array.value(idx)); + } } - Ok(()) - }) + None => array.values().iter().for_each(|x| { + self.values.insert(*x); + }), + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.hash_values.insert(scalar.clone()); - } - }); - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); + if let Some(state) = states.first() { + let list_arr = as_list_array(state)?; + for arr in list_arr.iter().flatten() { + self.update_batch(&[arr])?; } - Ok(()) - }) + } + Ok(()) } fn evaluate(&self) -> Result { - let mut bit_xor_value = ScalarValue::try_from(&self.data_type)?; - for distinct_value in self.hash_values.iter() { - bit_xor_value = bit_xor_value.bitxor(distinct_value)?; + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc ^ *distinct_value; } - Ok(bit_xor_value) + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &T::DATA_TYPE) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.hash_values) - - std::mem::size_of_val(&self.hash_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) + std::mem::size_of_val(self) + + self.values.capacity() * std::mem::size_of::() } } @@ -923,6 +700,7 @@ mod tests { use crate::expressions::col; use crate::expressions::tests::aggregate; use crate::generic_test_op; + use arrow::array::*; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use datafusion_common::Result; diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs index bbab4dfce660e..9757d314b6aaf 100644 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ b/datafusion/physical-expr/src/aggregate/bool_and_or.rs @@ -17,27 +17,24 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::convert::TryFrom; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::datatypes::DataType; use arrow::{ array::{ArrayRef, BooleanArray}, datatypes::Field, }; -use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; +use crate::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::compute::{bool_and, bool_or}; -use datafusion_row::accessor::RowAccessor; // returns the new value after bool_and/bool_or with the new values, taking nullability into account macro_rules! typed_bool_and_or_batch { @@ -56,9 +53,9 @@ macro_rules! bool_and_or_batch { typed_bool_and_or_batch!($VALUES, BooleanArray, Boolean, $OP) } e => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Bool and/Bool or is not expected to receive the type {e:?}" - ))); + ); } } }}; @@ -74,53 +71,6 @@ fn bool_or_batch(values: &ArrayRef) -> Result { bool_and_or_batch!(values, bool_or) } -// bool_and/bool_or of two scalar values. -macro_rules! typed_bool_and_or_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $TYPE:ident, $OP:ident) => {{ - paste::item! { - match $SCALAR { - None => {} - Some(v) => $ACC.[<$OP _ $TYPE>]($INDEX, *v as $TYPE) - } - } - }}; -} - -macro_rules! bool_and_or_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $OP:ident) => {{ - Ok(match $SCALAR { - ScalarValue::Boolean(rhs) => { - typed_bool_and_or_v2!($INDEX, $ACC, rhs, bool, $OP) - } - ScalarValue::Null => { - // do nothing - } - e => { - return Err(DataFusionError::Internal(format!( - "BOOL AND/BOOL OR is not expected to receive scalars of incompatible types {:?}", - e - ))) - } - }) - }}; -} - -pub fn bool_and_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - bool_and_or_v2!(index, accessor, s, bitand) -} - -pub fn bool_or_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - bool_and_or_v2!(index, accessor, s, bitor) -} - /// BOOL_AND aggregate expression #[derive(Debug, Clone)] pub struct BoolAnd { @@ -161,7 +111,7 @@ impl AggregateExpr for BoolAnd { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(BoolAndAccumulator::try_new(&self.data_type)?)) + Ok(Box::::default()) } fn state_fields(&self) -> Result> { @@ -180,18 +130,21 @@ impl AggregateExpr for BoolAnd { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(BoolAndRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) + } + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -199,7 +152,7 @@ impl AggregateExpr for BoolAnd { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(BoolAndAccumulator::try_new(&self.data_type)?)) + Ok(Box::::default()) } } @@ -217,25 +170,20 @@ impl PartialEq for BoolAnd { } } -#[derive(Debug)] +#[derive(Debug, Default)] struct BoolAndAccumulator { - bool_and: ScalarValue, -} - -impl BoolAndAccumulator { - /// new bool_and accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - bool_and: ScalarValue::try_from(data_type)?, - }) - } + acc: Option, } impl Accumulator for BoolAndAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - let delta = &bool_and_batch(values)?; - self.bool_and = self.bool_and.and(delta)?; + self.acc = match (self.acc, bool_and_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a && b), + _ => unreachable!(), + }; Ok(()) } @@ -244,74 +192,15 @@ impl Accumulator for BoolAndAccumulator { } fn state(&self) -> Result> { - Ok(vec![self.bool_and.clone()]) + Ok(vec![ScalarValue::Boolean(self.acc)]) } fn evaluate(&self) -> Result { - Ok(self.bool_and.clone()) + Ok(ScalarValue::Boolean(self.acc)) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.bool_and) - + self.bool_and.size() - } -} - -#[derive(Debug)] -struct BoolAndRowAccumulator { - index: usize, - datatype: DataType, -} - -impl BoolAndRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } - } -} - -impl RowAccumulator for BoolAndRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &bool_and_batch(values)?; - bool_and_row(self.index, accessor, delta) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - bool_and_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - bool_and_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) - } - - #[inline(always)] - fn state_index(&self) -> usize { - self.index + std::mem::size_of_val(self) } } @@ -355,7 +244,7 @@ impl AggregateExpr for BoolOr { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(BoolOrAccumulator::try_new(&self.data_type)?)) + Ok(Box::::default()) } fn state_fields(&self) -> Result> { @@ -374,18 +263,21 @@ impl AggregateExpr for BoolOr { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(BoolOrRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + match self.data_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) + } + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -393,7 +285,7 @@ impl AggregateExpr for BoolOr { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(BoolOrAccumulator::try_new(&self.data_type)?)) + Ok(Box::::default()) } } @@ -411,29 +303,24 @@ impl PartialEq for BoolOr { } } -#[derive(Debug)] +#[derive(Debug, Default)] struct BoolOrAccumulator { - bool_or: ScalarValue, -} - -impl BoolOrAccumulator { - /// new bool_or accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - bool_or: ScalarValue::try_from(data_type)?, - }) - } + acc: Option, } impl Accumulator for BoolOrAccumulator { fn state(&self) -> Result> { - Ok(vec![self.bool_or.clone()]) + Ok(vec![ScalarValue::Boolean(self.acc)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - let delta = bool_or_batch(values)?; - self.bool_or = self.bool_or.or(&delta)?; + self.acc = match (self.acc, bool_or_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a || b), + _ => unreachable!(), + }; Ok(()) } @@ -442,71 +329,11 @@ impl Accumulator for BoolOrAccumulator { } fn evaluate(&self) -> Result { - Ok(self.bool_or.clone()) + Ok(ScalarValue::Boolean(self.acc)) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.bool_or) - + self.bool_or.size() - } -} - -#[derive(Debug)] -struct BoolOrRowAccumulator { - index: usize, - datatype: DataType, -} - -impl BoolOrRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } - } -} - -impl RowAccumulator for BoolOrRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &bool_or_batch(values)?; - bool_or_row(self.index, accessor, delta)?; - Ok(()) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - bool_or_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - bool_or_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) - } - - #[inline(always)] - fn state_index(&self) -> usize { - self.index + std::mem::size_of_val(self) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 71ddf91315130..c40f0db194055 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,10 +26,10 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. -use crate::{expressions, AggregateExpr, PhysicalExpr}; +use crate::aggregate::regr::RegrType; +use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Schema; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::aggregate_function::{return_type, sum_type_of_avg}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; pub use datafusion_expr::AggregateFunction; use std::sync::Arc; @@ -39,6 +39,7 @@ pub fn create_aggregate_expr( fun: &AggregateFunction, distinct: bool, input_phy_exprs: &[Arc], + ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, ) -> Result> { @@ -48,186 +49,256 @@ pub fn create_aggregate_expr( .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let rt_type = return_type(fun, &input_phy_types)?; + let data_type = input_phy_types[0].clone(); + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(input_schema)) + .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); - Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, rt_type), + expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), ), (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - input_phy_types[0].clone(), + data_type, input_phy_exprs[0].clone(), name, )), (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BitAnd, _) => Arc::new(expressions::BitAnd::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BitOr, _) => Arc::new(expressions::BitOr::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BitXor, false) => Arc::new(expressions::BitXor::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BitXor, true) => Arc::new(expressions::DistinctBitXor::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::BoolOr, _) => Arc::new(expressions::BoolOr::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, + )), + (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), )), - (AggregateFunction::Sum, false) => { - let cast_to_sum_type = rt_type != input_phy_types[0]; - Arc::new(expressions::Sum::new_with_pre_cast( - input_phy_exprs[0].clone(), - name, - rt_type, - cast_to_sum_type, - )) - } (AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new( vec![input_phy_exprs[0].clone()], name, - rt_type, + data_type, )), - (AggregateFunction::ApproxDistinct, _) => { - Arc::new(expressions::ApproxDistinct::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - )) + (AggregateFunction::ApproxDistinct, _) => Arc::new( + expressions::ApproxDistinct::new(input_phy_exprs[0].clone(), name, data_type), + ), + (AggregateFunction::ArrayAgg, false) => { + let expr = input_phy_exprs[0].clone(); + let nullable = expr.nullable(input_schema)?; + + if ordering_req.is_empty() { + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + } else { + Arc::new(expressions::OrderSensitiveArrayAgg::new( + expr, + name, + data_type, + nullable, + ordering_types, + ordering_req.to_vec(), + )) + } } - (AggregateFunction::ArrayAgg, false) => Arc::new(expressions::ArrayAgg::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - )), (AggregateFunction::ArrayAgg, true) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, - input_phy_types[0].clone(), + data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, + )), + (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new( + input_phy_exprs[0].clone(), + name, + data_type, )), - (AggregateFunction::Avg, false) => { - let sum_type = sum_type_of_avg(&input_phy_types)?; - let cast_to_sum_type = sum_type != input_phy_types[0]; - Arc::new(expressions::Avg::new_with_pre_cast( - input_phy_exprs[0].clone(), - name, - sum_type, - rt_type, - cast_to_sum_type, - )) - } (AggregateFunction::Avg, true) => { - return Err(DataFusionError::NotImplemented( - "AVG(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("AVG(DISTINCT) aggregations are not available"); } (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::Variance, true) => { - return Err(DataFusionError::NotImplemented( - "VAR(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("VAR(DISTINCT) aggregations are not available"); } (AggregateFunction::VariancePop, false) => Arc::new( - expressions::VariancePop::new(input_phy_exprs[0].clone(), name, rt_type), + expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type), ), (AggregateFunction::VariancePop, true) => { - return Err(DataFusionError::NotImplemented( - "VAR_POP(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } (AggregateFunction::Covariance, false) => Arc::new(expressions::Covariance::new( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - rt_type, + data_type, )), (AggregateFunction::Covariance, true) => { - return Err(DataFusionError::NotImplemented( - "COVAR(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("COVAR(DISTINCT) aggregations are not available"); } (AggregateFunction::CovariancePop, false) => { Arc::new(expressions::CovariancePop::new( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - rt_type, + data_type, )) } (AggregateFunction::CovariancePop, true) => { - return Err(DataFusionError::NotImplemented( - "COVAR_POP(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available"); } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::Stddev, true) => { - return Err(DataFusionError::NotImplemented( - "STDDEV(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("STDDEV(DISTINCT) aggregations are not available"); } (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::StddevPop, true) => { - return Err(DataFusionError::NotImplemented( - "STDDEV_POP(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); } (AggregateFunction::Correlation, false) => { Arc::new(expressions::Correlation::new( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - rt_type, + data_type, )) } (AggregateFunction::Correlation, true) => { - return Err(DataFusionError::NotImplemented( - "CORR(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("CORR(DISTINCT) aggregations are not available"); + } + (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Slope, + data_type, + )), + (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Intercept, + data_type, + )), + (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Count, + data_type, + )), + (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::R2, + data_type, + )), + (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::AvgX, + data_type, + )), + (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::AvgY, + data_type, + )), + (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SXX, + data_type, + )), + (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SYY, + data_type, + )), + (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SXY, + data_type, + )), + ( + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY, + true, + ) => { + return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { @@ -235,69 +306,85 @@ pub fn create_aggregate_expr( // Pass in the desired percentile expr input_phy_exprs, name, - rt_type, + data_type, )?) } else { Arc::new(expressions::ApproxPercentileCont::new_with_max_size( // Pass in the desired percentile expr input_phy_exprs, name, - rt_type, + data_type, )?) } } (AggregateFunction::ApproxPercentileCont, true) => { - return Err(DataFusionError::NotImplemented( + return not_impl_err!( "approx_percentile_cont(DISTINCT) aggregations are not available" - .to_string(), - )); + ); } (AggregateFunction::ApproxPercentileContWithWeight, false) => { Arc::new(expressions::ApproxPercentileContWithWeight::new( // Pass in the desired percentile expr input_phy_exprs, name, - rt_type, + data_type, )?) } (AggregateFunction::ApproxPercentileContWithWeight, true) => { - return Err(DataFusionError::NotImplemented( + return not_impl_err!( "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" - .to_string(), - )); + ); } (AggregateFunction::ApproxMedian, false) => { Arc::new(expressions::ApproxMedian::try_new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )?) } (AggregateFunction::ApproxMedian, true) => { - return Err(DataFusionError::NotImplemented( - "APPROX_MEDIAN(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!( + "APPROX_MEDIAN(DISTINCT) aggregations are not available" + ); } (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( input_phy_exprs[0].clone(), name, - rt_type, + data_type, )), (AggregateFunction::Median, true) => { - return Err(DataFusionError::NotImplemented( - "MEDIAN(DISTINCT) aggregations are not available".to_string(), - )); + return not_impl_err!("MEDIAN(DISTINCT) aggregations are not available"); } (AggregateFunction::FirstValue, _) => Arc::new(expressions::FirstValue::new( input_phy_exprs[0].clone(), name, input_phy_types[0].clone(), + ordering_req.to_vec(), + ordering_types, )), (AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new( input_phy_exprs[0].clone(), name, input_phy_types[0].clone(), + ordering_req.to_vec(), + ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } @@ -310,9 +397,10 @@ mod tests { DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; use arrow::datatypes::{DataType, Field}; + use datafusion_common::plan_err; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{aggregate_function, type_coercion, Signature}; + use datafusion_expr::{type_coercion, Signature}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -366,8 +454,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -405,8 +493,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -1012,16 +1100,14 @@ mod tests { #[test] fn test_median() -> Result<()> { - let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]); + let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Utf8]); assert!(observed.is_err()); - let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?; + let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Int32])?; assert_eq!(DataType::Int32, observed); - let observed = return_type( - &AggregateFunction::ApproxMedian, - &[DataType::Decimal128(10, 6)], - ); + let observed = + AggregateFunction::ApproxMedian.return_type(&[DataType::Decimal128(10, 6)]); assert!(observed.is_err()); Ok(()) @@ -1029,20 +1115,20 @@ mod tests { #[test] fn test_min_max() -> Result<()> { - let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; + let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; + let observed = AggregateFunction::Max.return_type(&[DataType::Int32])?; assert_eq!(DataType::Int32, observed); // test decimal for min let observed = - return_type(&AggregateFunction::Min, &[DataType::Decimal128(10, 6)])?; + AggregateFunction::Min.return_type(&[DataType::Decimal128(10, 6)])?; assert_eq!(DataType::Decimal128(10, 6), observed); // test decimal for max let observed = - return_type(&AggregateFunction::Max, &[DataType::Decimal128(28, 13)])?; + AggregateFunction::Max.return_type(&[DataType::Decimal128(28, 13)])?; assert_eq!(DataType::Decimal128(28, 13), observed); Ok(()) @@ -1050,24 +1136,24 @@ mod tests { #[test] fn test_sum_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?; + let observed = AggregateFunction::Sum.return_type(&[DataType::Int32])?; assert_eq!(DataType::Int64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?; + let observed = AggregateFunction::Sum.return_type(&[DataType::UInt8])?; assert_eq!(DataType::UInt64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?; + let observed = AggregateFunction::Sum.return_type(&[DataType::Float32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?; + let observed = AggregateFunction::Sum.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); let observed = - return_type(&AggregateFunction::Sum, &[DataType::Decimal128(10, 5)])?; + AggregateFunction::Sum.return_type(&[DataType::Decimal128(10, 5)])?; assert_eq!(DataType::Decimal128(20, 5), observed); let observed = - return_type(&AggregateFunction::Sum, &[DataType::Decimal128(35, 5)])?; + AggregateFunction::Sum.return_type(&[DataType::Decimal128(35, 5)])?; assert_eq!(DataType::Decimal128(38, 5), observed); Ok(()) @@ -1075,73 +1161,73 @@ mod tests { #[test] fn test_sum_no_utf8() { - let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]); + let observed = AggregateFunction::Sum.return_type(&[DataType::Utf8]); assert!(observed.is_err()); } #[test] fn test_sum_upcasts() -> Result<()> { - let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?; + let observed = AggregateFunction::Sum.return_type(&[DataType::UInt32])?; assert_eq!(DataType::UInt64, observed); Ok(()) } #[test] fn test_count_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?; + let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; assert_eq!(DataType::Int64, observed); - let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; + let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; assert_eq!(DataType::Int64, observed); let observed = - return_type(&AggregateFunction::Count, &[DataType::Decimal128(28, 13)])?; + AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; assert_eq!(DataType::Int64, observed); Ok(()) } #[test] fn test_avg_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?; + let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?; + let observed = AggregateFunction::Avg.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?; + let observed = AggregateFunction::Avg.return_type(&[DataType::Int32])?; assert_eq!(DataType::Float64, observed); let observed = - return_type(&AggregateFunction::Avg, &[DataType::Decimal128(10, 6)])?; + AggregateFunction::Avg.return_type(&[DataType::Decimal128(10, 6)])?; assert_eq!(DataType::Decimal128(14, 10), observed); let observed = - return_type(&AggregateFunction::Avg, &[DataType::Decimal128(36, 6)])?; + AggregateFunction::Avg.return_type(&[DataType::Decimal128(36, 6)])?; assert_eq!(DataType::Decimal128(38, 10), observed); Ok(()) } #[test] fn test_avg_no_utf8() { - let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); + let observed = AggregateFunction::Avg.return_type(&[DataType::Utf8]); assert!(observed.is_err()); } #[test] fn test_variance_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?; + let observed = AggregateFunction::Variance.return_type(&[DataType::Float32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?; + let observed = AggregateFunction::Variance.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?; + let observed = AggregateFunction::Variance.return_type(&[DataType::Int32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?; + let observed = AggregateFunction::Variance.return_type(&[DataType::UInt32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?; + let observed = AggregateFunction::Variance.return_type(&[DataType::Int64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -1149,25 +1235,25 @@ mod tests { #[test] fn test_variance_no_utf8() { - let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); + let observed = AggregateFunction::Variance.return_type(&[DataType::Utf8]); assert!(observed.is_err()); } #[test] fn test_stddev_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?; + let observed = AggregateFunction::Stddev.return_type(&[DataType::Float32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?; + let observed = AggregateFunction::Stddev.return_type(&[DataType::Float64])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?; + let observed = AggregateFunction::Stddev.return_type(&[DataType::Int32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?; + let observed = AggregateFunction::Stddev.return_type(&[DataType::UInt32])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?; + let observed = AggregateFunction::Stddev.return_type(&[DataType::Int64])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -1175,7 +1261,7 @@ mod tests { #[test] fn test_stddev_no_utf8() { - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); + let observed = AggregateFunction::Stddev.return_type(&[DataType::Utf8]); assert!(observed.is_err()); } @@ -1189,18 +1275,14 @@ mod tests { name: impl Into, ) -> Result> { let name = name.into(); - let coerced_phy_exprs = coerce_exprs_for_test( - fun, - input_phy_exprs, - input_schema, - &aggregate_function::signature(fun), - )?; + let coerced_phy_exprs = + coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; if coerced_phy_exprs.is_empty() { - return Err(DataFusionError::Plan(format!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'", - ))); + return plan_err!( + "Invalid or wrong number of arguments passed to aggregate: '{name}'" + ); } - create_aggregate_expr(fun, distinct, &coerced_phy_exprs, input_schema, name) + create_aggregate_expr(fun, distinct, &coerced_phy_exprs, &[], input_schema, name) } // Returns the coerced exprs for each `input_exprs`. @@ -1227,7 +1309,7 @@ mod tests { // try cast if need input_exprs .iter() - .zip(coerced_types.into_iter()) + .zip(coerced_types) .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) .collect::>>() } diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index 475bfa4ce0da2..61f2db5c8ef93 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -505,13 +505,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 15df28b4e38ad..738ca4e915f7d 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -22,21 +22,25 @@ use std::fmt::Debug; use std::ops::BitAnd; use std::sync::Arc; -use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::array::{Array, Int64Array}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::PrimitiveArray; use arrow_buffer::BooleanBuffer; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; -use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; +use super::groups_accumulator::accumulate::accumulate_indices; +use super::groups_accumulator::EmitTo; + /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug, Clone)] @@ -44,6 +48,10 @@ pub struct Count { name: String, data_type: DataType, nullable: bool, + /// Input exprs + /// + /// For `COUNT(c1)` this is `[c1]` + /// For `COUNT(c1, c2)` this is `[c1, c2]` exprs: Vec>, } @@ -76,6 +84,114 @@ impl Count { } } +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.nulls(), // ignore values + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { @@ -102,17 +218,13 @@ impl AggregateExpr for Count { } fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + Ok(Field::new(&self.name, DataType::Int64, self.nullable)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "count"), - self.data_type.clone(), + DataType::Int64, true, )]) } @@ -129,19 +241,10 @@ impl AggregateExpr for Count { &self.name } - fn row_accumulator_supported(&self) -> bool { - true - } - - fn supports_bounded_execution(&self) -> bool { - true - } - - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(CountRowAccumulator::new(start_index))) + fn groups_accumulator_supported(&self) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + self.exprs.len() == 1 } fn reverse_expr(&self) -> Option> { @@ -151,6 +254,11 @@ impl AggregateExpr for Count { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(CountAccumulator::new())) } + + fn create_groups_accumulator(&self) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } } impl PartialEq for Count { @@ -214,81 +322,12 @@ impl Accumulator for CountAccumulator { Ok(ScalarValue::Int64(Some(self.count))) } - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[derive(Debug)] -struct CountRowAccumulator { - state_index: usize, -} - -impl CountRowAccumulator { - pub fn new(index: usize) -> Self { - Self { state_index: index } - } -} - -impl RowAccumulator for CountRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let array = &values[0]; - let delta = (array.len() - null_count_for_multiple_cols(values)) as u64; - accessor.add_u64(self.state_index, delta); - Ok(()) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - if !values.iter().any(|s| matches!(s, ScalarValue::Null)) { - accessor.add_u64(self.state_index, 1) - } - Ok(()) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - match value { - ScalarValue::Null => { - // do not update the accumulator - } - _ => accessor.add_u64(self.state_index, 1), - } - Ok(()) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - accessor.add_i64(self.state_index, *d); - } - Ok(()) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(ScalarValue::Int64(Some( - accessor.get_u64_opt(self.state_index()).unwrap_or(0) as i64, - ))) + fn supports_retract_batch(&self) -> bool { + true } - #[inline(always)] - fn state_index(&self) -> usize { - self.state_index + fn size(&self) -> usize { + std::mem::size_of_val(self) } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 94e6c082837ab..f5242d983d4cf 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::{DataType, Field}; + use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -27,8 +28,8 @@ use std::collections::HashSet; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; +use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -142,18 +143,11 @@ impl DistinctCountAccumulator { impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { - let mut cols_out = - ScalarValue::new_list(Some(Vec::new()), self.state_data_type.clone()); - self.values - .iter() - .enumerate() - .for_each(|(_, distinct_values)| { - if let ScalarValue::List(Some(ref mut v), _) = cols_out { - v.push(distinct_values.clone()); - } - }); - Ok(vec![cols_out]) + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); @@ -167,27 +161,17 @@ impl Accumulator for DistinctCountAccumulator { Ok(()) }) } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.values.insert(scalar.clone()); - } - }); - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); - } - Ok(()) - }) + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec.into_iter() { + self.values.extend(scalars) + } + Ok(()) } fn evaluate(&self) -> Result { @@ -213,32 +197,21 @@ mod tests { Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::DataType; - - macro_rules! state_to_vec { - ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ - match $LIST { - ScalarValue::List(_, field) => match field.data_type() { - &DataType::$DATA_TYPE => (), - _ => panic!("Unexpected DataType for list"), - }, - _ => panic!("Expected a ScalarValue::List"), - } - - match $LIST { - ScalarValue::List(None, _) => None, - ScalarValue::List(Some(scalar_values), _) => { - let vec = scalar_values - .iter() - .map(|scalar_value| match scalar_value { - ScalarValue::$DATA_TYPE(value) => *value, - _ => panic!("Unexpected ScalarValue variant"), - }) - .collect::>>(); - - Some(vec) - } - _ => unreachable!(), - } + use arrow::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; + use datafusion_common::internal_err; + use datafusion_common::DataFusionError; + + macro_rules! state_to_vec_primitive { + ($LIST:expr, $DATA_TYPE:ident) => {{ + let arr = ScalarValue::raw_data($LIST).unwrap(); + let list_arr = as_list_array(&arr).unwrap(); + let arr = list_arr.values(); + let arr = as_primitive_array::<$DATA_TYPE>(arr)?; + arr.values().iter().cloned().collect::>() }}; } @@ -260,18 +233,25 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); state_vec.sort(); assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); + assert_eq!(state_vec, vec![1, 2, 3]); assert_eq!(result, ScalarValue::Int64(Some(3))); Ok(()) }}; } + fn state_to_vec_bool(sv: &ScalarValue) -> Result> { + let arr = ScalarValue::raw_data(sv)?; + let list_arr = as_list_array(&arr)?; + let arr = list_arr.values(); + let bool_arr = as_boolean_array(arr)?; + Ok(bool_arr.iter().flatten().collect()) + } + fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { let agg = DistinctCount::new( arrays[0].data_type().clone(), @@ -354,13 +334,11 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); dbg!(&state_vec); state_vec.sort_by(|a, b| match (a, b) { - (Some(lhs), Some(rhs)) => lhs.total_cmp(rhs), - _ => a.partial_cmp(b).unwrap(), + (lhs, rhs) => lhs.total_cmp(rhs), }); let nan_idx = state_vec.len() - 1; @@ -368,16 +346,16 @@ mod tests { assert_eq!( &state_vec[..nan_idx], vec![ - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(-4.5), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::INFINITY) + <$PRIM_TYPE>::NEG_INFINITY, + -4.5, + <$PRIM_TYPE as SubNormal>::SUBNORMAL, + 1.0, + 2.0, + 3.0, + <$PRIM_TYPE>::INFINITY ] ); - assert!(state_vec[nan_idx].unwrap_or_default().is_nan()); + assert!(state_vec[nan_idx].is_nan()); assert_eq!(result, ScalarValue::Int64(Some(8))); Ok(()) @@ -386,68 +364,69 @@ mod tests { #[test] fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8) + test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) } #[test] fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16) + test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) } #[test] fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32) + test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) } #[test] fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64) + test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) } #[test] fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8) + test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) } #[test] fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16) + test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) } #[test] fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32) + test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) } #[test] fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64) + test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) } #[test] fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32, f32) + test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) } #[test] fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64, f64) + test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) } #[test] fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec>, i64)> { + let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { let arrays = vec![Arc::new(data) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec!(&states[0], Boolean, bool).unwrap(); + let mut state_vec = state_to_vec_bool(&states[0])?; state_vec.sort(); + let count = match result { ScalarValue::Int64(c) => c.ok_or_else(|| { DataFusionError::Internal("Found None count".to_string()) }), - scalar => Err(DataFusionError::Internal(format!( - "Found non int64 scalar value from count: {scalar}" - ))), + scalar => { + internal_err!("Found non int64 scalar value from count: {scalar}") + } }?; Ok((state_vec, count)) }; @@ -468,22 +447,13 @@ mod tests { Some(false), ]); - assert_eq!( - get_count(zero_count_values)?, - (Vec::>::new(), 0) - ); - assert_eq!(get_count(one_count_values)?, (vec![Some(false)], 1)); - assert_eq!( - get_count(one_count_values_with_null)?, - (vec![Some(true)], 1) - ); - assert_eq!( - get_count(two_count_values)?, - (vec![Some(false), Some(true)], 2) - ); + assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); + assert_eq!(get_count(one_count_values)?, (vec![false], 1)); + assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); + assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); assert_eq!( get_count(two_count_values_with_null)?, - (vec![Some(false), Some(true)], 2) + (vec![false, true], 2) ); Ok(()) } @@ -495,9 +465,9 @@ mod tests { )) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) @@ -508,9 +478,9 @@ mod tests { let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 5e589d4e39fd3..0f838eb6fa1cf 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -754,13 +754,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index a350637c48820..0dc27dede8b62 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -17,25 +17,31 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. -use crate::aggregate::utils::down_cast_any_ref; +use std::any::Any; +use std::sync::Arc; + +use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::{Result, ScalarValue}; +use arrow_schema::SortOptions; +use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - /// FIRST_VALUE aggregate expression #[derive(Debug)] pub struct FirstValue { name: String, - pub data_type: DataType, + input_data_type: DataType, + order_by_data_types: Vec, expr: Arc, + ordering_req: LexOrdering, } impl FirstValue { @@ -43,12 +49,16 @@ impl FirstValue { pub fn new( expr: Arc, name: impl Into, - data_type: DataType, + input_data_type: DataType, + ordering_req: LexOrdering, + order_by_data_types: Vec, ) -> Self { Self { name: name.into(), - data_type, + input_data_type, + order_by_data_types, expr, + ordering_req, } } } @@ -60,25 +70,47 @@ impl AggregateExpr for FirstValue { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) } fn state_fields(&self) -> Result> { - Ok(vec![Field::new( + let mut fields = vec![Field::new( format_state_name(&self.name, "first_value"), - self.data_type.clone(), + self.input_data_type.clone(), + true, + )]; + fields.extend(ordering_fields( + &self.ordering_req, + &self.order_by_data_types, + )); + fields.push(Field::new( + format_state_name(&self.name, "is_set"), + DataType::Boolean, true, - )]) + )); + Ok(fields) } fn expressions(&self) -> Vec> { vec![self.expr.clone()] } + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + if self.ordering_req.is_empty() { + None + } else { + Some(&self.ordering_req) + } + } + fn name(&self) -> &str { &self.name } @@ -92,12 +124,18 @@ impl AggregateExpr for FirstValue { Some(Arc::new(LastValue::new( self.expr.clone(), name, - self.data_type.clone(), + self.input_data_type.clone(), + reverse_order_bys(&self.ordering_req), + self.order_by_data_types.clone(), ))) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(FirstValueAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) } } @@ -107,7 +145,8 @@ impl PartialEq for FirstValue { .downcast_ref::() .map(|x| { self.name == x.name - && self.data_type == x.data_type + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types && self.expr.eq(&x.expr) }) .unwrap_or(false) @@ -117,42 +156,94 @@ impl PartialEq for FirstValue { #[derive(Debug)] struct FirstValueAccumulator { first: ScalarValue, - // At the beginning, `is_set` is `false`, this means `first` is not seen yet. - // Once we see (`is_set=true`) first value, we do not update `first`. + // At the beginning, `is_set` is false, which means `first` is not seen yet. + // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. is_set: bool, + // Stores ordering values, of the aggregator requirement corresponding to first value + // of the aggregator. These values are used during merging of multiple partitions. + orderings: Vec, + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, } impl FirstValueAccumulator { /// Creates a new `FirstValueAccumulator` for the given `data_type`. - pub fn try_new(data_type: &DataType) -> Result { + pub fn try_new( + data_type: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { + let orderings = ordering_dtypes + .iter() + .map(ScalarValue::try_from) + .collect::>>()?; ScalarValue::try_from(data_type).map(|value| Self { first: value, is_set: false, + orderings, + ordering_req, }) } + + // Updates state with the values in the given row. + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.first = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; + } } impl Accumulator for FirstValueAccumulator { fn state(&self) -> Result> { - Ok(vec![ - self.first.clone(), - ScalarValue::Boolean(Some(self.is_set)), - ]) + let mut result = vec![self.first.clone()]; + result.extend(self.orderings.iter().cloned()); + result.push(ScalarValue::Boolean(Some(self.is_set))); + Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // If we have seen first value, we shouldn't update it - let values = &values[0]; - if !values.is_empty() && !self.is_set { - self.first = ScalarValue::try_from_array(values, 0)?; - self.is_set = true; + if !values[0].is_empty() && !self.is_set { + let row = get_row_at_idx(values, 0)?; + // Update with first value in the array. + self.update_with_new_row(&row); } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // FIRST_VALUE(first1, first2, first3, ...) - self.update_batch(states) + // last index contains is_set flag. + let is_set_idx = states.len() - 1; + let flags = states[is_set_idx].as_boolean(); + let filtered_states = filter_states_according_to_is_set(states, flags)?; + // 1..is_set_idx range corresponds to ordering section + let sort_cols = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + + let ordered_states = if sort_cols.is_empty() { + // When no ordering is given, use the existing state as is: + filtered_states + } else { + let indices = lexsort_to_indices(&sort_cols, None)?; + get_arrayref_at_indices(&filtered_states, &indices)? + }; + if !ordered_states[0].is_empty() { + let first_row = get_row_at_idx(&ordered_states, 0)?; + // When collecting orderings, we exclude the is_set flag from the state. + let first_ordering = &first_row[1..is_set_idx]; + let sort_options = get_sort_options(&self.ordering_req); + // Either there is no existing value, or there is an earlier version in new data. + if !self.is_set + || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + { + // Update with first value in the state. Note that we should exclude the + // is_set flag from the state. Otherwise, we will end up with a state + // containing two is_set flags. + self.update_with_new_row(&first_row[0..is_set_idx]); + } + } + Ok(()) } fn evaluate(&self) -> Result { @@ -162,6 +253,8 @@ impl Accumulator for FirstValueAccumulator { fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + self.first.size() + + ScalarValue::size_of_vec(&self.orderings) + - std::mem::size_of_val(&self.orderings) } } @@ -169,8 +262,10 @@ impl Accumulator for FirstValueAccumulator { #[derive(Debug)] pub struct LastValue { name: String, - pub data_type: DataType, + input_data_type: DataType, + order_by_data_types: Vec, expr: Arc, + ordering_req: LexOrdering, } impl LastValue { @@ -178,12 +273,16 @@ impl LastValue { pub fn new( expr: Arc, name: impl Into, - data_type: DataType, + input_data_type: DataType, + ordering_req: LexOrdering, + order_by_data_types: Vec, ) -> Self { Self { name: name.into(), - data_type, + input_data_type, + order_by_data_types, expr, + ordering_req, } } } @@ -195,25 +294,47 @@ impl AggregateExpr for LastValue { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(LastValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) } fn state_fields(&self) -> Result> { - Ok(vec![Field::new( + let mut fields = vec![Field::new( format_state_name(&self.name, "last_value"), - self.data_type.clone(), + self.input_data_type.clone(), + true, + )]; + fields.extend(ordering_fields( + &self.ordering_req, + &self.order_by_data_types, + )); + fields.push(Field::new( + format_state_name(&self.name, "is_set"), + DataType::Boolean, true, - )]) + )); + Ok(fields) } fn expressions(&self) -> Vec> { vec![self.expr.clone()] } + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + if self.ordering_req.is_empty() { + None + } else { + Some(&self.ordering_req) + } + } + fn name(&self) -> &str { &self.name } @@ -227,12 +348,18 @@ impl AggregateExpr for LastValue { Some(Arc::new(FirstValue::new( self.expr.clone(), name, - self.data_type.clone(), + self.input_data_type.clone(), + reverse_order_bys(&self.ordering_req), + self.order_by_data_types.clone(), ))) } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(LastValueAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(LastValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) } } @@ -242,7 +369,8 @@ impl PartialEq for LastValue { .downcast_ref::() .map(|x| { self.name == x.name - && self.data_type == x.data_type + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types && self.expr.eq(&x.expr) }) .unwrap_or(false) @@ -252,34 +380,95 @@ impl PartialEq for LastValue { #[derive(Debug)] struct LastValueAccumulator { last: ScalarValue, + // The `is_set` flag keeps track of whether the last value is finalized. + // This information is used to discriminate genuine NULLs and NULLS that + // occur due to empty partitions. + is_set: bool, + orderings: Vec, + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, } impl LastValueAccumulator { /// Creates a new `LastValueAccumulator` for the given `data_type`. - pub fn try_new(data_type: &DataType) -> Result { + pub fn try_new( + data_type: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { + let orderings = ordering_dtypes + .iter() + .map(ScalarValue::try_from) + .collect::>>()?; Ok(Self { last: ScalarValue::try_from(data_type)?, + is_set: false, + orderings, + ordering_req, }) } + + // Updates state with the values in the given row. + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.last = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; + } } impl Accumulator for LastValueAccumulator { fn state(&self) -> Result> { - Ok(vec![self.last.clone()]) + let mut result = vec![self.last.clone()]; + result.extend(self.orderings.clone()); + result.push(ScalarValue::Boolean(Some(self.is_set))); + Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - if !values.is_empty() { + if !values[0].is_empty() { + let row = get_row_at_idx(values, values[0].len() - 1)?; // Update with last value in the array. - self.last = ScalarValue::try_from_array(values, values.len() - 1)?; + self.update_with_new_row(&row); } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // LAST_VALUE(last1, last2, last3, ...) - self.update_batch(states) + // last index contains is_set flag. + let is_set_idx = states.len() - 1; + let flags = states[is_set_idx].as_boolean(); + let filtered_states = filter_states_according_to_is_set(states, flags)?; + // 1..is_set_idx range corresponds to ordering section + let sort_cols = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + + let ordered_states = if sort_cols.is_empty() { + // When no ordering is given, use existing state as is: + filtered_states + } else { + let indices = lexsort_to_indices(&sort_cols, None)?; + get_arrayref_at_indices(&filtered_states, &indices)? + }; + + if !ordered_states[0].is_empty() { + let last_idx = ordered_states[0].len() - 1; + let last_row = get_row_at_idx(&ordered_states, last_idx)?; + // When collecting orderings, we exclude the is_set flag from the state. + let last_ordering = &last_row[1..is_set_idx]; + let sort_options = get_sort_options(&self.ordering_req); + // Either there is no existing value, or there is a newer (latest) + // version in the new data: + if !self.is_set + || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + { + // Update with last value in the state. Note that we should exclude the + // is_set flag from the state. Otherwise, we will end up with a state + // containing two is_set flags. + self.update_with_new_row(&last_row[0..is_set_idx]); + } + } + Ok(()) } fn evaluate(&self) -> Result { @@ -287,23 +476,65 @@ impl Accumulator for LastValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + self.last.size() + std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + + self.last.size() + + ScalarValue::size_of_vec(&self.orderings) + - std::mem::size_of_val(&self.orderings) } } +/// Filters states according to the `is_set` flag at the last column and returns +/// the resulting states. +fn filter_states_according_to_is_set( + states: &[ArrayRef], + flags: &BooleanArray, +) -> Result> { + states + .iter() + .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .collect::>>() +} + +/// Combines array refs and their corresponding orderings to construct `SortColumn`s. +fn convert_to_sort_cols( + arrs: &[ArrayRef], + sort_exprs: &[PhysicalSortExpr], +) -> Vec { + arrs.iter() + .zip(sort_exprs.iter()) + .map(|(item, sort_expr)| SortColumn { + values: item.clone(), + options: Some(sort_expr.options), + }) + .collect::>() +} + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req + .iter() + .map(|item| item.options) + .collect::>() +} + #[cfg(test)] mod tests { use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; + + use arrow::compute::concat; use std::sync::Arc; #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64)?; - let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64)?; + let mut first_accumulator = + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + let mut last_accumulator = + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -327,4 +558,78 @@ mod tests { assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12))); Ok(()) } + + #[test] + fn test_first_last_state_after_merge() -> Result<()> { + let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; + // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12 + let arrs = ranges + .into_iter() + .map(|(start, end)| { + Arc::new((start..end).collect::()) as ArrayRef + }) + .collect::>(); + + // FirstValueAccumulator + let mut first_accumulator = + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + + first_accumulator.update_batch(&[arrs[0].clone()])?; + let state1 = first_accumulator.state()?; + + let mut first_accumulator = + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + first_accumulator.update_batch(&[arrs[1].clone()])?; + let state2 = first_accumulator.state()?; + + assert_eq!(state1.len(), state2.len()); + + let mut states = vec![]; + + for idx in 0..state1.len() { + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); + } + + let mut first_accumulator = + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + first_accumulator.merge_batch(&states)?; + + let merged_state = first_accumulator.state()?; + assert_eq!(merged_state.len(), state1.len()); + + // LastValueAccumulator + let mut last_accumulator = + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + + last_accumulator.update_batch(&[arrs[0].clone()])?; + let state1 = last_accumulator.state()?; + + let mut last_accumulator = + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + last_accumulator.update_batch(&[arrs[1].clone()])?; + let state2 = last_accumulator.state()?; + + assert_eq!(state1.len(), state2.len()); + + let mut states = vec![]; + + for idx in 0..state1.len() { + states.push(concat(&[ + &state1[idx].to_array()?, + &state2[idx].to_array()?, + ])?); + } + + let mut last_accumulator = + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + last_accumulator.merge_batch(&states)?; + + let merged_state = last_accumulator.state()?; + assert_eq!(merged_state.len(), state1.len()); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs index 7ba303960eafc..70afda265aeaf 100644 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ b/datafusion/physical-expr/src/aggregate/grouping.rs @@ -24,7 +24,7 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::{AggregateExpr, PhysicalExpr}; use arrow::datatypes::DataType; use arrow::datatypes::Field; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_expr::Accumulator; use crate::expressions::format_state_name; @@ -62,17 +62,13 @@ impl AggregateExpr for Grouping { } fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + Ok(Field::new(&self.name, DataType::Int32, self.nullable)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "grouping"), - self.data_type.clone(), + DataType::Int32, true, )]) } @@ -82,10 +78,9 @@ impl AggregateExpr for Grouping { } fn create_accumulator(&self) -> Result> { - Err(DataFusionError::NotImplemented( + not_impl_err!( "physical plan is not yet implemented for GROUPING aggregate function" - .to_owned(), - )) + ) } fn name(&self) -> &str { diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs new file mode 100644 index 0000000000000..596265a737da0 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -0,0 +1,873 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] +//! +//! [`GroupsAccumulator`]: crate::GroupsAccumulator + +use arrow::datatypes::ArrowPrimitiveType; +use arrow_array::{Array, BooleanArray, PrimitiveArray}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; + +use crate::EmitTo; + +/// Track the accumulator null state per row: if any values for that +/// group were null and if any values have been seen at all for that group. +/// +/// This is part of the inner loop for many [`GroupsAccumulator`]s, +/// and thus the performance is critical and so there are multiple +/// specialized implementations, invoked depending on the specific +/// combinations of the input. +/// +/// Typically there are 4 potential combinations of inputs must be +/// special cased for performance: +/// +/// * With / Without filter +/// * With / Without nulls in the input +/// +/// If the input has nulls, then the accumulator must potentially +/// handle each input null value specially (e.g. for `SUM` to mark the +/// corresponding sum as null) +/// +/// If there are filters present, `NullState` tracks if it has seen +/// *any* value for that group (as some values may be filtered +/// out). Without a filter, the accumulator is only passed groups that +/// had at least one value to accumulate so they do not need to track +/// if they have seen values for a particular group. +/// +/// [`GroupsAccumulator`]: crate::GroupsAccumulator +#[derive(Debug)] +pub struct NullState { + /// Have we seen any non-filtered input values for `group_index`? + /// + /// If `seen_values[i]` is true, have seen at least one non null + /// value for group `i` + /// + /// If `seen_values[i]` is false, have not seen any values that + /// pass the filter yet for group `i` + seen_values: BooleanBufferBuilder, +} + +impl NullState { + pub fn new() -> Self { + Self { + seen_values: BooleanBufferBuilder::new(0), + } + } + + /// return the size of all buffers allocated by this null state, not including self + pub fn size(&self) -> usize { + // capacity is in bits, so convert to bytes + self.seen_values.capacity() / 8 + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value of `value`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs if necessary + // + /// # Arguments: + /// + /// * `values`: the input arguments to the accumulator + /// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) + /// * `opt_filter`: if present, only rows for which is Some(true) are included + /// * `value_fn`: function invoked for (group_index, value) where value is non null + /// + /// # Example + /// + /// ```text + /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ + /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ + /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ + /// │ └─────┘ │ │ └─────┘ │ └─────┘ + /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// + /// group_indices values opt_filter + /// ``` + /// + /// In the example above, `value_fn` is invoked for each (group_index, + /// value) pair where `opt_filter[i]` is true and values is non null + /// + /// ```text + /// value_fn(2, 200) + /// value_fn(0, 200) + /// value_fn(0, 300) + /// ``` + /// + /// It also sets + /// + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + pub fn accumulate( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, + { + let data: &[T::Native] = values.values(); + assert_eq!(data.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = nulls.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real value + let is_valid = (mask & index_mask) != 0; + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + }); + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, &new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value) + } + } + }) + } + } + } + + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`BooleanArray`]s. + /// + /// Since `BooleanArray` is not a [`PrimitiveArray`] it must be + /// handled specially. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_boolean( + &mut self, + group_indices: &[usize], + values: &BooleanArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, bool) + Send, + { + let data = values.values(); + assert_eq!(data.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + // These could be made more performant by iterating in chunks of 64 bits at a time + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + // if we have previously seen nulls, ensure the null + // buffer is big enough (start everything at valid) + group_indices.iter().zip(data.iter()).for_each( + |(&group_index, new_value)| { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value) + }, + ) + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(data.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value) + } + } + }) + } + } + } + + /// Creates the a [`NullBuffer`] representing which group_indices + /// should have null values (because they never saw any values) + /// for the `emit_to` rows. + /// + /// resets the internal state appropriately + pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer { + let nulls: BooleanBuffer = self.seen_values.finish(); + + let nulls = match emit_to { + EmitTo::All => nulls, + EmitTo::First(n) => { + // split off the first N values in seen_values + // + // TODO make this more efficient rather than two + // copies and bitwise manipulation + let first_n_null: BooleanBuffer = nulls.iter().take(n).collect(); + // reset the existing seen buffer + for seen in nulls.iter().skip(n) { + self.seen_values.append(seen); + } + first_n_null + } + }; + NullBuffer::new(nulls) + } +} + +/// This function is called to update the accumulator state per row +/// when the value is not needed (e.g. COUNT) +/// +/// `F`: Invoked like `value_fn(group_index) for all non null values +/// passing the filter. Note that no tracking is done for null inputs +/// or which groups have seen any values +/// +/// See [`NullState::accumulate`], for more details on other +/// arguments. +pub fn accumulate_indices( + group_indices: &[usize], + nulls: Option<&NullBuffer>, + opt_filter: Option<&BooleanArray>, + mut index_fn: F, +) where + F: FnMut(usize) + Send, +{ + match (nulls, opt_filter) { + (None, None) => { + for &group_index in group_indices.iter() { + index_fn(group_index) + } + } + (None, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket + let iter = group_indices.iter().zip(filter.iter()); + for (&group_index, filter_value) in iter { + if let Some(true) = filter_value { + index_fn(group_index) + } + } + } + (Some(valids), None) => { + assert_eq!(valids.len(), group_indices.len()); + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = valids.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); + } + + (Some(valids), Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + assert_eq!(valids.len(), group_indices.len()); + // The performance with a filter could likely be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket + filter + .iter() + .zip(group_indices.iter()) + .zip(valids.iter()) + .for_each(|((filter_value, &group_index), is_valid)| { + if let (Some(true), true) = (filter_value, is_valid) { + index_fn(group_index) + } + }) + } + } +} + +/// Ensures that `builder` contains a `BooleanBufferBuilder with at +/// least `total_num_groups`. +/// +/// All new entries are initialized to `default_value` +fn initialize_builder( + builder: &mut BooleanBufferBuilder, + total_num_groups: usize, + default_value: bool, +) -> &mut BooleanBufferBuilder { + if builder.len() < total_num_groups { + let new_groups = total_num_groups - builder.len(); + builder.append_n(new_groups, default_value); + } + builder +} + +#[cfg(test)] +mod test { + use super::*; + + use arrow_array::UInt32Array; + use arrow_buffer::BooleanBuffer; + use hashbrown::HashSet; + use rand::{rngs::ThreadRng, Rng}; + + #[test] + fn accumulate() { + let group_indices = (0..100).collect(); + let values = (0..100).map(|i| (i + 1) * 10).collect(); + let values_with_nulls = (0..100) + .map(|i| if i % 3 == 0 { None } else { Some((i + 1) * 10) }) + .collect(); + + // default to every fifth value being false, every even + // being null + let filter: BooleanArray = (0..100) + .map(|i| { + let is_even = i % 2 == 0; + let is_fifth = i % 5 == 0; + if is_even { + None + } else if is_fifth { + Some(false) + } else { + Some(true) + } + }) + .collect(); + + Fixture { + group_indices, + values, + values_with_nulls, + filter, + } + .run() + } + + #[test] + fn accumulate_fuzz() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + Fixture::new_random(&mut rng).run(); + } + } + + /// Values for testing (there are enough values to exercise the 64 bit chunks + struct Fixture { + /// 100..0 + group_indices: Vec, + + /// 10, 20, ... 1010 + values: Vec, + + /// same as values, but every third is null: + /// None, Some(20), Some(30), None ... + values_with_nulls: Vec>, + + /// filter (defaults to None) + filter: BooleanArray, + } + + impl Fixture { + fn new_random(rng: &mut ThreadRng) -> Self { + // Number of input values in a batch + let num_values: usize = rng.gen_range(1..200); + // number of distinct groups + let num_groups: usize = rng.gen_range(2..1000); + let max_group = num_groups - 1; + + let group_indices: Vec = (0..num_values) + .map(|_| rng.gen_range(0..max_group)) + .collect(); + + let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + + // 10% chance of false + // 10% change of null + // 80% chance of true + let filter: BooleanArray = (0..num_values) + .map(|_| { + let filter_value = rng.gen_range(0.0..1.0); + if filter_value < 0.1 { + Some(false) + } else if filter_value < 0.2 { + None + } else { + Some(true) + } + }) + .collect(); + + // random values with random number and location of nulls + // random null percentage + let null_pct: f32 = rng.gen_range(0.0..1.0); + let values_with_nulls: Vec> = (0..num_values) + .map(|_| { + let is_null = null_pct < rng.gen_range(0.0..1.0); + if is_null { + None + } else { + Some(rng.gen()) + } + }) + .collect(); + + Self { + group_indices, + values, + values_with_nulls, + filter, + } + } + + /// returns `Self::values` an Array + fn values_array(&self) -> UInt32Array { + UInt32Array::from(self.values.clone()) + } + + /// returns `Self::values_with_nulls` as an Array + fn values_with_nulls_array(&self) -> UInt32Array { + UInt32Array::from(self.values_with_nulls.clone()) + } + + /// Calls `NullState::accumulate` and `accumulate_indices` + /// with all combinations of nulls and filter values + fn run(&self) { + let total_num_groups = *self.group_indices.iter().max().unwrap() + 1; + + let group_indices = &self.group_indices; + let values_array = self.values_array(); + let values_with_nulls_array = self.values_with_nulls_array(); + let filter = &self.filter; + + // no null, no filters + Self::accumulate_test(group_indices, &values_array, None, total_num_groups); + + // nulls, no filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + None, + total_num_groups, + ); + + // no nulls, filters + Self::accumulate_test( + group_indices, + &values_array, + Some(filter), + total_num_groups, + ); + + // nulls, filters + Self::accumulate_test( + group_indices, + &values_with_nulls_array, + Some(filter), + total_num_groups, + ); + } + + /// Calls `NullState::accumulate` and `accumulate_indices` to + /// ensure it generates the correct values. + /// + fn accumulate_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + Self::accumulate_values_test( + group_indices, + values, + opt_filter, + total_num_groups, + ); + Self::accumulate_indices_test(group_indices, values.nulls(), opt_filter); + + // Convert values into a boolean array (anything above the + // average is true, otherwise false) + let avg: usize = values.iter().filter_map(|v| v.map(|v| v as usize)).sum(); + let boolean_values: BooleanArray = + values.iter().map(|v| v.map(|v| v as usize > avg)).collect(); + Self::accumulate_boolean_test( + group_indices, + &boolean_values, + opt_filter, + total_num_groups, + ); + } + + /// This is effectively a different implementation of + /// accumulate that we compare with the above implementation + fn accumulate_values_test( + group_indices: &[usize], + values: &UInt32Array, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + let mut accumulated_values = vec![]; + let mut null_state = NullState::new(); + + null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, value| { + accumulated_values.push((group_index, value)); + }, + ); + + // Figure out the expected values + let mut expected_values = vec![]; + let mut mock = MockNullState::new(); + + match opt_filter { + None => group_indices.iter().zip(values.iter()).for_each( + |(&group_index, value)| { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + }, + ), + Some(filter) => { + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, value), is_included)| { + // if value passed filter + if let Some(true) = is_included { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + let seen_values = null_state.seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + + // Validate the final buffer (one value per group) + let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + + let null_buffer = null_state.build(EmitTo::All); + + assert_eq!(null_buffer, expected_null_buffer); + } + + // Calls `accumulate_indices` + // and opt_filter and ensures it calls the right values + fn accumulate_indices_test( + group_indices: &[usize], + nulls: Option<&NullBuffer>, + opt_filter: Option<&BooleanArray>, + ) { + let mut accumulated_values = vec![]; + + accumulate_indices(group_indices, nulls, opt_filter, |group_index| { + accumulated_values.push(group_index); + }); + + // Figure out the expected values + let mut expected_values = vec![]; + + match (nulls, opt_filter) { + (None, None) => group_indices.iter().for_each(|&group_index| { + expected_values.push(group_index); + }), + (Some(nulls), None) => group_indices.iter().zip(nulls.iter()).for_each( + |(&group_index, is_valid)| { + if is_valid { + expected_values.push(group_index); + } + }, + ), + (None, Some(filter)) => group_indices.iter().zip(filter.iter()).for_each( + |(&group_index, is_included)| { + if let Some(true) = is_included { + expected_values.push(group_index); + } + }, + ), + (Some(nulls), Some(filter)) => { + group_indices + .iter() + .zip(nulls.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, is_valid), is_included)| { + // if value passed filter + if let (true, Some(true)) = (is_valid, is_included) { + expected_values.push(group_index); + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + } + + /// This is effectively a different implementation of + /// accumulate_boolean that we compare with the above implementation + fn accumulate_boolean_test( + group_indices: &[usize], + values: &BooleanArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) { + let mut accumulated_values = vec![]; + let mut null_state = NullState::new(); + + null_state.accumulate_boolean( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, value| { + accumulated_values.push((group_index, value)); + }, + ); + + // Figure out the expected values + let mut expected_values = vec![]; + let mut mock = MockNullState::new(); + + match opt_filter { + None => group_indices.iter().zip(values.iter()).for_each( + |(&group_index, value)| { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + }, + ), + Some(filter) => { + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, value), is_included)| { + // if value passed filter + if let Some(true) = is_included { + if let Some(value) = value { + mock.saw_value(group_index); + expected_values.push((group_index, value)); + } + } + }); + } + } + + assert_eq!(accumulated_values, expected_values, + "\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"); + + let seen_values = null_state.seen_values.finish_cloned(); + mock.validate_seen_values(&seen_values); + + // Validate the final buffer (one value per group) + let expected_null_buffer = mock.expected_null_buffer(total_num_groups); + + let null_buffer = null_state.build(EmitTo::All); + + assert_eq!(null_buffer, expected_null_buffer); + } + } + + /// Parallel implementaiton of NullState to check expected values + #[derive(Debug, Default)] + struct MockNullState { + /// group indices that had values that passed the filter + seen_values: HashSet, + } + + impl MockNullState { + fn new() -> Self { + Default::default() + } + + fn saw_value(&mut self, group_index: usize) { + self.seen_values.insert(group_index); + } + + /// did this group index see any input? + fn expected_seen(&self, group_index: usize) -> bool { + self.seen_values.contains(&group_index) + } + + /// Validate that the seen_values matches self.seen_values + fn validate_seen_values(&self, seen_values: &BooleanBuffer) { + for (group_index, is_seen) in seen_values.iter().enumerate() { + let expected_seen = self.expected_seen(group_index); + assert_eq!( + expected_seen, is_seen, + "mismatch at for group {group_index}" + ); + } + } + + /// Create the expected null buffer based on if the input had nulls and a filter + fn expected_null_buffer(&self, total_num_groups: usize) -> NullBuffer { + (0..total_num_groups) + .map(|group_index| self.expected_seen(group_index)) + .collect() + } + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs new file mode 100644 index 0000000000000..cf980f4c3f167 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] + +use super::{EmitTo, GroupsAccumulator}; +use arrow::{ + array::{AsArray, UInt32Builder}, + compute, + datatypes::UInt32Type, +}; +use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; +use datafusion_common::{ + utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::Accumulator; + +/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] +/// +/// While [`Accumulator`] are simpler to implement and can support +/// more general calculations (like retractable window functions), +/// they are not as fast as a specialized `GroupsAccumulator`. This +/// interface bridges the gap so the group by operator only operates +/// in terms of [`Accumulator`]. +pub struct GroupsAccumulatorAdapter { + factory: Box Result> + Send>, + + /// state for each group, stored in group_index order + states: Vec, + + /// Current memory usage, in bytes. + /// + /// Note this is incrementally updated with deltas to avoid the + /// call to size() being a bottleneck. We saw size() being a + /// bottleneck in earlier implementations when there were many + /// distinct groups. + allocation_bytes: usize, +} + +struct AccumulatorState { + /// [`Accumulator`] that stores the per-group state + accumulator: Box, + + // scratch space: indexes in the input array that will be fed to + // this accumulator. Stores indexes as `u32` to match the arrow + // `take` kernel input. + indices: Vec, +} + +impl AccumulatorState { + fn new(accumulator: Box) -> Self { + Self { + accumulator, + indices: vec![], + } + } + + /// Returns the amount of memory taken by this structre and its accumulator + fn size(&self) -> usize { + self.accumulator.size() + + std::mem::size_of_val(self) + + self.indices.allocated_size() + } +} + +impl GroupsAccumulatorAdapter { + /// Create a new adapter that will create a new [`Accumulator`] + /// for each group, using the specified factory function + pub fn new(factory: F) -> Self + where + F: Fn() -> Result> + Send + 'static, + { + Self { + factory: Box::new(factory), + states: vec![], + allocation_bytes: 0, + } + } + + /// Ensure that self.accumulators has total_num_groups + fn make_accumulators_if_needed(&mut self, total_num_groups: usize) -> Result<()> { + // can't shrink + assert!(total_num_groups >= self.states.len()); + let vec_size_pre = self.states.allocated_size(); + + // instantiate new accumulators + let new_accumulators = total_num_groups - self.states.len(); + for _ in 0..new_accumulators { + let accumulator = (self.factory)()?; + let state = AccumulatorState::new(accumulator); + self.add_allocation(state.size()); + self.states.push(state); + } + + self.adjust_allocation(vec_size_pre, self.states.allocated_size()); + Ok(()) + } + + /// invokes f(accumulator, values) for each group that has values + /// in group_indices. + /// + /// This function first reorders the input and filter so that + /// values for each group_index are contiguous and then invokes f + /// on the contiguous ranges, to minimize per-row overhead + /// + /// ```text + /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ + /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ ┏━━━━━┓ │ ┌─────┐ │ ┌─────┐ + /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ ┃ 0 ┃ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ ┃ 0 ┃ │ │ 300 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ ┃ 1 ┃ │ │ 200 │ │ │ │NULL │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ────────▶ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ ┃ 2 ┃ │ │ 200 │ │ │ │ t │ │ + /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ ┣━━━━━┫ │ ├─────┤ │ ├─────┤ + /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ ┃ 2 ┃ │ │ 100 │ │ │ │ f │ │ + /// │ └─────┘ │ │ └─────┘ │ └─────┘ ┗━━━━━┛ │ └─────┘ │ └─────┘ + /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// + /// logical group values opt_filter logical group values opt_filter + /// + /// ``` + fn invoke_per_accumulator( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + f: F, + ) -> Result<()> + where + F: Fn(&mut dyn Accumulator, &[ArrayRef]) -> Result<()>, + { + self.make_accumulators_if_needed(total_num_groups)?; + + assert_eq!(values[0].len(), group_indices.len()); + + // figure out which input rows correspond to which groups. + // Note that self.state.indices starts empty for all groups + // (it is cleared out below) + for (idx, group_index) in group_indices.iter().enumerate() { + self.states[*group_index].indices.push(idx as u32); + } + + // groups_with_rows holds a list of group indexes that have + // any rows that need to be accumulated, stored in order of + // group_index + + let mut groups_with_rows = vec![]; + + // batch_indices holds indices into values, each group is contiguous + let mut batch_indices = UInt32Builder::with_capacity(0); + + // offsets[i] is index into batch_indices where the rows for + // group_index i starts + let mut offsets = vec![0]; + + let mut offset_so_far = 0; + for (group_index, state) in self.states.iter_mut().enumerate() { + let indices = &state.indices; + if indices.is_empty() { + continue; + } + + groups_with_rows.push(group_index); + batch_indices.append_slice(indices); + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); + + // reorder the values and opt_filter by batch_indices so that + // all values for each group are contiguous, then invoke the + // accumulator once per group with values + let values = get_arrayref_at_indices(values, &batch_indices)?; + let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; + + // invoke each accumulator with the appropriate rows, first + // pulling the input arguments for this group into their own + // RecordBatch(es) + let iter = groups_with_rows.iter().zip(offsets.windows(2)); + + let mut sizes_pre = 0; + let mut sizes_post = 0; + for (&group_idx, offsets) in iter { + let state = &mut self.states[group_idx]; + sizes_pre += state.size(); + + let values_to_accumulate = + slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?; + (f)(state.accumulator.as_mut(), &values_to_accumulate)?; + + // clear out the state so they are empty for next + // iteration + state.indices.clear(); + sizes_post += state.size(); + } + + self.adjust_allocation(sizes_pre, sizes_post); + Ok(()) + } + + /// Increment the allocation by `n` + /// + /// See [`Self::allocation_bytes`] for rationale. + fn add_allocation(&mut self, size: usize) { + self.allocation_bytes += size; + } + + /// Decrease the allocation by `n` + /// + /// See [`Self::allocation_bytes`] for rationale. + fn free_allocation(&mut self, size: usize) { + // use saturating sub to avoid errors if the accumulators + // report erronious sizes + self.allocation_bytes = self.allocation_bytes.saturating_sub(size) + } + + /// Adjusts the allocation for something that started with + /// start_size and now has new_size avoiding overflow + /// + /// See [`Self::allocation_bytes`] for rationale. + fn adjust_allocation(&mut self, old_size: usize, new_size: usize) { + if new_size > old_size { + self.add_allocation(new_size - old_size) + } else { + self.free_allocation(old_size - new_size) + } + } +} + +impl GroupsAccumulator for GroupsAccumulatorAdapter { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.invoke_per_accumulator( + values, + group_indices, + opt_filter, + total_num_groups, + |accumulator, values_to_accumulate| { + accumulator.update_batch(values_to_accumulate) + }, + )?; + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let vec_size_pre = self.states.allocated_size(); + + let states = emit_to.take_needed(&mut self.states); + + let results: Vec = states + .into_iter() + .map(|state| { + self.free_allocation(state.size()); + state.accumulator.evaluate() + }) + .collect::>()?; + + let result = ScalarValue::iter_to_array(results); + + self.adjust_allocation(vec_size_pre, self.states.allocated_size()); + + result + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let vec_size_pre = self.states.allocated_size(); + let states = emit_to.take_needed(&mut self.states); + + // each accumulator produces a potential vector of values + // which we need to form into columns + let mut results: Vec> = vec![]; + + for state in states { + self.free_allocation(state.size()); + let accumulator_state = state.accumulator.state()?; + results.resize_with(accumulator_state.len(), Vec::new); + for (idx, state_val) in accumulator_state.into_iter().enumerate() { + results[idx].push(state_val); + } + } + + // create an array for each intermediate column + let arrays = results + .into_iter() + .map(ScalarValue::iter_to_array) + .collect::>>()?; + + // double check each array has the same length (aka the + // accumulator was implemented correctly + if let Some(first_col) = arrays.first() { + for arr in &arrays { + assert_eq!(arr.len(), first_col.len()) + } + } + self.adjust_allocation(vec_size_pre, self.states.allocated_size()); + + Ok(arrays) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.invoke_per_accumulator( + values, + group_indices, + opt_filter, + total_num_groups, + |accumulator, values_to_accumulate| { + accumulator.merge_batch(values_to_accumulate)?; + Ok(()) + }, + )?; + Ok(()) + } + + fn size(&self) -> usize { + self.allocation_bytes + } +} + +/// Extension trait for [`Vec`] to account for allocations. +pub trait VecAllocExt { + /// Item type. + type T; + /// Return the amount of memory allocated by this Vec (not + /// recursively counting any heap allocations contained within the + /// structure). Does not include the size of `self` + fn allocated_size(&self) -> usize; +} + +impl VecAllocExt for Vec { + type T = T; + fn allocated_size(&self) -> usize { + std::mem::size_of::() * self.capacity() + } +} + +fn get_filter_at_indices( + opt_filter: Option<&BooleanArray>, + indices: &PrimitiveArray, +) -> Result> { + opt_filter + .map(|filter| { + compute::take( + &filter, indices, None, // None: no index check + ) + }) + .transpose() + .map_err(DataFusionError::ArrowError) +} + +// Copied from physical-plan +pub(crate) fn slice_and_maybe_filter( + aggr_array: &[ArrayRef], + filter_opt: Option<&ArrayRef>, + offsets: &[usize], +) -> Result> { + let (offset, length) = (offsets[0], offsets[1] - offsets[0]); + let sliced_arrays: Vec = aggr_array + .iter() + .map(|array| array.slice(offset, length)) + .collect(); + + if let Some(f) = filter_opt { + let filter_array = f.slice(offset, length); + let filter_array = filter_array.as_boolean(); + + sliced_arrays + .iter() + .map(|array| { + compute::filter(array, filter_array).map_err(DataFusionError::ArrowError) + }) + .collect() + } else { + Ok(sliced_arrays) + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs new file mode 100644 index 0000000000000..21b6cc29e83df --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::AsArray; +use arrow_array::{ArrayRef, BooleanArray}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; +use datafusion_common::Result; + +use crate::GroupsAccumulator; + +use super::{accumulate::NullState, EmitTo}; + +/// An accumulator that implements a single operation over a +/// [`BooleanArray`] where the accumulated state is also boolean (such +/// as [`BitAndAssign`]) +/// +/// F: The function to apply to two elements. The first argument is +/// the existing value and should be updated with the second value +/// (e.g. [`BitAndAssign`] style). +/// +/// [`BitAndAssign`]: std::ops::BitAndAssign +#[derive(Debug)] +pub struct BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + /// values per group + values: BooleanBufferBuilder, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the output + bool_fn: F, +} + +impl BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + pub fn new(bitop_fn: F) -> Self { + Self { + values: BooleanBufferBuilder::new(0), + null_state: NullState::new(), + bool_fn: bitop_fn, + } + } +} + +impl GroupsAccumulator for BooleanGroupsAccumulator +where + F: Fn(bool, bool) -> bool + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_boolean(); + + if self.values.len() < total_num_groups { + let new_groups = total_num_groups - self.values.len(); + self.values.append_n(new_groups, Default::default()); + } + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate_boolean( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let current_value = self.values.get_bit(group_index); + let value = (self.bool_fn)(current_value, new_value); + self.values.set_bit(group_index, value); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let values = self.values.finish(); + + let values = match emit_to { + EmitTo::All => values, + EmitTo::First(n) => { + let first_n: BooleanBuffer = values.iter().take(n).collect(); + // put n+1 back into self.values + for v in values.iter().skip(n) { + self.values.append(v); + } + first_n + } + }; + + let nulls = self.null_state.build(emit_to); + let values = BooleanArray::new(values, Some(nulls)); + Ok(Arc::new(values)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn size(&self) -> usize { + // capacity is in bits, so convert to bytes + self.values.capacity() / 8 + self.null_state.size() + } +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs new file mode 100644 index 0000000000000..d2e64d373be29 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vectorized [`GroupsAccumulator`] + +pub(crate) mod accumulate; +mod adapter; +pub use adapter::GroupsAccumulatorAdapter; + +pub(crate) mod bool_op; +pub(crate) mod prim_op; + +use arrow_array::{ArrayRef, BooleanArray}; +use datafusion_common::Result; + +/// Describes how many rows should be emitted during grouping. +#[derive(Debug, Clone, Copy)] +pub enum EmitTo { + /// Emit all groups + All, + /// Emit only the first `n` groups and shift all existing group + /// indexes down by `n`. + /// + /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted + /// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`. + First(usize), +} + +impl EmitTo { + /// Removes the number of rows from `v` required to emit the right + /// number of rows, returning a `Vec` with elements taken, and the + /// remaining values in `v`. + /// + /// This avoids copying if Self::All + pub fn take_needed(&self, v: &mut Vec) -> Vec { + match self { + Self::All => { + // Take the entire vector, leave new (empty) vector + std::mem::take(v) + } + Self::First(n) => { + // get end n+1,.. values into t + let mut t = v.split_off(*n); + // leave n+1,.. in v + std::mem::swap(v, &mut t); + t + } + } + } +} + +/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and +/// stores the state for *all* groups internally. +/// +/// Each group is assigned a `group_index` by the hash table and each +/// accumulator manages the specific state, one per group_index. +/// +/// group_indexes are contiguous (there aren't gaps), and thus it is +/// expected that each GroupAccumulator will use something like `Vec<..>` +/// to store the group states. +pub trait GroupsAccumulator: Send { + /// Updates the accumulator's state from its arguments, encoded as + /// a vector of [`ArrayRef`]s. + /// + /// * `values`: the input arguments to the accumulator + /// + /// * `group_indices`: To which groups do the rows in `values` + /// belong, group id) + /// + /// * `opt_filter`: if present, only update aggregate state using + /// `values[i]` if `opt_filter[i]` is true + /// + /// * `total_num_groups`: the number of groups (the largest + /// group_index is thus `total_num_groups - 1`). + /// + /// Note that subsequent calls to update_batch may have larger + /// total_num_groups as new groups are seen. + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Returns the final aggregate value for each group as a single + /// `RecordBatch`, resetting the internal state. + /// + /// The rows returned *must* be in group_index order: The value + /// for group_index 0, followed by 1, etc. Any group_index that + /// did not have values, should be null. + /// + /// For example, a `SUM` accumulator maintains a running sum for + /// each group, and `evaluate` will produce that running sum as + /// its output for all groups, in group_index order + /// + /// If `emit_to`` is [`EmitTo::All`], the accumulator should + /// return all groups and release / reset its internal state + /// equivalent to when it was first created. + /// + /// If `emit_to` is [`EmitTo::First`], only the first `n` groups + /// should be emitted and the state for those first groups + /// removed. State for the remaining groups must be retained for + /// future use. The group_indices on subsequent calls to + /// `update_batch` or `merge_batch` will be shifted down by + /// `n`. See [`EmitTo::First`] for more details. + fn evaluate(&mut self, emit_to: EmitTo) -> Result; + + /// Returns the intermediate aggregate state for this accumulator, + /// used for multi-phase grouping, resetting its internal state. + /// + /// For example, `AVG` might return two arrays: `SUM` and `COUNT` + /// but the `MIN` aggregate would just return a single array. + /// + /// Note more sophisticated internal state can be passed as + /// single `StructArray` rather than multiple arrays. + /// + /// See [`Self::evaluate`] for details on the required output + /// order and `emit_to`. + fn state(&mut self, emit_to: EmitTo) -> Result>; + + /// Merges intermediate state (the output from [`Self::state`]) + /// into this accumulator's values. + /// + /// For some aggregates (such as `SUM`), `merge_batch` is the same + /// as `update_batch`, but for some aggregates (such as `COUNT`, + /// where the partial counts must be summed) the operations + /// differ. See [`Self::state`] for more details on how state is + /// used and merged. + /// + /// * `values`: arrays produced from calling `state` previously to the accumulator + /// + /// Other arguments are the same as for [`Self::update_batch`]; + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()>; + + /// Amount of memory used to store the state of this accumulator, + /// in bytes. This function is called once per batch, so it should + /// be `O(n)` to compute, not `O(num_groups)` + fn size(&self) -> usize; +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs new file mode 100644 index 0000000000000..130d562712800 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{array::AsArray, datatypes::ArrowPrimitiveType}; +use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_common::Result; + +use crate::GroupsAccumulator; + +use super::{accumulate::NullState, EmitTo}; + +/// An accumulator that implements a single operation over +/// [`ArrowPrimitiveType`] where the accumulated state is the same as +/// the input type (such as `Sum`) +/// +/// F: The function to apply to two elements. The first argument is +/// the existing value and should be updated with the second value +/// (e.g. [`BitAndAssign`] style). +/// +/// [`BitAndAssign`]: std::ops::BitAndAssign +#[derive(Debug)] +pub struct PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + /// values per group, stored as the native type + values: Vec, + + /// The output type (needed for Decimal precision and scale) + data_type: DataType, + + /// The starting value for new groups + starting_value: T::Native, + + /// Track nulls in the input / filters + null_state: NullState, + + /// Function that computes the primitive result + prim_fn: F, +} + +impl PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + pub fn new(data_type: &DataType, prim_fn: F) -> Self { + Self { + values: vec![], + data_type: data_type.clone(), + null_state: NullState::new(), + starting_value: T::default_value(), + prim_fn, + } + } + + /// Set the starting values for new groups + pub fn with_starting_value(mut self, starting_value: T::Native) -> Self { + self.starting_value = starting_value; + self + } +} + +impl GroupsAccumulator for PrimitiveGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + F: Fn(&mut T::Native, T::Native) + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + // update values + self.values.resize(total_num_groups, self.starting_value); + + // NullState dispatches / handles tracking nulls and groups that saw no values + self.null_state.accumulate( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value| { + let value = &mut self.values[group_index]; + (self.prim_fn)(value, new_value); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + let values = PrimitiveArray::::new(values.into(), Some(nulls)) // no copy + .with_data_type(self.data_type.clone()); + Ok(Arc::new(values)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn size(&self) -> usize { + self.values.capacity() * std::mem::size_of::() + self.null_state.size() + } +} diff --git a/datafusion/physical-expr/src/aggregate/hyperloglog.rs b/datafusion/physical-expr/src/aggregate/hyperloglog.rs index bf25ecebcd661..a0d55ca71db14 100644 --- a/datafusion/physical-expr/src/aggregate/hyperloglog.rs +++ b/datafusion/physical-expr/src/aggregate/hyperloglog.rs @@ -68,6 +68,15 @@ const SEED: RandomState = RandomState::with_seeds( 0x0eaea5d736d733a4_u64, ); +impl Default for HyperLogLog +where + T: Hash + ?Sized, +{ + fn default() -> Self { + Self::new() + } +} + impl HyperLogLog where T: Hash + ?Sized, diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 62114d624c7d7..691b1c1752f41 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -20,12 +20,15 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef, UInt32Array}; -use arrow::compute::sort_to_indices; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType}; +use arrow_buffer::ArrowNativeType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; +use std::fmt::Formatter; use std::sync::Arc; /// MEDIAN aggregate expression. This uses a lot of memory because all values need to be @@ -64,10 +67,29 @@ impl AggregateExpr for Median { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(MedianAccumulator { - data_type: self.data_type.clone(), - all_values: vec![], - })) + use arrow_array::types::*; + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(MedianAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + })) + }; + } + let dt = &self.data_type; + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -104,129 +126,209 @@ impl PartialEq for Median { } } -#[derive(Debug)] /// The median accumulator accumulates the raw input values /// as `ScalarValue`s /// -/// The intermediate state is represented as a List of those scalars -struct MedianAccumulator { +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct MedianAccumulator { data_type: DataType, - all_values: Vec, + all_values: Vec, +} + +impl std::fmt::Debug for MedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MedianAccumulator({})", self.data_type) + } } -impl Accumulator for MedianAccumulator { +impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let state = - ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone()); - Ok(vec![state]) + let all_values = self + .all_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1); - let array = &values[0]; - - assert_eq!(array.data_type(), &self.data_type); - self.all_values.reserve(self.all_values.len() + array.len()); - for index in 0..array.len() { - self.all_values - .push(ScalarValue::try_from_array(array, index)?); - } - + let values = values[0].as_primitive::(); + self.all_values.reserve(values.len() - values.null_count()); + self.all_values.extend(values.iter().flatten()); Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - assert_eq!(states.len(), 1); - - let array = &states[0]; - assert!(matches!(array.data_type(), DataType::List(_))); - for index in 0..array.len() { - match ScalarValue::try_from_array(array, index)? { - ScalarValue::List(Some(mut values), _) => { - self.all_values.append(&mut values); - } - ScalarValue::List(None, _) => {} // skip empty state - v => { - return Err(DataFusionError::Internal(format!( - "unexpected state in median. Expected DataType::List, got {v:?}" - ))) - } - } + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? } Ok(()) } fn evaluate(&self) -> Result { - if !self.all_values.iter().any(|v| !v.is_null()) { - return ScalarValue::try_from(&self.data_type); - } + // TODO: evaluate could pass &mut self + let mut d = self.all_values.clone(); + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); - // Create an array of all the non null values and find the - // sorted indexes - let array = ScalarValue::iter_to_array( - self.all_values - .iter() - // ignore null values - .filter(|v| !v.is_null()) - .cloned(), - )?; - - // find the mid point - let len = array.len(); - let mid = len / 2; - - // only sort up to the top size/2 elements - let limit = Some(mid + 1); - let options = None; - let indices = sort_to_indices(&array, options, limit)?; - - // pick the relevant indices in the original arrays - let result = if len >= 2 && len % 2 == 0 { - // even number of values, average the two mid points - let s1 = scalar_at_index(&array, &indices, mid - 1)?; - let s2 = scalar_at_index(&array, &indices, mid)?; - match s1.add(s2)? { - ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)), - ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)), - ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)), - ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)), - ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)), - ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)), - ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)), - ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)), - ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)), - ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)), - v => { - return Err(DataFusionError::Internal(format!( - "Unsupported type in MedianAccumulator: {v:?}" - ))) - } - } + let len = d.len(); + let median = if len == 0 { + None + } else if len % 2 == 0 { + let (low, high, _) = d.select_nth_unstable_by(len / 2, cmp); + let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); + let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + Some(median) } else { - // odd number of values, pick that one - scalar_at_index(&array, &indices, mid)? + let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp); + Some(*median) }; - - Ok(result) + ScalarValue::new_primitive::(median, &self.data_type) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.all_values) - - std::mem::size_of_val(&self.all_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) + std::mem::size_of_val(self) + + self.all_values.capacity() * std::mem::size_of::() } } -/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` -fn scalar_at_index( - array: &dyn Array, - indices: &UInt32Array, - indicies_index: usize, -) -> Result { - let array_index = indices - .value(indicies_index) - .try_into() - .expect("Convert uint32 to usize"); - ScalarValue::try_from_array(array, array_index) +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::col; + use crate::expressions::tests::aggregate; + use crate::generic_test_op; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn median_decimal() -> Result<()> { + // test median + let array: ArrayRef = Arc::new( + (1..7) + .map(Some) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(Some(3), 10, 4) + ) + } + + #[test] + fn median_decimal_with_nulls() -> Result<()> { + let array: ArrayRef = Arc::new( + (1..6) + .map(|i| if i == 2 { None } else { Some(i) }) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(Some(3), 10, 4) + ) + } + + #[test] + fn median_decimal_all_nulls() -> Result<()> { + // test median + let array: ArrayRef = Arc::new( + std::iter::repeat::>(None) + .take(6) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + Median, + ScalarValue::Decimal128(None, 10, 4) + ) + } + + #[test] + fn median_i32_odd() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + } + + #[test] + fn median_i32_even() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + } + + #[test] + fn median_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3i32)) + } + + #[test] + fn median_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + generic_test_op!(a, DataType::Int32, Median, ScalarValue::Int32(None)) + } + + #[test] + fn median_u32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + } + + #[test] + fn median_u32_even() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from(vec![ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, 6_u32, + ])); + generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + } + + #[test] + fn median_f32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3_f32)) + } + + #[test] + fn median_f32_even() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from(vec![ + 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, 6_f32, + ])); + generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3.5_f32)) + } + + #[test] + fn median_f64_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3_f64)) + } + + #[test] + fn median_f64_even() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, 6_f64, + ])); + generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) + } } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index f811dae7b5609..7e3ef2a2ababb 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -21,32 +21,41 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; use arrow::compute; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; use arrow::{ array::{ - ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + LargeStringArray, StringArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::Field, }; +use arrow_array::types::{ + Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion_common::internal_err; use datafusion_common::ScalarValue; use datafusion_common::{downcast_value, DataFusionError, Result}; use datafusion_expr::Accumulator; -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use arrow::array::Array; use arrow::array::Decimal128Array; -use datafusion_row::accessor::RowAccessor; +use arrow::array::Decimal256Array; +use arrow::datatypes::i256; +use arrow::datatypes::Decimal256Type; use super::moving_min_max; @@ -86,6 +95,48 @@ impl Max { } } } +/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` +/// the specified [`ArrowPrimitiveType`]. +/// +/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType +macro_rules! instantiate_max_accumulator { + ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + &$SELF.data_type, + |cur, new| { + if *cur < new { + *cur = new + } + }, + ) + // Initialize each accumulator to $NATIVE::MIN + .with_starting_value($NATIVE::MIN), + )) + }}; +} + +/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` +/// the specified [`ArrowPrimitiveType`]. +/// +/// +/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType +macro_rules! instantiate_min_accumulator { + ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + &$SELF.data_type, + |cur, new| { + if *cur > new { + *cur = new + } + }, + ) + // Initialize each accumulator to $NATIVE::MAX + .with_starting_value($NATIVE::MAX), + )) + }}; +} impl AggregateExpr for Max { /// Return a reference to Any that can be used for downcasting @@ -121,22 +172,90 @@ impl AggregateExpr for Max { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + fn groups_accumulator_supported(&self) -> bool { + use DataType::*; + matches!( + self.data_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } - fn supports_bounded_execution(&self) -> bool { - true - } + fn create_groups_accumulator(&self) -> Result> { + use DataType::*; + use TimeUnit::*; + + match self.data_type { + Int8 => instantiate_max_accumulator!(self, i8, Int8Type), + Int16 => instantiate_max_accumulator!(self, i16, Int16Type), + Int32 => instantiate_max_accumulator!(self, i32, Int32Type), + Int64 => instantiate_max_accumulator!(self, i64, Int64Type), + UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type), + UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type), + UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type), + UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type), + Float32 => { + instantiate_max_accumulator!(self, f32, Float32Type) + } + Float64 => { + instantiate_max_accumulator!(self, f64, Float64Type) + } + Date32 => instantiate_max_accumulator!(self, i32, Date32Type), + Date64 => instantiate_max_accumulator!(self, i64, Date64Type), + Time32(Second) => { + instantiate_max_accumulator!(self, i32, Time32SecondType) + } + Time32(Millisecond) => { + instantiate_max_accumulator!(self, i32, Time32MillisecondType) + } + Time64(Microsecond) => { + instantiate_max_accumulator!(self, i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + instantiate_max_accumulator!(self, i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + instantiate_max_accumulator!(self, i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + instantiate_max_accumulator!(self, i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + instantiate_max_accumulator!(self, i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + instantiate_max_accumulator!(self, i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + instantiate_max_accumulator!(self, i128, Decimal128Type) + } + Decimal256(_, _) => { + instantiate_max_accumulator!(self, i256, Decimal256Type) + } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(MaxRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/arrow-datafusion/issues/6906 + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!( + "GroupsAccumulator not supported for max({})", + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -172,6 +291,16 @@ macro_rules! typed_min_max_batch_string { }}; } +// Statically-typed version of min/max(array) -> ScalarValue for binay types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + // Statically-typed version of min/max(array) -> ScalarValue for non-string types. macro_rules! typed_min_max_batch { ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ @@ -196,6 +325,16 @@ macro_rules! min_max_batch { scale ) } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } // all types that have a natural order DataType::Float64 => { typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) @@ -272,10 +411,10 @@ macro_rules! min_max_batch { } other => { // This should have been handled before - return Err(DataFusionError::Internal(format!( + return internal_err!( "Min/Max accumulator not implemented for type {:?}", other - ))); + ); } } }}; @@ -293,6 +432,17 @@ fn min_batch(values: &ArrayRef) -> Result { DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } _ => min_max_batch!(values, min), }) } @@ -309,6 +459,17 @@ fn max_batch(values: &ArrayRef) -> Result { DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } _ => min_max_batch!(values, max), }) } @@ -328,18 +489,6 @@ macro_rules! typed_min_max { }}; } -// min/max of two non-string scalar values. -macro_rules! typed_min_max_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $TYPE:ident, $OP:ident) => {{ - paste::item! { - match $SCALAR { - None => {} - Some(v) => $ACC.[<$OP _ $TYPE>]($INDEX, *v as $TYPE) - } - } - }}; -} - // min/max of two scalar string values. macro_rules! typed_min_max_string { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ @@ -367,9 +516,7 @@ macro_rules! interval_min_max { Some(interval_choose_min_max!($OP)) => $RHS.clone(), Some(_) => $LHS.clone(), None => { - return Err(DataFusionError::Internal( - "Comparison error while computing interval min/max".to_string(), - )) + return internal_err!("Comparison error while computing interval min/max") } } }}; @@ -386,10 +533,23 @@ macro_rules! min_max { if lhsp.eq(rhsp) && lhss.eq(rhss) { typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) } else { - return Err(DataFusionError::Internal(format!( + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", (lhs, rhs) - ))); + ); } } (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { @@ -431,6 +591,12 @@ macro_rules! min_max { (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) } @@ -527,64 +693,35 @@ macro_rules! min_max { ) => { interval_min_max!($OP, $VALUE, $DELTA) } - e => { - return Err(DataFusionError::Internal(format!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - e - ))) - } - }) - }}; -} - -// min/max of two scalar values of the same type -macro_rules! min_max_v2 { - ($INDEX:ident, $ACC:ident, $SCALAR:expr, $OP:ident) => {{ - Ok(match $SCALAR { - ScalarValue::Boolean(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, bool, $OP) - } - ScalarValue::Float64(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, f64, $OP) - } - ScalarValue::Float32(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, f32, $OP) - } - ScalarValue::UInt64(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, u64, $OP) - } - ScalarValue::UInt32(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, u32, $OP) - } - ScalarValue::UInt16(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, u16, $OP) - } - ScalarValue::UInt8(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, u8, $OP) - } - ScalarValue::Int64(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, i64, $OP) - } - ScalarValue::Int32(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, i32, $OP) - } - ScalarValue::Int16(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, i16, $OP) + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) } - ScalarValue::Int8(rhs) => { - typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP) + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) } - ScalarValue::Decimal128(rhs, ..) => { - typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP) + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) } - ScalarValue::Null => { - // do nothing + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) } e => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", e - ))) + ) } }) }}; @@ -595,19 +732,11 @@ pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, min) } -pub fn min_row(index: usize, accessor: &mut RowAccessor, s: &ScalarValue) -> Result<()> { - min_max_v2!(index, accessor, s, min) -} - /// the maximum of two scalar values pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, max) } -pub fn max_row(index: usize, accessor: &mut RowAccessor, s: &ScalarValue) -> Result<()> { - min_max_v2!(index, accessor, s, max) -} - /// An accumulator to compute the maximum value #[derive(Debug)] pub struct MaxAccumulator { @@ -699,66 +828,12 @@ impl Accumulator for SlidingMaxAccumulator { Ok(self.max.clone()) } - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() - } -} - -#[derive(Debug)] -struct MaxRowAccumulator { - index: usize, - data_type: DataType, -} - -impl MaxRowAccumulator { - pub fn new(index: usize, data_type: DataType) -> Self { - Self { index, data_type } - } -} - -impl RowAccumulator for MaxRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &max_batch(values)?; - max_row(self.index, accessor, delta) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - max_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - max_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.data_type, self.index)) + fn supports_retract_batch(&self) -> bool { + true } - #[inline(always)] - fn state_index(&self) -> usize { - self.index + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() } } @@ -821,22 +896,85 @@ impl AggregateExpr for Min { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) - } - - fn supports_bounded_execution(&self) -> bool { - true + fn groups_accumulator_supported(&self) -> bool { + use DataType::*; + matches!( + self.data_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(MinRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + use DataType::*; + use TimeUnit::*; + match self.data_type { + Int8 => instantiate_min_accumulator!(self, i8, Int8Type), + Int16 => instantiate_min_accumulator!(self, i16, Int16Type), + Int32 => instantiate_min_accumulator!(self, i32, Int32Type), + Int64 => instantiate_min_accumulator!(self, i64, Int64Type), + UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type), + UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type), + UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type), + UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type), + Float32 => { + instantiate_min_accumulator!(self, f32, Float32Type) + } + Float64 => { + instantiate_min_accumulator!(self, f64, Float64Type) + } + Date32 => instantiate_min_accumulator!(self, i32, Date32Type), + Date64 => instantiate_min_accumulator!(self, i64, Date64Type), + Time32(Second) => { + instantiate_min_accumulator!(self, i32, Time32SecondType) + } + Time32(Millisecond) => { + instantiate_min_accumulator!(self, i32, Time32MillisecondType) + } + Time64(Microsecond) => { + instantiate_min_accumulator!(self, i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + instantiate_min_accumulator!(self, i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + instantiate_min_accumulator!(self, i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + instantiate_min_accumulator!(self, i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + instantiate_min_accumulator!(self, i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + instantiate_min_accumulator!(self, i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + instantiate_min_accumulator!(self, i128, Decimal128Type) + } + Decimal256(_, _) => { + instantiate_min_accumulator!(self, i256, Decimal256Type) + } + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!( + "GroupsAccumulator not supported for min({})", + self.data_type + ), + } } fn reverse_expr(&self) -> Option> { @@ -958,67 +1096,12 @@ impl Accumulator for SlidingMinAccumulator { Ok(self.min.clone()) } - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() - } -} - -#[derive(Debug)] -struct MinRowAccumulator { - index: usize, - data_type: DataType, -} - -impl MinRowAccumulator { - pub fn new(index: usize, data_type: DataType) -> Self { - Self { index, data_type } - } -} - -impl RowAccumulator for MinRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = &min_batch(values)?; - min_row(self.index, accessor, delta)?; - Ok(()) - } - - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - min_row(self.index, accessor, value) - } - - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - min_row(self.index, accessor, value) - } - - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) - } - - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.data_type, self.index)) + fn supports_retract_batch(&self) -> bool { + true } - #[inline(always)] - fn state_index(&self) -> usize { - self.index + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } } @@ -1026,8 +1109,8 @@ impl RowAccumulator for MinRowAccumulator { mod tests { use super::*; use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; + use crate::expressions::tests::{aggregate, aggregate_new}; + use crate::{generic_test_op, generic_test_op_new}; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use datafusion_common::Result; @@ -1123,11 +1206,14 @@ mod tests { let right = ScalarValue::Decimal128(Some(124), 10, 3); let result = max(&left, &right); - let expect = DataFusionError::Internal(format!( + let err_msg = format!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) - )); - assert_eq!(expect.to_string(), result.unwrap_err().to_string()); + ); + let expect = DataFusionError::Internal(err_msg); + assert!(expect + .strip_backtrace() + .starts_with(&result.unwrap_err().strip_backtrace())); // max batch let array: ArrayRef = Arc::new( @@ -1211,12 +1297,7 @@ mod tests { #[test] fn max_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Max, - ScalarValue::Utf8(Some("d".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) } #[test] @@ -1233,12 +1314,7 @@ mod tests { #[test] fn min_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::Utf8, - Min, - ScalarValue::Utf8(Some("a".to_string())) - ) + generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) } #[test] @@ -1442,6 +1518,26 @@ mod tests { ) } + #[test] + fn max_new_timestamp_micro() -> Result<()> { + let dt = DataType::Timestamp(TimeUnit::Microsecond, None); + let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) + .with_data_type(dt.clone()); + let expected: ArrayRef = + Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); + generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) + } + + #[test] + fn max_new_timestamp_micro_with_tz() -> Result<()> { + let dt = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); + let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) + .with_data_type(dt.clone()); + let expected: ArrayRef = + Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); + generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) + } + #[test] fn max_bool() -> Result<()> { let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 09fd9bcfc524a..329bb1e6415ec 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,22 +15,24 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregate::row_accumulator::RowAccumulator; -use crate::expressions::{ArrayAgg, FirstValue, LastValue}; -use crate::PhysicalExpr; +use crate::expressions::{FirstValue, LastValue, OrderSensitiveArrayAgg}; +use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Field; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; +use self::groups_accumulator::GroupsAccumulator; + pub(crate) mod approx_distinct; pub(crate) mod approx_median; pub(crate) mod approx_percentile_cont; pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; +pub(crate) mod array_agg_ordered; pub(crate) mod average; pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; @@ -41,12 +43,14 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub mod build_in; +pub(crate) mod groups_accumulator; mod hyperloglog; pub mod moving_min_max; -pub mod row_accumulator; +pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; @@ -65,7 +69,7 @@ pub(crate) mod variance; /// `PartialEq` to allows comparing equality between the /// trait objects. pub trait AggregateExpr: Send + Sync + Debug + PartialEq { - /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -85,35 +89,33 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. fn expressions(&self) -> Vec>; + /// Order by requirements for the aggregate function + /// By default it is `None` (there is no requirement) + /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + None + } + /// Human readable name such as `"MIN(c2)"`. The default /// implementation returns placeholder text. fn name(&self) -> &str { "AggregateExpr: default name" } - /// If the aggregate expression is supported by row format - fn row_accumulator_supported(&self) -> bool { - false - } - - /// Specifies whether this aggregate function can run using bounded memory. - /// Any accumulator returning "true" needs to implement `retract_batch`. - fn supports_bounded_execution(&self) -> bool { + /// If the aggregate expression has a specialized + /// [`GroupsAccumulator`] implementation. If this returns true, + /// `[Self::create_groups_accumulator`] will be called. + fn groups_accumulator_supported(&self) -> bool { false } - /// RowAccumulator to access/update row-based aggregation state in-place. - /// Currently, row accumulator only supports states of fixed-sized type. + /// Return a specialized [`GroupsAccumulator`] that manages state + /// for all groups. /// - /// We recommend implementing `RowAccumulator` along with the standard `Accumulator`, - /// when its state is of fixed size, as RowAccumulator is more memory efficient and CPU-friendly. - fn create_row_accumulator( - &self, - _start_index: usize, - ) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "RowAccumulator hasn't been implemented for {self:?} yet" - ))) + /// For maximum performance, a [`GroupsAccumulator`] should be + /// implemented in addition to [`Accumulator`]. + fn create_groups_accumulator(&self) -> Result> { + not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } /// Construct an expression that calculates the aggregate in reverse. @@ -126,9 +128,7 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Creates accumulator implementation that supports retract fn create_sliding_accumulator(&self) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "Retractable Accumulator hasn't been implemented for {self:?} yet" - ))) + not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") } } @@ -139,5 +139,5 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { aggr_expr.as_any().is::() || aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() + || aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/physical-expr/src/aggregate/regr.rs new file mode 100644 index 0000000000000..6922cb131cacc --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/regr.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::Float64Array; +use arrow::{ + array::{ArrayRef, UInt64Array}, + compute::cast, + datatypes::DataType, + datatypes::Field, +}; +use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::Accumulator; + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; + +#[derive(Debug)] +pub struct Regr { + name: String, + regr_type: RegrType, + expr_y: Arc, + expr_x: Arc, +} + +impl Regr { + pub fn get_regr_type(&self) -> RegrType { + self.regr_type.clone() + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +pub enum RegrType { + /// Variant for `regr_slope` aggregate expression + /// Returns the slope of the linear regression line for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal + /// RSS (Residual Sum of Squares) fitting. + Slope, + /// Variant for `regr_intercept` aggregate expression + /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal + /// RSS fitting. + Intercept, + /// Variant for `regr_count` aggregate expression + /// Returns the number of input rows for which both expressions are not null. + /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs. + Count, + /// Variant for `regr_r2` aggregate expression + /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns. + /// The R-squared value represents the proportion of variance in Y that is predictable from X. + R2, + /// Variant for `regr_avgx` aggregate expression + /// Returns the average of the independent variable for non-null pairs in aggregate columns. + /// Given input column X: `regr_avgx(Y, X)` returns the average of X values. + AvgX, + /// Variant for `regr_avgy` aggregate expression + /// Returns the average of the dependent variable for non-null pairs in aggregate columns. + /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values. + AvgY, + /// Variant for `regr_sxx` aggregate expression + /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns. + /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean. + SXX, + /// Variant for `regr_syy` aggregate expression + /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns. + /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean. + SYY, + /// Variant for `regr_sxy` aggregate expression + /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means. + SXY, +} + +impl Regr { + pub fn new( + expr_y: Arc, + expr_x: Arc, + name: impl Into, + regr_type: RegrType, + return_type: DataType, + ) -> Self { + // the result of regr_slope only support FLOAT64 data type. + assert!(matches!(return_type, DataType::Float64)); + Self { + name: name.into(), + regr_type, + expr_y, + expr_x, + } + } +} + +impl AggregateExpr for Regr { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(&self.name, "mean_x"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean_y"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_x"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_y"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr_y.clone(), self.expr_x.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Regr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr_y.eq(&x.expr_y) + && self.expr_x.eq(&x.expr_x) + }) + .unwrap_or(false) + } +} + +/// `RegrAccumulator` is used to compute linear regression aggregate functions +/// by maintaining statistics needed to compute them in an online fashion. +/// +/// This struct uses Welford's online algorithm for calculating variance and covariance: +/// +/// +/// Given the statistics, the following aggregate functions can be calculated: +/// +/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as: +/// cov_pop(x, y) / var_pop(x). +/// It represents the expected change in Y for a one-unit change in X. +/// +/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as: +/// mean_y - (regr_slope(y, x) * mean_x). +/// It represents the expected value of Y when X is 0. +/// +/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows. +/// +/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as: +/// (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)). +/// It provides a measure of how well the model's predictions match the observed data. +/// +/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x. +/// +/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y. +/// +/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as: +/// m2_x. +/// +/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as: +/// m2_y. +/// +/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as: +/// algo_const. +/// +/// Here's how the statistics maintained in this struct are calculated: +/// - `cov_pop(x, y)`: algo_const / count. +/// - `var_pop(x)`: m2_x / count. +/// - `var_pop(y)`: m2_y / count. +#[derive(Debug)] +pub struct RegrAccumulator { + count: u64, + mean_x: f64, + mean_y: f64, + m2_x: f64, + m2_y: f64, + algo_const: f64, + regr_type: RegrType, +} + +impl RegrAccumulator { + /// Creates a new `RegrAccumulator` + pub fn try_new(regr_type: &RegrType) -> Result { + Ok(Self { + count: 0_u64, + mean_x: 0_f64, + mean_y: 0_f64, + m2_x: 0_f64, + m2_y: 0_f64, + algo_const: 0_f64, + regr_type: regr_type.clone(), + }) + } +} + +impl Accumulator for RegrAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean_x), + ScalarValue::from(self.mean_y), + ScalarValue::from(self.m2_x), + ScalarValue::from(self.m2_y), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // regr_slope(Y, X) calculates k in y = k*x + b + let values_y = &cast(&values[0], &DataType::Float64)?; + let values_x = &cast(&values[1], &DataType::Float64)?; + + let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); + let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + + for i in 0..values_y.len() { + // skip either x or y is NULL + let value_y = if values_y.is_valid(i) { + arr_y.next() + } else { + None + }; + let value_x = if values_x.is_valid(i) { + arr_x.next() + } else { + None + }; + if value_y.is_none() || value_x.is_none() { + continue; + } + + // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] + let value_y = unwrap_or_internal_err!(value_y); + let value_x = unwrap_or_internal_err!(value_x); + + self.count += 1; + let delta_x = value_x - self.mean_x; + let delta_y = value_y - self.mean_y; + self.mean_x += delta_x / self.count as f64; + self.mean_y += delta_y / self.count as f64; + let delta_x_2 = value_x - self.mean_x; + let delta_y_2 = value_y - self.mean_y; + self.m2_x += delta_x * delta_x_2; + self.m2_y += delta_y * delta_y_2; + self.algo_const += delta_x * (value_y - self.mean_y); + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values_y = &cast(&values[0], &DataType::Float64)?; + let values_x = &cast(&values[1], &DataType::Float64)?; + + let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); + let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + + for i in 0..values_y.len() { + // skip either x or y is NULL + let value_y = if values_y.is_valid(i) { + arr_y.next() + } else { + None + }; + let value_x = if values_x.is_valid(i) { + arr_x.next() + } else { + None + }; + if value_y.is_none() || value_x.is_none() { + continue; + } + + // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] + let value_y = unwrap_or_internal_err!(value_y); + let value_x = unwrap_or_internal_err!(value_x); + + if self.count > 1 { + self.count -= 1; + let delta_x = value_x - self.mean_x; + let delta_y = value_y - self.mean_y; + self.mean_x -= delta_x / self.count as f64; + self.mean_y -= delta_y / self.count as f64; + let delta_x_2 = value_x - self.mean_x; + let delta_y_2 = value_y - self.mean_y; + self.m2_x -= delta_x * delta_x_2; + self.m2_y -= delta_y * delta_y_2; + self.algo_const -= delta_x * (value_y - self.mean_y); + } else { + self.count = 0; + self.mean_x = 0.0; + self.m2_x = 0.0; + self.m2_y = 0.0; + self.mean_y = 0.0; + self.algo_const = 0.0; + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let count_arr = downcast_value!(states[0], UInt64Array); + let mean_x_arr = downcast_value!(states[1], Float64Array); + let mean_y_arr = downcast_value!(states[2], Float64Array); + let m2_x_arr = downcast_value!(states[3], Float64Array); + let m2_y_arr = downcast_value!(states[4], Float64Array); + let algo_const_arr = downcast_value!(states[5], Float64Array); + + for i in 0..count_arr.len() { + let count_b = count_arr.value(i); + if count_b == 0_u64 { + continue; + } + let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = ( + self.count, + self.mean_x, + self.mean_y, + self.m2_x, + self.m2_y, + self.algo_const, + ); + let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = ( + count_b, + mean_x_arr.value(i), + mean_y_arr.value(i), + m2_x_arr.value(i), + m2_y_arr.value(i), + algo_const_arr.value(i), + ); + + // Assuming two different batches of input have calculated the states: + // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a} + // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b} + // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab, + // algo_const_ab} + // + // Reference for the algorithm to merge states: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + let count_ab = count_a + count_b; + let (count_a, count_b) = (count_a as f64, count_b as f64); + let d_x = mean_x_b - mean_x_a; + let d_y = mean_y_b - mean_y_a; + let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64; + let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64; + let m2_x_ab = + m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64; + let m2_y_ab = + m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64; + let algo_const_ab = algo_const_a + + algo_const_b + + d_x * d_y * count_a * count_b / count_ab as f64; + + self.count = count_ab; + self.mean_x = mean_x_ab; + self.mean_y = mean_y_ab; + self.m2_x = m2_x_ab; + self.m2_y = m2_y_ab; + self.algo_const = algo_const_ab; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + let cov_pop_x_y = self.algo_const / self.count as f64; + let var_pop_x = self.m2_x / self.count as f64; + let var_pop_y = self.m2_y / self.count as f64; + + let nullif_or_stat = |cond: bool, stat: f64| { + if cond { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(stat))) + } + }; + + match self.regr_type { + RegrType::Slope => { + // Only 0/1 point or slope is infinite + let nullif_cond = self.count <= 1 || var_pop_x == 0.0; + nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x) + } + RegrType::Intercept => { + let slope = cov_pop_x_y / var_pop_x; + // Only 0/1 point or slope is infinite + let nullif_cond = self.count <= 1 || var_pop_x == 0.0; + nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x) + } + RegrType::Count => Ok(ScalarValue::Float64(Some(self.count as f64))), + RegrType::R2 => { + // Only 0/1 point or all x(or y) is the same + let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0; + nullif_or_stat( + nullif_cond, + (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y), + ) + } + RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x), + RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y), + RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x), + RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y), + RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const), + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs deleted file mode 100644 index e5282629220f7..0000000000000 --- a/datafusion/physical-expr/src/aggregate/row_accumulator.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Accumulator over row format - -use arrow::array::ArrayRef; -use arrow_schema::DataType; -use datafusion_common::{Result, ScalarValue}; -use datafusion_row::accessor::RowAccessor; -use std::fmt::Debug; - -/// Row-based accumulator where the internal aggregate state(s) are stored using row format. -/// -/// Unlike the [`datafusion_expr::Accumulator`], the [`RowAccumulator`] does not store the state internally. -/// Instead, it knows how to access/update the state stored in a row via the the provided accessor and -/// its state's starting field index in the row. -/// -/// For example, we are evaluating `SELECT a, sum(b), avg(c), count(d) from GROUP BY a;`, we would have one row used as -/// aggregation state for each distinct `a` value, the index of the first and the only state of `sum(b)` would be 0, -/// the index of the first state of `avg(c)` would be 1, and the index of the first and only state of `cound(d)` would be 3: -/// -/// sum(b) state_index = 0 count(d) state_index = 3 -/// | | -/// v v -/// +--------+----------+--------+----------+ -/// | sum(b) | count(c) | sum(c) | count(d) | -/// +--------+----------+--------+----------+ -/// ^ -/// | -/// avg(c) state_index = 1 -/// -pub trait RowAccumulator: Send + Sync + Debug { - /// updates the accumulator's state from a vector of arrays. - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()>; - - /// updates the accumulator's state from a vector of Scalar value. - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()>; - - /// updates the accumulator's state from a Scalar value. - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()>; - - /// updates the accumulator's state from a vector of states. - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()>; - - /// returns its value based on its current state. - fn evaluate(&self, accessor: &RowAccessor) -> Result; - - /// State's starting field index in the row. - fn state_index(&self) -> usize; -} - -/// Returns if `data_type` is supported with `RowAccumulator` -pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Boolean - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - ) -} diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index e1b9b9ae23ff1..64e19ef502c7b 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -27,7 +27,7 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::Accumulator; /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression @@ -230,9 +230,7 @@ impl Accumulator for StddevAccumulator { Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) } } - _ => Err(DataFusionError::Internal( - "Variance should be f64".to_string(), - )), + _ => internal_err!("Variance should be f64"), } } @@ -447,13 +445,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 0000000000000..7adc736932ad7 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::{format_state_name, Literal}; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + name: String, + data_type: DataType, + expr: Arc, + delimiter: Arc, + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + delimiter: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + delimiter, + expr, + nullable: true, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + return Ok(Box::new(StringAggAccumulator::new(delimiter))); + } + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::new(""))); + } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), + } + } + not_impl_err!("StringAgg not supported for {}", self.name) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.data_type.clone(), + self.nullable, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone(), self.delimiter.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use arrow::array::ArrayRef; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; + + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn string_agg_utf8() { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); + } + + #[test] + fn string_agg_largeutf8() { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); + } +} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 1c70dc67beeb8..03f666cc4e5d5 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -15,43 +15,38 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::compute; +use super::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use arrow::compute::sum; use arrow::datatypes::DataType; -use arrow::{ - array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::Field, +use arrow::{array::ArrayRef, datatypes::Field}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, }; -use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::type_coercion::aggregates::sum_return_type; use datafusion_expr::Accumulator; -use crate::aggregate::row_accumulator::{ - is_row_accumulator_support_dtype, RowAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::compute::cast; -use datafusion_row::accessor::RowAccessor; - /// SUM aggregate expression #[derive(Debug, Clone)] pub struct Sum { name: String, - pub data_type: DataType, + // The DataType for the input expression + data_type: DataType, + // The DataType for the final sum + return_type: DataType, expr: Arc, nullable: bool, - pub pre_cast_to_sum_type: bool, } impl Sum { @@ -61,30 +56,35 @@ impl Sum { name: impl Into, data_type: DataType, ) -> Self { + let return_type = sum_return_type(&data_type).unwrap(); Self { name: name.into(), - expr, data_type, + return_type, + expr, nullable: true, - pre_cast_to_sum_type: false, } } +} - pub fn new_with_pre_cast( - expr: Arc, - name: impl Into, - data_type: DataType, - pre_cast_to_sum_type: bool, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - pre_cast_to_sum_type, +/// Sum only supports a subset of numeric types, instead relying on type coercion +/// +/// This macro is similar to [downcast_primitive](arrow_array::downcast_primitive) +/// +/// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) +macro_rules! downcast_sum { + ($s:ident, $helper:ident) => { + match $s.return_type { + DataType::UInt64 => $helper!(UInt64Type, $s.return_type), + DataType::Int64 => $helper!(Int64Type, $s.return_type), + DataType::Float64 => $helper!(Float64Type, $s.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_type), + _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.return_type), } - } + }; } +pub(crate) use downcast_sum; impl AggregateExpr for Sum { /// Return a reference to Any that can be used for downcasting @@ -95,28 +95,26 @@ impl AggregateExpr for Sum { fn field(&self) -> Result { Ok(Field::new( &self.name, - self.data_type.clone(), + self.return_type.clone(), self.nullable, )) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(self, helper) } fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "sum"), - self.data_type.clone(), - self.nullable, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - self.nullable, - ), - ]) + Ok(vec![Field::new( + format_state_name(&self.name, "sum"), + self.return_type.clone(), + self.nullable, + )]) } fn expressions(&self) -> Vec> { @@ -127,22 +125,20 @@ impl AggregateExpr for Sum { &self.name } - fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) - } - - fn supports_bounded_execution(&self) -> bool { + fn groups_accumulator_supported(&self) -> bool { true } - fn create_row_accumulator( - &self, - start_index: usize, - ) -> Result> { - Ok(Box::new(SumRowAccumulator::new( - start_index, - self.data_type.clone(), - ))) + fn create_groups_accumulator(&self) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + &$dt, + |x, y| *x = x.add_wrapping(y), + ))) + }; + } + downcast_sum!(self, helper) } fn reverse_expr(&self) -> Option> { @@ -150,7 +146,12 @@ impl AggregateExpr for Sum { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(self, helper) } } @@ -168,357 +169,200 @@ impl PartialEq for Sum { } } -#[derive(Debug)] -struct SumAccumulator { - sum: ScalarValue, - count: u64, +/// This accumulator computes SUM incrementally +struct SumAccumulator { + sum: Option, + data_type: DataType, } -impl SumAccumulator { - /// new sum accumulator - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - sum: ScalarValue::try_from(data_type)?, - count: 0, - }) +impl std::fmt::Debug for SumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SumAccumulator({})", self.data_type) } } -// returns the new value after sum with the new values, taking nullability into account -macro_rules! typed_sum_delta_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let delta = compute::sum(array); - ScalarValue::$SCALAR(delta) - }}; -} - -fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: i8) -> Result { - let array = downcast_value!(values, Decimal128Array); - let result = compute::sum(array); - Ok(ScalarValue::Decimal128(result, precision, scale)) -} - -// sums the array and returns a ScalarValue of its corresponding type. -pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result { - // TODO refine the cast kernel in arrow-rs - let cast_values = if values.data_type() != sum_type { - Some(cast(values, sum_type)?) - } else { - None - }; - let values = cast_values.as_ref().unwrap_or(values); - Ok(match values.data_type() { - DataType::Decimal128(precision, scale) => { - sum_decimal_batch(values, *precision, *scale)? - } - DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), - DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), - DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), - DataType::Int32 => typed_sum_delta_batch!(values, Int32Array, Int32), - DataType::Int16 => typed_sum_delta_batch!(values, Int16Array, Int16), - DataType::Int8 => typed_sum_delta_batch!(values, Int8Array, Int8), - DataType::UInt64 => typed_sum_delta_batch!(values, UInt64Array, UInt64), - DataType::UInt32 => typed_sum_delta_batch!(values, UInt32Array, UInt32), - DataType::UInt16 => typed_sum_delta_batch!(values, UInt16Array, UInt16), - DataType::UInt8 => typed_sum_delta_batch!(values, UInt8Array, UInt8), - e => { - return Err(DataFusionError::Internal(format!( - "Sum is not expected to receive the type {e:?}" - ))); - } - }) -} - -macro_rules! sum_row { - ($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{ - paste::item! { - if let Some(v) = $DELTA { - $ACC.[]($INDEX, *v) - } - } - }}; -} - -macro_rules! avg_row { - ($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{ - paste::item! { - if let Some(v) = $DELTA { - $ACC.add_u64($INDEX, 1); - $ACC.[]($INDEX + 1, *v) - } - } - }}; -} - -pub(crate) fn add_to_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - match s { - ScalarValue::Null => { - // do nothing - } - ScalarValue::Float64(rhs) => { - sum_row!(index, accessor, rhs, f64) - } - ScalarValue::Float32(rhs) => { - sum_row!(index, accessor, rhs, f32) - } - ScalarValue::UInt64(rhs) => { - sum_row!(index, accessor, rhs, u64) - } - ScalarValue::Int64(rhs) => { - sum_row!(index, accessor, rhs, i64) - } - ScalarValue::Decimal128(rhs, _, _) => { - sum_row!(index, accessor, rhs, i128) - } - ScalarValue::Dictionary(_, value) => { - let value = value.as_ref(); - return add_to_row(index, accessor, value); - } - _ => { - let msg = - format!("Row sum updater is not expected to receive a scalar {s:?}"); - return Err(DataFusionError::Internal(msg)); - } - } - Ok(()) -} - -pub(crate) fn update_avg_to_row( - index: usize, - accessor: &mut RowAccessor, - s: &ScalarValue, -) -> Result<()> { - match s { - ScalarValue::Null => { - // do nothing - } - ScalarValue::Float64(rhs) => { - avg_row!(index, accessor, rhs, f64) - } - ScalarValue::Float32(rhs) => { - avg_row!(index, accessor, rhs, f32) - } - ScalarValue::UInt64(rhs) => { - avg_row!(index, accessor, rhs, u64) - } - ScalarValue::Int64(rhs) => { - avg_row!(index, accessor, rhs, i64) - } - ScalarValue::Decimal128(rhs, _, _) => { - avg_row!(index, accessor, rhs, i128) - } - ScalarValue::Dictionary(_, value) => { - let value = value.as_ref(); - return update_avg_to_row(index, accessor, value); - } - _ => { - let msg = - format!("Row avg updater is not expected to receive a scalar {s:?}"); - return Err(DataFusionError::Internal(msg)); +impl SumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: None, + data_type, } } - Ok(()) } -impl Accumulator for SumAccumulator { +impl Accumulator for SumAccumulator { fn state(&self) -> Result> { - Ok(vec![self.sum.clone(), ScalarValue::from(self.count)]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.count += (values.len() - values.null_count()) as u64; - let delta = sum_batch(values, &self.sum.get_datatype())?; - self.sum = self.sum.add(&delta)?; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.count -= (values.len() - values.null_count()) as u64; - let delta = sum_batch(values, &self.sum.get_datatype())?; - self.sum = self.sum.sub(&delta)?; + let values = values[0].as_primitive::(); + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(T::Native::usize_as(0)); + *v = v.add_wrapping(x); + } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... self.update_batch(states) } fn evaluate(&self) -> Result { - // TODO: add the checker for overflow - // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. - if self.count == 0 { - ScalarValue::try_from(&self.sum.get_datatype()) - } else { - Ok(self.sum.clone()) - } + ScalarValue::new_primitive::(self.sum, &self.data_type) } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + std::mem::size_of_val(self) } } -#[derive(Debug)] -struct SumRowAccumulator { - index: usize, - datatype: DataType, +/// This accumulator incrementally computes sums over a sliding window +/// +/// This is separate from [`SumAccumulator`] as requires additional state +struct SlidingSumAccumulator { + sum: T::Native, + count: u64, + data_type: DataType, } -impl SumRowAccumulator { - pub fn new(index: usize, datatype: DataType) -> Self { - Self { index, datatype } +impl std::fmt::Debug for SlidingSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SlidingSumAccumulator({})", self.data_type) } } -impl RowAccumulator for SumRowAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - let values = &values[0]; - let delta = sum_batch(values, &self.datatype)?; - add_to_row(self.index, accessor, &delta) +impl SlidingSumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: T::Native::usize_as(0), + count: 0, + data_type, + } } +} - fn update_scalar_values( - &mut self, - values: &[ScalarValue], - accessor: &mut RowAccessor, - ) -> Result<()> { - let value = &values[0]; - add_to_row(self.index, accessor, value) +impl Accumulator for SlidingSumAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?, self.count.into()]) } - fn update_scalar( - &mut self, - value: &ScalarValue, - accessor: &mut RowAccessor, - ) -> Result<()> { - add_to_row(self.index, accessor, value) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = self.sum.add_wrapping(x) + } + Ok(()) } - fn merge_batch( - &mut self, - states: &[ArrayRef], - accessor: &mut RowAccessor, - ) -> Result<()> { - self.update_batch(states, accessor) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let values = states[0].as_primitive::(); + if let Some(x) = sum(values) { + self.sum = self.sum.add_wrapping(x) + } + if let Some(x) = sum(states[1].as_primitive::()) { + self.count += x; + } + Ok(()) } - fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&self.datatype, self.index)) + fn evaluate(&self) -> Result { + let v = (self.count != 0).then_some(self.sum); + ScalarValue::new_primitive::(v, &self.data_type) } - #[inline(always)] - fn state_index(&self) -> usize { - self.index + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = sum(values) { + self.sum = self.sum.sub_wrapping(x) + } + self.count -= (values.len() - values.null_count()) as u64; + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::tests::aggregate; - use crate::expressions::{col, Avg}; - use crate::generic_test_op; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use arrow_array::DictionaryArray; - use datafusion_common::Result; + use crate::expressions::tests::assert_aggregate; + use arrow_array::*; + use datafusion_expr::AggregateFunction; #[test] - fn sum_decimal() -> Result<()> { - // test sum batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = sum_batch(&array, &DataType::Decimal128(10, 0))?; - assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); - + fn sum_decimal() { // test agg let array: ArrayRef = Arc::new( (1..6) .map(Some) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( + assert_aggregate( array, - DataType::Decimal128(10, 0), - Sum, - ScalarValue::Decimal128(Some(15), 20, 0) - ) + AggregateFunction::Sum, + false, + ScalarValue::Decimal128(Some(15), 20, 0), + ); } #[test] - fn sum_decimal_with_nulls() -> Result<()> { - // test with batch - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = sum_batch(&array, &DataType::Decimal128(10, 0))?; - assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); - + fn sum_decimal_with_nulls() { // test agg let array: ArrayRef = Arc::new( (1..6) .map(|i| if i == 2 { None } else { Some(i) }) .collect::() - .with_precision_and_scale(35, 0)?, + .with_precision_and_scale(35, 0) + .unwrap(), ); - generic_test_op!( + + assert_aggregate( array, - DataType::Decimal128(35, 0), - Sum, - ScalarValue::Decimal128(Some(13), 38, 0) - ) + AggregateFunction::Sum, + false, + ScalarValue::Decimal128(Some(13), 38, 0), + ); } #[test] - fn sum_decimal_all_nulls() -> Result<()> { + fn sum_decimal_all_nulls() { // test with batch let array: ArrayRef = Arc::new( std::iter::repeat::>(None) .take(6) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - let result = sum_batch(&array, &DataType::Decimal128(10, 0))?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); // test agg - generic_test_op!( + assert_aggregate( array, - DataType::Decimal128(10, 0), - Sum, - ScalarValue::Decimal128(None, 20, 0) - ) + AggregateFunction::Sum, + false, + ScalarValue::Decimal128(None, 20, 0), + ); } #[test] - fn sum_i32() -> Result<()> { + fn sum_i32() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(15i32)) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15i64)); } #[test] - fn sum_i32_with_nulls() -> Result<()> { + fn sum_i32_with_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ Some(1), None, @@ -526,104 +370,33 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(13i32)) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(13i64)); } #[test] - fn sum_i32_all_nulls() -> Result<()> { + fn sum_i32_all_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Sum, ScalarValue::Int32(None)) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::Int64(None)); } #[test] - fn sum_u32() -> Result<()> { + fn sum_u32() { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Sum, ScalarValue::from(15u32)) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15u64)); } #[test] - fn sum_f32() -> Result<()> { + fn sum_f32() { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Sum, ScalarValue::from(15_f32)) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); } #[test] - fn sum_f64() -> Result<()> { + fn sum_f64() { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Sum, ScalarValue::from(15_f64)) - } - - fn row_aggregate( - array: &ArrayRef, - agg: Arc, - row_accessor: &mut RowAccessor, - row_indexs: Vec, - ) -> Result { - let mut accum = agg.create_row_accumulator(0)?; - - for row_index in row_indexs { - let scalar_value = ScalarValue::try_from_array(array, row_index)?; - accum.update_scalar(&scalar_value, row_accessor)?; - } - accum.evaluate(row_accessor) - } - - #[test] - fn sum_dictionary_f64() -> Result<()> { - let keys = Int32Array::from(vec![2, 3, 1, 0, 1]); - let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64])); - - let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap()); - - let row_schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - let mut row_accessor = RowAccessor::new(&row_schema); - let mut buffer: Vec = vec![0; 16]; - row_accessor.point_to(0, &mut buffer); - - let expected = ScalarValue::from(9_f64); - - let agg = Arc::new(Sum::new( - col("a", &row_schema)?, - "bla".to_string(), - expected.get_datatype(), - )); - - let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?; - assert_eq!(expected, actual); - - Ok(()) - } - - #[test] - fn avg_dictionary_f64() -> Result<()> { - let keys = Int32Array::from(vec![2, 1, 1, 3, 0]); - let values = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64])); - - let a: ArrayRef = Arc::new(DictionaryArray::try_new(keys, values).unwrap()); - - let row_schema = Schema::new(vec![ - Field::new("count", DataType::UInt64, true), - Field::new("a", DataType::Float64, true), - ]); - let mut row_accessor = RowAccessor::new(&row_schema); - let mut buffer: Vec = vec![0; 24]; - row_accessor.point_to(0, &mut buffer); - - let expected = ScalarValue::from(2.3333333333333335_f64); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - let agg = Arc::new(Avg::new( - col("a", &schema)?, - "bla".to_string(), - expected.get_datatype(), - )); - - let actual = row_aggregate(&a, agg, &mut row_accessor, vec![0, 1, 2])?; - assert_eq!(expected, actual); - - Ok(()) + assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); } } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 73f74df967542..0cf4a90ab8cc4 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -18,17 +18,21 @@ use crate::expressions::format_state_name; use arrow::datatypes::{DataType, Field}; use std::any::Any; -use std::fmt::Debug; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; use std::collections::HashSet; +use crate::aggregate::sum::downcast_sum; use crate::aggregate::utils::down_cast_any_ref; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::type_coercion::aggregates::sum_return_type; use datafusion_expr::Accumulator; /// Expression for a SUM(DISTINCT) aggregation. @@ -36,8 +40,10 @@ use datafusion_expr::Accumulator; pub struct DistinctSum { /// Column name name: String, - /// The DataType for the final sum + // The DataType for the input expression data_type: DataType, + // The DataType for the final sum + return_type: DataType, /// The input arguments, only contains 1 item for sum exprs: Vec>, } @@ -49,9 +55,11 @@ impl DistinctSum { name: String, data_type: DataType, ) -> Self { + let return_type = sum_return_type(&data_type).unwrap(); Self { name, data_type, + return_type, exprs, } } @@ -63,14 +71,14 @@ impl AggregateExpr for DistinctSum { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.return_type.clone(), true)) } fn state_fields(&self) -> Result> { // State field is a List which stores items to rebuild hash set. Ok(vec![Field::new_list( format_state_name(&self.name, "sum distinct"), - Field::new("item", self.data_type.clone(), true), + Field::new("item", self.return_type.clone(), true), false, )]) } @@ -84,7 +92,12 @@ impl AggregateExpr for DistinctSum { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctSumAccumulator::try_new(&self.data_type)?)) + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) + }; + } + downcast_sum!(self, helper) } } @@ -106,33 +119,61 @@ impl PartialEq for DistinctSum { } } -#[derive(Debug)] -struct DistinctSumAccumulator { - hash_values: HashSet, +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone)] +struct Hashable(T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} + +struct DistinctSumAccumulator { + values: HashSet, RandomState>, data_type: DataType, } -impl DistinctSumAccumulator { + +impl std::fmt::Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { pub fn try_new(data_type: &DataType) -> Result { Ok(Self { - hash_values: HashSet::default(), + values: HashSet::default(), data_type: data_type.clone(), }) } } -impl Accumulator for DistinctSumAccumulator { +impl Accumulator for DistinctSumAccumulator { fn state(&self) -> Result> { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { - let mut distinct_values = Vec::new(); - self.hash_values + let distinct_values = self + .values .iter() - .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![ScalarValue::new_list( - Some(distinct_values), - self.data_type.clone(), - )] + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list( + &distinct_values, + &self.data_type, + ))] }; Ok(state_out) } @@ -142,64 +183,49 @@ impl Accumulator for DistinctSumAccumulator { return Ok(()); } - let arr = &values[0]; - (0..values[0].len()).try_for_each(|index| { - if !arr.is_null(index) { - let v = ScalarValue::try_from_array(arr, index)?; - self.hash_values.insert(v); + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } } - Ok(()) - }) + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? } - - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.hash_values.insert(scalar.clone()); - } - }); - } else { - return Err(DataFusionError::Internal( - "Unexpected accumulator state".into(), - )); - } - Ok(()) - }) + Ok(()) } fn evaluate(&self) -> Result { - let mut sum_value = ScalarValue::try_from(&self.data_type)?; - for distinct_value in self.hash_values.iter() { - sum_value = sum_value.add(distinct_value)?; + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) } - Ok(sum_value) + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &self.data_type) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.hash_values) - - std::mem::size_of_val(&self.hash_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) + std::mem::size_of_val(self) + + self.values.capacity() * std::mem::size_of::() } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; + use crate::expressions::tests::assert_aggregate; + use arrow::array::*; use datafusion_common::Result; + use datafusion_expr::AggregateFunction; fn run_update_batch( return_type: DataType, @@ -213,26 +239,6 @@ mod tests { Ok((accum.state()?, accum.evaluate()?)) } - macro_rules! generic_test_sum_distinct { - ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(DistinctSum::new( - vec![col("a", &schema)?], - "count_distinct_a".to_string(), - $EXPECTED.get_datatype(), - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - #[test] fn sum_distinct_update_batch() -> Result<()> { let array_int64: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 3])); @@ -246,7 +252,7 @@ mod tests { } #[test] - fn sum_distinct_i32_with_nulls() -> Result<()> { + fn sum_distinct_i32_with_nulls() { let array = Arc::new(Int32Array::from(vec![ Some(1), Some(1), @@ -255,11 +261,11 @@ mod tests { Some(2), Some(3), ])); - generic_test_sum_distinct!(array, DataType::Int32, ScalarValue::from(6_i32)) + assert_aggregate(array, AggregateFunction::Sum, true, 6_i64.into()); } #[test] - fn sum_distinct_u32_with_nulls() -> Result<()> { + fn sum_distinct_u32_with_nulls() { let array: ArrayRef = Arc::new(UInt32Array::from(vec![ Some(1_u32), Some(1_u32), @@ -267,28 +273,30 @@ mod tests { Some(3_u32), None, ])); - generic_test_sum_distinct!(array, DataType::UInt32, ScalarValue::from(4_u32)) + assert_aggregate(array, AggregateFunction::Sum, true, 4_u64.into()); } #[test] - fn sum_distinct_f64() -> Result<()> { + fn sum_distinct_f64() { let array: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64])); - generic_test_sum_distinct!(array, DataType::Float64, ScalarValue::from(4_f64)) + assert_aggregate(array, AggregateFunction::Sum, true, 4_f64.into()); } #[test] - fn sum_distinct_decimal_with_nulls() -> Result<()> { + fn sum_distinct_decimal_with_nulls() { let array: ArrayRef = Arc::new( (1..6) .map(|i| if i == 2 { None } else { Some(i % 2) }) .collect::() - .with_precision_and_scale(35, 0)?, + .with_precision_and_scale(35, 0) + .unwrap(), ); - generic_test_sum_distinct!( + assert_aggregate( array, - DataType::Decimal128(35, 0), - ScalarValue::Decimal128(Some(1), 38, 0) - ) + AggregateFunction::Sum, + true, + ScalarValue::Decimal128(Some(1), 38, 0), + ); } } diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 7e6d2dcf8f4fe..90f5244f477de 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,6 +28,9 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; +use arrow_array::cast::as_list_array; +use arrow_array::types::Float64Type; +use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; @@ -566,20 +569,22 @@ impl TDigest { /// [`TDigest`]. pub(crate) fn to_scalar_state(&self) -> Vec { // Gather up all the centroids - let centroids: Vec<_> = self + let centroids: Vec = self .centroids .iter() .flat_map(|c| [c.mean(), c.weight()]) .map(|v| ScalarValue::Float64(Some(v))) .collect(); + let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), - ScalarValue::new_list(Some(centroids), DataType::Float64), + ScalarValue::List(arr), ] } @@ -600,10 +605,18 @@ impl TDigest { }; let centroids: Vec<_> = match &state[5] { - ScalarValue::List(Some(c), f) if *f.data_type() == DataType::Float64 => c - .chunks(2) - .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) - .collect(), + ScalarValue::List(arr) => { + let list_array = as_list_array(arr); + let arr = list_array.values(); + + let f64arr = + as_primitive_array::(arr).expect("expected f64 array"); + f64arr + .values() + .chunks(2) + .map(|v| Centroid::new(v[0], v[1])) + .collect() + } v => panic!("invalid centroids type {v:?}"), }; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 158ceb316e844..e5421ef5ab7ec 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -17,11 +17,17 @@ //! Utilities used in aggregates -use crate::AggregateExpr; +use crate::{AggregateExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow_array::ArrowNativeTypeOp; +use arrow_buffer::ArrowNativeType; +use arrow_schema::{DataType, Field}; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -30,76 +36,172 @@ use std::sync::Arc; pub fn get_accum_scalar_values_as_arrays( accum: &dyn Accumulator, ) -> Result> { - Ok(accum + accum .state()? .iter() .map(|s| s.to_array_of_size(1)) - .collect::>()) + .collect::>>() } -pub fn calculate_result_decimal_for_avg( - lit_value: i128, - count: i128, - scale: i8, - target_type: &DataType, -) -> Result { - match target_type { - DataType::Decimal128(p, s) => { - // Different precision for decimal128 can store different range of value. - // For example, the precision is 3, the max of value is `999` and the min - // value is `-999` - let (target_mul, target_min, target_max) = ( - 10_i128.pow(*s as u32), - MIN_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1], - MAX_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1], - ); - let lit_scale_mul = 10_i128.pow(scale as u32); - if target_mul >= lit_scale_mul { - if let Some(value) = lit_value.checked_mul(target_mul / lit_scale_mul) { - let new_value = value / count; - if new_value >= target_min && new_value <= target_max { - Ok(ScalarValue::Decimal128(Some(new_value), *p, *s)) - } else { - Err(DataFusionError::Internal( - "Arithmetic Overflow in AvgAccumulator".to_string(), - )) - } - } else { - // can't convert the lit decimal to the returned data type - Err(DataFusionError::Internal( - "Arithmetic Overflow in AvgAccumulator".to_string(), - )) - } +/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow +/// +/// This is needed because different precisions for Decimal128/Decimal256 can +/// store different ranges of values and thus sum/count may not fit in +/// the target type. +/// +/// For example, the precision is 3, the max of value is `999` and the min +/// value is `-999` +pub(crate) struct DecimalAverager { + /// scale factor for sum values (10^sum_scale) + sum_mul: T::Native, + /// scale factor for target (10^target_scale) + target_mul: T::Native, + /// the output precision + target_precision: u8, +} + +impl DecimalAverager { + /// Create a new `DecimalAverager`: + /// + /// * sum_scale: the scale of `sum` values passed to [`Self::avg`] + /// * target_precision: the output precision + /// * target_scale: the output scale + /// + /// Errors if the resulting data can not be stored + pub fn try_new( + sum_scale: i8, + target_precision: u8, + target_scale: i8, + ) -> Result { + let sum_mul = T::Native::from_usize(10_usize) + .map(|b| b.pow_wrapping(sum_scale as u32)) + .ok_or(DataFusionError::Internal( + "Failed to compute sum_mul in DecimalAverager".to_string(), + ))?; + + let target_mul = T::Native::from_usize(10_usize) + .map(|b| b.pow_wrapping(target_scale as u32)) + .ok_or(DataFusionError::Internal( + "Failed to compute target_mul in DecimalAverager".to_string(), + ))?; + + if target_mul >= sum_mul { + Ok(Self { + sum_mul, + target_mul, + target_precision, + }) + } else { + // can't convert the lit decimal to the returned data type + exec_err!("Arithmetic Overflow in AvgAccumulator") + } + } + + /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with + /// target_scale and target_precision and reporting overflow. + /// + /// * sum: The total sum value stored as Decimal128 with sum_scale + /// (passed to `Self::try_new`) + /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value) + #[inline(always)] + pub fn avg(&self, sum: T::Native, count: T::Native) -> Result { + if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) { + let new_value = value.div_wrapping(count); + + let validate = + T::validate_decimal_precision(new_value, self.target_precision); + + if validate.is_ok() { + Ok(new_value) } else { - // can't convert the lit decimal to the returned data type - Err(DataFusionError::Internal( - "Arithmetic Overflow in AvgAccumulator".to_string(), - )) + exec_err!("Arithmetic Overflow in AvgAccumulator") } + } else { + // can't convert the lit decimal to the returned data type + exec_err!("Arithmetic Overflow in AvgAccumulator") } - other => Err(DataFusionError::Internal(format!( - "Error returned data type in AvgAccumulator {other:?}" - ))), } } +/// Adjust array type metadata if needed +/// +/// Since `Decimal128Arrays` created from `Vec` have +/// default precision and scale, this function adjusts the output to +/// match `data_type`, if necessary +pub fn adjust_output_array( + data_type: &DataType, + array: ArrayRef, +) -> Result { + let array = match data_type { + DataType::Decimal128(p, s) => Arc::new( + array + .as_primitive::() + .clone() + .with_precision_and_scale(*p, *s)?, + ) as ArrayRef, + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Second, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + // no adjustment needed for other arrays + _ => array, + }; + Ok(array) +} + /// Downcast a `Box` or `Arc` -/// and return the inner trait object as [`Any`](std::any::Any) so +/// and return the inner trait object as [`Any`] so /// that it can be downcast to a specific implementation. /// /// This method is used when implementing the `PartialEq` /// for [`AggregateExpr`] aggregation expressions and allows comparing the equality /// between the trait objects. pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() + if let Some(obj) = any.downcast_ref::>() { + obj.as_any() + } else if let Some(obj) = any.downcast_ref::>() { + obj.as_any() } else { any } } + +/// Construct corresponding fields for lexicographical ordering requirement expression +pub(crate) fn ordering_fields( + ordering_req: &[PhysicalSortExpr], + // Data type of each expression in the ordering requirement + data_types: &[DataType], +) -> Vec { + ordering_req + .iter() + .zip(data_types.iter()) + .map(|(expr, dtype)| { + Field::new( + expr.to_string().as_str(), + dtype.clone(), + // Multi partitions may be empty hence field should be nullable. + true, + ) + }) + .collect() +} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index a720dd833a87a..d82c5ad5626f4 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -519,13 +519,17 @@ mod tests { let values1 = expr1 .iter() - .map(|e| e.evaluate(batch1)) - .map(|r| r.map(|v| v.into_array(batch1.num_rows()))) + .map(|e| { + e.evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows())) + }) .collect::>>()?; let values2 = expr2 .iter() - .map(|e| e.evaluate(batch2)) - .map(|r| r.map(|v| v.into_array(batch2.num_rows()))) + .map(|e| { + e.evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows())) + }) .collect::>>()?; accum1.update_batch(&values1)?; accum2.update_batch(&values2)?; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs new file mode 100644 index 0000000000000..f43434362a19a --- /dev/null +++ b/datafusion/physical-expr/src/analysis.rs @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval and selectivity in [`AnalysisContext`] + +use std::fmt::Debug; +use std::sync::Arc; + +use crate::expressions::Column; +use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; +use crate::utils::collect_columns; +use crate::PhysicalExpr; + +use arrow::datatypes::Schema; +use datafusion_common::stats::Precision; +use datafusion_common::{ + internal_err, ColumnStatistics, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; + +/// The shared context used during the analysis of an expression. Includes +/// the boundaries for all known columns. +#[derive(Clone, Debug, PartialEq)] +pub struct AnalysisContext { + // A list of known column boundaries, ordered by the index + // of the column in the current schema. + pub boundaries: Vec, + /// The estimated percentage of rows that this expression would select, if + /// it were to be used as a boolean predicate on a filter. The value will be + /// between 0.0 (selects nothing) and 1.0 (selects everything). + pub selectivity: Option, +} + +impl AnalysisContext { + pub fn new(boundaries: Vec) -> Self { + Self { + boundaries, + selectivity: None, + } + } + + pub fn with_selectivity(mut self, selectivity: f64) -> Self { + self.selectivity = Some(selectivity); + self + } + + /// Create a new analysis context from column statistics. + pub fn try_from_statistics( + input_schema: &Schema, + statistics: &[ColumnStatistics], + ) -> Result { + statistics + .iter() + .enumerate() + .map(|(idx, stats)| ExprBoundaries::try_from_column(input_schema, stats, idx)) + .collect::>>() + .map(Self::new) + } +} + +/// Represents the boundaries (e.g. min and max values) of a particular column +/// +/// This is used range analysis of expressions, to determine if the expression +/// limits the value of particular columns (e.g. analyzing an expression such as +/// `time < 50` would result in a boundary interval for `time` having a max +/// value of `50`). +#[derive(Clone, Debug, PartialEq)] +pub struct ExprBoundaries { + pub column: Column, + /// Minimum and maximum values this expression can have. + pub interval: Interval, + /// Maximum number of distinct values this expression can produce, if known. + pub distinct_count: Precision, +} + +impl ExprBoundaries { + /// Create a new `ExprBoundaries` object from column level statistics. + pub fn try_from_column( + schema: &Schema, + col_stats: &ColumnStatistics, + col_index: usize, + ) -> Result { + let field = &schema.fields()[col_index]; + let empty_field = ScalarValue::try_from(field.data_type())?; + let interval = Interval::try_new( + col_stats + .min_value + .get_value() + .cloned() + .unwrap_or(empty_field.clone()), + col_stats + .max_value + .get_value() + .cloned() + .unwrap_or(empty_field), + )?; + let column = Column::new(field.name(), col_index); + Ok(ExprBoundaries { + column, + interval, + distinct_count: col_stats.distinct_count.clone(), + }) + } + + /// Create `ExprBoundaries` that represent no known bounds for all the + /// columns in `schema` + pub fn try_new_unbounded(schema: &Schema) -> Result> { + schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + Ok(Self { + column: Column::new(field.name(), i), + interval: Interval::make_unbounded(field.data_type())?, + distinct_count: Precision::Absent, + }) + }) + .collect() + } +} + +/// Attempts to refine column boundaries and compute a selectivity value. +/// +/// The function accepts boundaries of the input columns in the `context` parameter. +/// It then tries to tighten these boundaries based on the provided `expr`. +/// The resulting selectivity value is calculated by comparing the initial and final boundaries. +/// The computation assumes that the data within the column is uniformly distributed and not sorted. +/// +/// # Arguments +/// +/// * `context` - The context holding input column boundaries. +/// * `expr` - The expression used to shrink the column boundaries. +/// +/// # Returns +/// +/// * `AnalysisContext` constructed by pruned boundaries and a selectivity value. +pub fn analyze( + expr: &Arc, + context: AnalysisContext, + schema: &Schema, +) -> Result { + let target_boundaries = context.boundaries; + + let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; + + let columns = collect_columns(expr) + .into_iter() + .map(|c| Arc::new(c) as _) + .collect::>(); + + let target_expr_and_indices = graph.gather_node_indices(columns.as_slice()); + + let mut target_indices_and_boundaries = target_expr_and_indices + .iter() + .filter_map(|(expr, i)| { + target_boundaries.iter().find_map(|bound| { + expr.as_any() + .downcast_ref::() + .filter(|expr_column| bound.column.eq(*expr_column)) + .map(|_| (*i, bound.interval.clone())) + }) + }) + .collect::>(); + + match graph + .update_ranges(&mut target_indices_and_boundaries, Interval::CERTAINLY_TRUE)? + { + PropagationResult::Success => { + shrink_boundaries(graph, target_boundaries, target_expr_and_indices) + } + PropagationResult::Infeasible => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0)) + } + PropagationResult::CannotPropagate => { + Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0)) + } + } +} + +/// If the `PropagationResult` indicates success, this function calculates the +/// selectivity value by comparing the initial and final column boundaries. +/// Following this, it constructs and returns a new `AnalysisContext` with the +/// updated parameters. +fn shrink_boundaries( + graph: ExprIntervalGraph, + mut target_boundaries: Vec, + target_expr_and_indices: Vec<(Arc, usize)>, +) -> Result { + let initial_boundaries = target_boundaries.clone(); + target_expr_and_indices.iter().for_each(|(expr, i)| { + if let Some(column) = expr.as_any().downcast_ref::() { + if let Some(bound) = target_boundaries + .iter_mut() + .find(|bound| bound.column.eq(column)) + { + bound.interval = graph.get_interval(*i); + }; + } + }); + + let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries); + + if !(0.0..=1.0).contains(&selectivity) { + return internal_err!("Selectivity is out of limit: {}", selectivity); + } + + Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity)) +} + +/// This function calculates the filter predicate's selectivity by comparing +/// the initial and pruned column boundaries. Selectivity is defined as the +/// ratio of rows in a table that satisfy the filter's predicate. +fn calculate_selectivity( + target_boundaries: &[ExprBoundaries], + initial_boundaries: &[ExprBoundaries], +) -> f64 { + // Since the intervals are assumed uniform and the values + // are not correlated, we need to multiply the selectivities + // of multiple columns to get the overall selectivity. + initial_boundaries + .iter() + .zip(target_boundaries.iter()) + .fold(1.0, |acc, (initial, target)| { + acc * cardinality_ratio(&initial.interval, &target.interval) + }) +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 631ca376fc059..c2dc88b107739 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -17,1690 +17,2237 @@ //! Array expressions +use std::any::type_name; +use std::collections::HashSet; +use std::sync::Arc; + use arrow::array::*; -use arrow::buffer::Buffer; +use arrow::buffer::OffsetBuffer; use arrow::compute; -use arrow::datatypes::{DataType, Field}; -use core::any::type_name; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; +use arrow::datatypes::{DataType, Field, UInt64Type}; +use arrow::row::{RowConverter, SortField}; +use arrow_buffer::NullBuffer; + +use arrow_schema::{FieldRef, SortOptions}; +use datafusion_common::cast::{ + as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array, + as_list_array, as_null_array, as_string_array, +}; +use datafusion_common::utils::{array_into_list_array, list_ndims}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, +}; + +use itertools::Itertools; + +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} +/// Downcasts multiple arguments into a single concrete type +/// $ARGS: &[ArrayRef] +/// $ARRAY_TYPE: type to downcast to +/// +/// $returns a Vec<$ARRAY_TYPE> macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ $ARGS .iter() .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { Some(array) => Ok(array), - _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + _ => internal_err!("failed to downcast"), }) }}; } -macro_rules! new_builder { - (BooleanBuilder, $len:expr) => { - BooleanBuilder::with_capacity($len) +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. +/// +/// # Arguments +/// +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. +/// +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns +/// +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. +/// +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? + } + } }; - (StringBuilder, $len:expr) => { - StringBuilder::new() + + Ok(res) +} + +/// Returns the length of a concrete array dimension +fn compute_array_length( + arr: Option, + dimension: Option, +) -> Result> { + let mut current_dimension: i64 = 1; + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() + let dimension = match dimension { + Some(value) => { + if value < 1 { + return Ok(None); + } + + value + } + None => return Ok(None), }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; + + loop { + if current_dimension == dimension { + return Ok(Some(value.len() as u64)); + } + + match value.data_type() { + DataType::List(..) => { + value = downcast_arg!(value, ListArray).value(0); + current_dimension += 1; + } + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; + } + _ => return Ok(None), + } + } } -macro_rules! array { - ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - // downcast all arguments to their common format - let args = - downcast_vec!($ARGS, $ARRAY_TYPE).collect::>>()?; - - let builder = new_builder!($BUILDER_TYPE, args[0].len()); - let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, args.len()); - // for each entry in the array - for index in 0..args[0].len() { - for arg in &args { - if arg.is_null(index) { - builder.values().append_null(); - } else { - builder.values().append_value(arg.value(index)); - } +/// Returns the length of each array dimension +fn compute_array_dims(arr: Option) -> Result>>> { + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + if value.is_empty() { + return Ok(None); + } + let mut res = vec![Some(value.len() as u64)]; + + loop { + match value.data_type() { + DataType::List(..) => { + value = downcast_arg!(value, ListArray).value(0); + res.push(Some(value.len() as u64)); } - builder.append(true); + _ => return Ok(Some(res)), + } + } +} + +fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { + let data_type = args[0].data_type(); + if !args.iter().all(|arg| { + arg.data_type().equals_datatype(data_type) + || arg.data_type().equals_datatype(&DataType::Null) + }) { + let types = args.iter().map(|arg| arg.data_type()).collect::>(); + return plan_err!("{name} received incompatible types: '{types:?}'."); + } + + Ok(()) +} + +macro_rules! call_array_function { + ($DATATYPE:expr, false) => { + match $DATATYPE { + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }; + ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ + match $DATATYPE { + DataType::List(_) => array_function!(ListArray), + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), } - Arc::new(builder.finish()) }}; } -fn array_array(args: &[ArrayRef]) -> Result { +/// Convert one or more [`ArrayRef`] of the same type into a +/// `ListArray` or 'LargeListArray' depending on the offset size. +/// +/// # Example (non nested) +/// +/// Calling `array(col1, col2)` where col1 and col2 are non nested +/// would return a single new `ListArray`, where each row was a list +/// of 2 elements: +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ +/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ +/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ +/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ +/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ +/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ +/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ +/// └─────────┘ └─────────┘ └──────────────┘ +/// col1 col2 output +/// ``` +/// +/// # Example (nested) +/// +/// Calling `array(col1, col2)` where col1 and col2 are lists +/// would return a single new `ListArray`, where each row was a list +/// of the corresponding elements of col1 and col2. +/// +/// ``` text +/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ +/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ +/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ +/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ +/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ +/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ +/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ +/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ +/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ +/// col1 col2 output +/// ``` +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "Array requires at least one argument".to_string(), - )); + return plan_err!("Array requires at least one argument"); } - let data_type = args[0].data_type(); - let res = match data_type { - DataType::List(..) => { - let arrays = - downcast_vec!(args, ListArray).collect::>>()?; - let len: i32 = arrays.len() as i32; - let capacity = - Capacities::Array(arrays.iter().map(|a| a.get_array_memory_size()).sum()); - let array_data: Vec<_> = - arrays.iter().map(|a| a.to_data()).collect::>(); - let array_data = array_data.iter().collect(); - let mut mutable = - MutableArrayData::with_capacities(array_data, false, capacity); + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } - for (i, a) in arrays.iter().enumerate() { - mutable.extend(i, 0, a.len()) + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); + + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); + + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); } + } + offsets.push(O::usize_as(mutable.len())); + } + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) +} - let list_data_type = - DataType::List(Arc::new(Field::new("item", data_type.clone(), true))); +/// `make_array` SQL function +pub fn make_array(arrays: &[ArrayRef]) -> Result { + let mut data_type = DataType::Null; + for arg in arrays { + let arg_data_type = arg.data_type(); + if !arg_data_type.equals_datatype(&DataType::Null) { + data_type = arg_data_type.clone(); + break; + } + } - let list_data = ArrayData::builder(list_data_type) - .len(1) - .add_buffer(Buffer::from_slice_ref([0, len])) - .add_child_data(mutable.freeze()) - .build() - .unwrap(); + match data_type { + // Either an empty array or all nulls: + DataType::Null => { + let array = new_null_array(&DataType::Null, arrays.len()); + Ok(Arc::new(array_into_list_array(array))) + } + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), + } +} - Arc::new(ListArray::from(list_data)) - } - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), - data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))) +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: true, we don't construct List for array_element, so we need explicit nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + fn adjusted_array_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 + } else { + index - 1 + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None } - }; + } - Ok(res) -} + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; -/// put values in an array. -pub fn array(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values - .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }) - .collect(); - Ok(ColumnarValue::Array(array_array(arrays.as_slice())?)) -} + // array is null + if len == 0 { + mutable.extend_nulls(1); + continue; + } -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; + let index = adjusted_array_index(indexes.value(row_index), len); + + if let Some(index) = index { + mutable.extend(0, start + index as usize, start + index as usize + 1); + } else { + // Index out of bounds + mutable.extend_nulls(1); + } + } + + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } -macro_rules! append { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let concat = compute::concat(&[child_array, element])?; - let mut scalars = vec![]; - for i in 0..concat.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array( - &concat, i, - )?)); +fn general_except( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + + let l_values = l.values().to_owned(); + let r_values = r.values().to_owned(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let mut rows = Vec::with_capacity(l_values.num_rows()); + let mut dedup = HashSet::new(); + + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in r_slice { + let right_row = r_values.row(i); + dedup.insert(right_row); } - scalars - }}; + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } + + if let Some(values) = converter.convert_rows(rows)?.first() { + Ok(GenericListArray::::new( + field.to_owned(), + OffsetBuffer::new(offsets.into()), + values.to_owned(), + l.nulls().cloned(), + )) + } else { + internal_err!("array_except failed to convert rows") + } } -/// Array_append SQL function -pub fn array_append(args: &[ColumnarValue]) -> Result { +pub fn array_except(args: &[ArrayRef]) -> Result { if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "Array_append function requires two arguments, got {}", - args.len() - ))); + return internal_err!("array_except needs two arguments"); } - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + let array1 = &args[0]; + let array2 = &args[1]; + + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::List(field), DataType::List(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (DataType::LargeList(field), DataType::LargeList(_)) => { + check_datatypes("array_except", &[array1, array2])?; + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = general_except::(list1, list2, field)?; + Ok(Arc::new(result)) + } + (dt1, dt2) => { + internal_err!("array_except got unexpected types: {dt1:?} and {dt2:?}") + } + } +} - let element = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_append function requires scalar element".to_string(), - )) +/// array_slice SQL function +/// +/// We follow the behavior of array_slice in DuckDB +/// Note that array_slice is 1-indexed. And there are two additional arguments `from` and `to` in array_slice. +/// +/// > array_slice(array, from, to) +/// +/// Positive index is treated as the index from the start of the array. If the +/// `from` index is smaller than 1, it is treated as 1. If the `to` index is larger than the +/// length of the array, it is treated as the length of the array. +/// +/// Negative index is treated as the index from the end of the array. If the index +/// is larger than the length of the array, it is NOT VALID, either in `from` or `to`. +/// The `to` index is exclusive like python slice syntax. +/// +/// See test cases in `array.slt` for more details. +pub fn array_slice(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + + let values = list_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls. + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + // We have the slice syntax compatible with DuckDB v0.8.1. + // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. + + fn adjusted_from_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + index + len as i64 + } else { + // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) + std::cmp::max(index - 1, 0) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None } - }; + } - let data_type = arr.data_type(); - let arrays = match data_type { - DataType::List(field) => { - match (field.data_type(), element.data_type()) { - (DataType::Utf8, DataType::Utf8) => append!(arr, element, StringArray), - (DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, element, LargeStringArray), - (DataType::Boolean, DataType::Boolean) => append!(arr, element, BooleanArray), - (DataType::Float32, DataType::Float32) => append!(arr, element, Float32Array), - (DataType::Float64, DataType::Float64) => append!(arr, element, Float64Array), - (DataType::Int8, DataType::Int8) => append!(arr, element, Int8Array), - (DataType::Int16, DataType::Int16) => append!(arr, element, Int16Array), - (DataType::Int32, DataType::Int32) => append!(arr, element, Int32Array), - (DataType::Int64, DataType::Int64) => append!(arr, element, Int64Array), - (DataType::UInt8, DataType::UInt8) => append!(arr, element, UInt8Array), - (DataType::UInt16, DataType::UInt16) => append!(arr, element, UInt16Array), - (DataType::UInt32, DataType::UInt32) => append!(arr, element, UInt32Array), - (DataType::UInt64, DataType::UInt64) => append!(arr, element, UInt64Array), - (array_data_type, element_data_type) => { - return Err(DataFusionError::NotImplemented(format!( - "Array_append is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." - ))) - } - } + fn adjusted_to_index(index: i64, len: usize) -> Option { + // 0 ~ len - 1 + let adjusted_zero_index = if index < 0 { + // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive + index + len as i64 - 1 + } else { + // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) + std::cmp::min(index - 1, len as i64 - 1) + }; + + if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { + Some(adjusted_zero_index) + } else { + // Out of bounds + None } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))) + } + + let mut offsets = vec![0]; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + // len 0 indicate array is null, return empty array in this row. + if len == 0 { + offsets.push(offsets[row_index]); + continue; } - }; - array(arrays.as_slice()) -} + // If index is null, we consider it as the minimum / maximum index of the array. + let from_index = if from_array.is_null(row_index) { + Some(0) + } else { + adjusted_from_index(from_array.value(row_index), len) + }; + + let to_index = if to_array.is_null(row_index) { + Some(len as i64 - 1) + } else { + adjusted_to_index(to_array.value(row_index), len) + }; -macro_rules! prepend { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let concat = compute::concat(&[element, child_array])?; - let mut scalars = vec![]; - for i in 0..concat.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array( - &concat, i, - )?)); + if let (Some(from), Some(to)) = (from_index, to_index) { + if from <= to { + assert!(start + to as usize <= end); + mutable.extend(0, start + from as usize, start + to as usize + 1); + offsets.push(offsets[row_index] + (to - from + 1) as i32); + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); + } + } else { + // invalid range, return empty array + offsets.push(offsets[row_index]); } - scalars - }}; + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } -/// Array_prepend SQL function -pub fn array_prepend(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "Array_prepend function requires two arguments, got {}", - args.len() - ))); - } +/// array_pop_back SQL function +pub fn array_pop_back(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let from_array = Int64Array::from(vec![1; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64 - 1)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) +} - let element = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_prepend function requires scalar element".to_string(), - )) +/// Appends or prepends elements to a ListArray. +/// +/// This function takes a ListArray, an ArrayRef, a FieldRef, and a boolean flag +/// indicating whether to append or prepend the elements. It returns a `Result` +/// representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `list_array` - A reference to the ListArray to which elements will be appended/prepended. +/// * `element_array` - A reference to the Array containing elements to be appended/prepended. +/// * `field` - A reference to the Field describing the data type of the arrays. +/// * `is_append` - A boolean flag indicating whether to append (`true`) or prepend (`false`) elements. +/// +/// # Examples +/// +/// general_append_and_prepend( +/// [1, 2, 3], 4, append => [1, 2, 3, 4] +/// 5, [6, 7, 8], prepend => [5, 6, 7, 8] +/// ) +fn general_append_and_prepend( + list_array: &ListArray, + element_array: &ArrayRef, + data_type: &DataType, + is_append: bool, +) -> Result { + let mut offsets = vec![0]; + let values = list_array.values(); + let original_data = values.to_data(); + let element_data = element_array.to_data(); + let capacity = Capacities::Array(original_data.len() + element_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &element_data], + false, + capacity, + ); + + let values_index = 0; + let element_index = 1; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + if is_append { + mutable.extend(values_index, start, end); + mutable.extend(element_index, row_index, row_index + 1); + } else { + mutable.extend(element_index, row_index, row_index + 1); + mutable.extend(values_index, start, end); } - }; + offsets.push(offsets[row_index] + (end - start + 1) as i32); + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) +} - let arr = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub fn gen_range(args: &[ArrayRef]) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return internal_err!("gen_range expects 1 to 3 arguments"), }; - let data_type = arr.data_type(); - let arrays = match data_type { - DataType::List(field) => { - match (field.data_type(), element.data_type()) { - (DataType::Utf8, DataType::Utf8) => prepend!(arr, element, StringArray), - (DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, element, LargeStringArray), - (DataType::Boolean, DataType::Boolean) => prepend!(arr, element, BooleanArray), - (DataType::Float32, DataType::Float32) => prepend!(arr, element, Float32Array), - (DataType::Float64, DataType::Float64) => prepend!(arr, element, Float64Array), - (DataType::Int8, DataType::Int8) => prepend!(arr, element, Int8Array), - (DataType::Int16, DataType::Int16) => prepend!(arr, element, Int16Array), - (DataType::Int32, DataType::Int32) => prepend!(arr, element, Int32Array), - (DataType::Int64, DataType::Int64) => prepend!(arr, element, Int64Array), - (DataType::UInt8, DataType::UInt8) => prepend!(arr, element, UInt8Array), - (DataType::UInt16, DataType::UInt16) => prepend!(arr, element, UInt16Array), - (DataType::UInt32, DataType::UInt32) => prepend!(arr, element, UInt32Array), - (DataType::UInt64, DataType::UInt64) => prepend!(arr, element, UInt64Array), - (array_data_type, element_data_type) => { - return Err(DataFusionError::NotImplemented(format!( - "Array_prepend is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." - ))) - } - } + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + if step == 0 { + return exec_err!("step can't be 0 for function range(start [, stop, step]"); + } + if step < 0 { + // Decreasing range + values.extend((stop + 1..start + 1).rev().step_by((-step) as usize)); + } else { + // Increasing range + values.extend((start..stop).step_by(step as usize)); } + + offsets.push(values.len() as i32); + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + None, + )?); + Ok(arr) +} + +/// array_pop_front SQL function +pub fn array_pop_front(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let from_array = Int64Array::from(vec![2; list_array.len()]); + let to_array = Int64Array::from( + list_array + .iter() + .map(|arr| arr.map_or(0, |arr| arr.len() as i64)) + .collect::>(), + ); + let args = vec![args[0].clone(), Arc::new(from_array), Arc::new(to_array)]; + array_slice(args.as_slice()) +} + +/// Array_append SQL function +pub fn array_append(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; + + check_datatypes("array_append", &[list_array.values(), element_array])?; + let res = match list_array.value_type() { + DataType::List(_) => concat_internal(args)?, + DataType::Null => return make_array(&[element_array.to_owned()]), data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))) + return general_append_and_prepend( + list_array, + element_array, + &data_type, + true, + ); } }; - array(arrays.as_slice()) + Ok(res) } -/// Array_concat/Array_cat SQL function -pub fn array_concat(args: &[ColumnarValue]) -> Result { - let arrays: Vec = args - .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }) - .collect(); - let data_type = arrays[0].data_type(); - match data_type { - DataType::List(..) => { - let list_arrays = - downcast_vec!(arrays, ListArray).collect::>>()?; - let len: usize = list_arrays.iter().map(|a| a.values().len()).sum(); - let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum()); - let array_data: Vec<_> = - list_arrays.iter().map(|a| a.to_data()).collect::>(); - let array_data = array_data.iter().collect(); - let mut mutable = - MutableArrayData::with_capacities(array_data, false, capacity); +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) + } + _ => return internal_err!("array_sort expects 1 to 3 arguments"), + }; - for (i, a) in list_arrays.iter().enumerate() { - mutable.extend(i, 0, a.len()) - } + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); - let builder = mutable.into_builder(); - let list = builder - .len(1) - .buffers(vec![Buffer::from_slice_ref([0, len as i32])]) - .build() - .unwrap(); + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); - return Ok(ColumnarValue::Array(Arc::new(make_array(list)))); + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); } - _ => Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." - ))), } -} -macro_rules! fill { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); - let mut acc = ColumnarValue::Scalar($ELEMENT); - for value in arr.iter().rev() { - match value { - Some(value) => { - let mut repeated = vec![]; - for _ in 0..value { - repeated.push(acc.clone()); - } - acc = array(repeated.as_slice()).unwrap(); - } - _ => { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires non nullable array" - ))); - } - } - } + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} - acc - }}; +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + } } -/// Array_fill SQL function -pub fn array_fill(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires two arguments, got {}", - args.len() - ))); +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => internal_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), } +} - let element = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_fill function requires scalar element".to_string(), - )) +/// Array_prepend SQL function +pub fn array_prepend(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[1])?; + let element_array = &args[0]; + + check_datatypes("array_prepend", &[element_array, list_array.values()])?; + let res = match list_array.value_type() { + DataType::List(_) => concat_internal(args)?, + DataType::Null => return make_array(&[element_array.to_owned()]), + data_type => { + return general_append_and_prepend( + list_array, + element_array, + &data_type, + false, + ); } }; - let arr = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + Ok(res) +} - let res = match arr.data_type() { - DataType::List(..) => { - let arr = downcast_arg!(arr, ListArray); - let array_values = arr.values(); - match arr.value_type() { - DataType::Int8 => fill!(array_values, element, Int8Array), - DataType::Int16 => fill!(array_values, element, Int16Array), - DataType::Int32 => fill!(array_values, element, Int32Array), - DataType::Int64 => fill!(array_values, element, Int64Array), - DataType::UInt8 => fill!(array_values, element, UInt8Array), - DataType::UInt16 => fill!(array_values, element, UInt16Array), - DataType::UInt32 => fill!(array_values, element, UInt32Array), - DataType::UInt64 => fill!(array_values, element, UInt64Array), - data_type => { - return Err(DataFusionError::Internal(format!( - "Array_fill is not implemented for type '{data_type:?}'." - ))); +fn align_array_dimensions(args: Vec) -> Result> { + let args_ndim = args + .iter() + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) + .collect::>(); + let max_ndim = args_ndim.iter().max().unwrap_or(&0); + + // Align the dimensions of the arrays + let aligned_args: Result> = args + .into_iter() + .zip(args_ndim.iter()) + .map(|(array, ndim)| { + if ndim < max_ndim { + let mut aligned_array = array.clone(); + for _ in 0..(max_ndim - ndim) { + let data_type = aligned_array.data_type().to_owned(); + let array_lengths = vec![1; aligned_array.len()]; + let offsets = OffsetBuffer::::from_lengths(array_lengths); + + aligned_array = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type, true)), + offsets, + aligned_array, + None, + )?) } + Ok(aligned_array) + } else { + Ok(array.clone()) } - } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))); - } - }; + }) + .collect(); - Ok(res) + aligned_args } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE).value(0); +// Concatenate arrays on the same row. +fn concat_internal(args: &[ArrayRef]) -> Result { + let args = align_array_dimensions(args.to_vec())?; + + let list_arrays = + downcast_vec!(args, ListArray).collect::>>()?; - match child_array + // Assume number of rows is the same for all arrays + let row_count = list_arrays[0].len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + let nulls = list_arrays .iter() - .skip($INDEX) - .position(|x| x == Some(element)) - { - Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some( - (value + $INDEX + 1) as u8, - )))), - None => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + .map(|arr| arr.is_null(i)) + .collect::>(); + + // If all the arrays are null, the concatenated array is null + let is_null = nulls.iter().all(|&x| x); + if is_null { + array_lengths.push(0); + valid.append(false); + } else { + // Get all the arrays on i-th row + let values = list_arrays + .iter() + .map(|arr| arr.value(i)) + .collect::>(); + + let elements = values + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + // Concatenated array on i-th row + let concated_array = compute::concat(elements.as_slice())?; + array_lengths.push(concated_array.len()); + arrays.push(concated_array); + valid.append(true); } - }}; -} + } + // Assume all arrays have the same data type + let data_type = list_arrays[0].value_type(); + let buffer = valid.finish(); -/// Array_position SQL function -pub fn array_position(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); - let element = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_position function requires scalar element".to_string(), - )) + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + + Ok(Arc::new(list_arr)) +} + +/// Array_concat/Array_cat SQL function +pub fn array_concat(args: &[ArrayRef]) -> Result { + let mut new_args = vec![]; + for arg in args { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { + new_args.push(arg.clone()); } - }; + } - let mut index: usize = 0; - if args.len() == 3 { - let scalar = - match &args[2] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => return Err(DataFusionError::Internal( - "Array_position function requires positive integer scalar element" - .to_string(), - )), - }; + concat_internal(new_args.as_slice()) +} - index = - match scalar { - ScalarValue::Int8(Some(value)) => value as usize, - ScalarValue::Int16(Some(value)) => value as usize, - ScalarValue::Int32(Some(value)) => value as usize, - ScalarValue::Int64(Some(value)) => value as usize, - ScalarValue::UInt8(Some(value)) => value as usize, - ScalarValue::UInt16(Some(value)) => value as usize, - ScalarValue::UInt32(Some(value)) => value as usize, - ScalarValue::UInt64(Some(value)) => value as usize, - _ => return Err(DataFusionError::Internal( - "Array_position function requires positive integer scalar element" - .to_string(), - )), - }; +/// Array_empty SQL function +pub fn array_empty(args: &[ArrayRef]) -> Result { + if as_null_array(&args[0]).is_ok() { + // Make sure to return Boolean type. + return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); + } + let array_type = args[0].data_type(); - if index == 0 { - index = 0; - } else { - index -= 1; - } + match array_type { + DataType::List(_) => array_empty_dispatch::(&args[0]), + DataType::LargeList(_) => array_empty_dispatch::(&args[0]), + _ => internal_err!("array_empty does not support type '{array_type:?}'."), } +} - match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::Utf8 => position!(arr, element, index, StringArray), - DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), - DataType::Boolean => position!(arr, element, index, BooleanArray), - DataType::Float32 => position!(arr, element, index, Float32Array), - DataType::Float64 => position!(arr, element, index, Float64Array), - DataType::Int8 => position!(arr, element, index, Int8Array), - DataType::Int16 => position!(arr, element, index, Int16Array), - DataType::Int32 => position!(arr, element, index, Int32Array), - DataType::Int64 => position!(arr, element, index, Int64Array), - DataType::UInt8 => position!(arr, element, index, UInt8Array), - DataType::UInt16 => position!(arr, element, index, UInt16Array), - DataType::UInt32 => position!(arr, element, index, UInt32Array), - DataType::UInt64 => position!(arr, element, index, UInt64Array), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array_position is not implemented for types '{data_type:?}'." - ))), - }, - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." - ))), +fn array_empty_dispatch(array: &ArrayRef) -> Result { + let array = as_generic_list_array::(array)?; + let builder = array + .iter() + .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) + .collect::(); + Ok(Arc::new(builder)) +} + +/// Array_repeat SQL function +pub fn array_repeat(args: &[ArrayRef]) -> Result { + let element = &args[0]; + let count_array = as_int64_array(&args[1])?; + + match element.data_type() { + DataType::List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat(list_array, count_array) + } + _ => general_repeat(element, count_array), } } -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE).value(0); +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); - let mut res = vec![]; - for (i, x) in child_array.iter().enumerate() { - if x == Some(element) { - res.push(ScalarValue::UInt8(Some((i + 1) as u8))); + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); } - } - let field = Arc::new(Field::new("item", DataType::UInt8, true)); - Ok(ColumnarValue::Scalar(ScalarValue::List(Some(res), field))) - }}; + let data = mutable.freeze(); + arrow_array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) } -/// Array_positions SQL function -pub fn array_positions(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &ListArray, + count_array: &Int64Array, +) -> Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } - let element = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_positions function requires scalar element".to_string(), - )) - } - }; + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); - match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::Utf8 => positions!(arr, element, StringArray), - DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), - DataType::Boolean => positions!(arr, element, BooleanArray), - DataType::Float32 => positions!(arr, element, Float32Array), - DataType::Float64 => positions!(arr, element, Float64Array), - DataType::Int8 => positions!(arr, element, Int8Array), - DataType::Int16 => positions!(arr, element, Int16Array), - DataType::Int32 => positions!(arr, element, Int32Array), - DataType::Int64 => positions!(arr, element, Int64Array), - DataType::UInt8 => positions!(arr, element, UInt8Array), - DataType::UInt16 => positions!(arr, element, UInt16Array), - DataType::UInt32 => positions!(arr, element, UInt32Array), - DataType::UInt64 => positions!(arr, element, UInt64Array), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array_positions is not implemented for types '{data_type:?}'." - ))), - }, - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." - ))), + let list_arr = ListArray::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(lengths), + values, + None, + )?)) } -macro_rules! remove { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE).value(0); - let mut builder = new_builder!($BUILDER_TYPE, child_array.len()); +/// Array_position SQL function +pub fn array_position(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; - for x in child_array { - match x { - Some(x) => { - if x != element { - builder.append_value(x); - } - } - None => builder.append_null(), + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() + .iter() + .map(|&x| x - 1) + .collect::>() + } else { + vec![0; list_array.len()] + }; + + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); } + } else { + // We will get null if we got null in the array, so we don't need to check } - let arr = builder.finish(); + } - let mut scalars = vec![]; - for i in 0..arr.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, i)?)); + general_position::(list_array, element_array, arr_from) +} + +fn general_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: Vec, // 0-indexed +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() + { + let from = from as usize; + + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); + + data.push(index); + } else { + data.push(None); } - scalars - }}; + } + + Ok(Arc::new(UInt64Array::from(data))) } -/// Array_remove SQL function -pub fn array_remove(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; +/// Array_positions SQL function +pub fn array_positions(args: &[ArrayRef]) -> Result { + let arr = as_list_array(&args[0])?; + let element = &args[1]; - let element = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_remove function requires scalar element".to_string(), - )) + check_datatypes("array_positions", &[arr.values(), element])?; + + general_positions::(arr, element) +} + +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let indexes = eq_array + .iter() + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .collect::>(); + + data.push(Some(indexes)); + } else { + data.push(None); } - }; + } - let data_type = arr.data_type(); - let res = match data_type { - DataType::List(field) => { - match (field.data_type(), element.data_type()) { - (DataType::Utf8, DataType::Utf8) => remove!(arr, element, StringArray, StringBuilder), - (DataType::LargeUtf8, DataType::LargeUtf8) => remove!(arr, element, LargeStringArray, LargeStringBuilder), - (DataType::Boolean, DataType::Boolean) => remove!(arr, element, BooleanArray, BooleanBuilder), - (DataType::Float32, DataType::Float32) => remove!(arr, element, Float32Array, Float32Builder), - (DataType::Float64, DataType::Float64) => remove!(arr, element, Float64Array, Float64Builder), - (DataType::Int8, DataType::Int8) => remove!(arr, element, Int8Array, Int8Builder), - (DataType::Int16, DataType::Int16) => remove!(arr, element, Int16Array, Int16Builder), - (DataType::Int32, DataType::Int32) => remove!(arr, element, Int32Array, Int32Builder), - (DataType::Int64, DataType::Int64) => remove!(arr, element, Int64Array, Int64Builder), - (DataType::UInt8, DataType::UInt8) => remove!(arr, element, UInt8Array, UInt8Builder), - (DataType::UInt16, DataType::UInt16) => remove!(arr, element, UInt16Array, UInt16Builder), - (DataType::UInt32, DataType::UInt32) => remove!(arr, element, UInt32Array, UInt32Builder), - (DataType::UInt64, DataType::UInt64) => remove!(arr, element, UInt64Array, UInt64Builder), - (array_data_type, element_data_type) => { - return Err(DataFusionError::NotImplemented(format!( - "Array_remove is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." - ))) - } + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) +} + +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// of `element_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `element_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to remove a list array (where each element is a +/// list of int32s, the second argument are int32 arrays, and the +/// third argument is the number of occurrences to remove +/// +/// ```text +/// general_remove( +/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed) +/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed) +/// ) +/// ``` +fn general_remove( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + let data_type = list_array.value_type(); + let mut new_values = vec![]; + // Build up the offsets for the final output array + let mut offsets = Vec::::with_capacity(arr_n.len() + 1); + offsets.push(OffsetSize::zero()); + + // n is the number of elements to remove in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + match list_array_row { + Some(list_array_row) => { + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; + + // We need to keep at most first n elements as `false`, which represent the elements to remove. + let eq_array = if eq_array.false_count() < *n as usize { + eq_array + } else { + let mut count = 0; + eq_array + .iter() + .map(|e| { + // Keep first n `false` elements, and reverse other elements to `true`. + if let Some(false) = e { + if count < *n { + count += 1; + e + } else { + Some(true) + } + } else { + e + } + }) + .collect::() + }; + + let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; + offsets.push( + offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), + ); + new_values.push(filtered_array); + } + None => { + // Null element results in a null row (no new offsets) + offsets.push(offsets[row_index]); } } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))) - } + } + + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); + arrow::compute::concat(&new_values)? }; - array(res.as_slice()) + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + values, + list_array.nulls().cloned(), + )?)) +} + +fn array_remove_internal( + array: &ArrayRef, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + _ => internal_err!("array_remove_all expects a list array"), + } +} + +pub fn array_remove_all(args: &[ArrayRef]) -> Result { + let arr_n = vec![i64::MAX; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -macro_rules! replace { - ($ARRAY:expr, $FROM:expr, $TO:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let child_array = - downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE); - let from = downcast_arg!($FROM, $ARRAY_TYPE).value(0); - let to = downcast_arg!($TO, $ARRAY_TYPE).value(0); - let mut builder = new_builder!($BUILDER_TYPE, child_array.len()); +pub fn array_remove(args: &[ArrayRef]) -> Result { + let arr_n = vec![1; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) +} - for x in child_array { - match x { - Some(x) => { - if x == from { - builder.append_value(to); - } else { - builder.append_value(x); - } - } - None => builder.append_null(), - } - } - let arr = builder.finish(); +pub fn array_remove_n(args: &[ArrayRef]) -> Result { + let arr_n = as_int64_array(&args[2])?.values().to_vec(); + array_remove_internal(&args[0], &args[1], arr_n) +} - let mut scalars = vec![]; - for i in 0..arr.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, i)?)); +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &ListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![0]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut valid = BooleanBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; } - scalars - }}; -} -/// Array_replace SQL function -pub fn array_replace(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; - let from = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "array_replace function requires scalar element".to_string(), - )) - } - }; + let list_array_row = list_array.value(row_index); - let to = match &args[2] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - _ => { - return Err(DataFusionError::Internal( - "array_replace function requires scalar element".to_string(), - )) - } - }; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - if from.data_type() != to.data_type() { - return Err(DataFusionError::Internal( - "array_replace function requires scalar element".to_string(), - )); - } + let original_idx = 0; + let replace_idx = 1; + let n = arr_n[row_index]; + let mut counter = 0; - let data_type = arr.data_type(); - let res = match data_type { - DataType::List(field) => { - match (field.data_type(), from.data_type()) { - (DataType::Utf8, DataType::Utf8) => replace!(arr, from, to, StringArray, StringBuilder), - (DataType::LargeUtf8, DataType::LargeUtf8) => replace!(arr, from, to, LargeStringArray, LargeStringBuilder), - (DataType::Boolean, DataType::Boolean) => replace!(arr, from, to, BooleanArray, BooleanBuilder), - (DataType::Float32, DataType::Float32) => replace!(arr, from, to, Float32Array, Float32Builder), - (DataType::Float64, DataType::Float64) => replace!(arr, from, to, Float64Array, Float64Builder), - (DataType::Int8, DataType::Int8) => replace!(arr, from, to, Int8Array, Int8Builder), - (DataType::Int16, DataType::Int16) => replace!(arr, from, to, Int16Array, Int16Builder), - (DataType::Int32, DataType::Int32) => replace!(arr, from, to, Int32Array, Int32Builder), - (DataType::Int64, DataType::Int64) => replace!(arr, from, to, Int64Array, Int64Builder), - (DataType::UInt8, DataType::UInt8) => replace!(arr, from, to, UInt8Array, UInt8Builder), - (DataType::UInt16, DataType::UInt16) => replace!(arr, from, to, UInt16Array, UInt16Builder), - (DataType::UInt32, DataType::UInt32) => replace!(arr, from, to, UInt32Array, UInt32Builder), - (DataType::UInt64, DataType::UInt64) => replace!(arr, from, to, UInt64Array, UInt64Builder), - (array_data_type, element_data_type) => { - return Err(DataFusionError::NotImplemented(format!( - "Array_replace is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." - ))) + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend(original_idx, start, end); + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + if let Some(true) = to_replace { + mutable.extend(replace_idx, row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend(original_idx, start + i + 1, end); + break; } + } else { + // copy original data for false / null matches + mutable.extend(original_idx, start + i, start + i + 1); } } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))) - } - }; - array(res.as_slice()) + offsets.push(offsets[row_index] + (end - start) as i32); + valid.append(true); + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), + )?)) +} + +pub fn array_replace(args: &[ArrayRef]) -> Result { + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) +} + +pub fn array_replace_n(args: &[ArrayRef]) -> Result { + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) +} + +pub fn array_replace_all(args: &[ArrayRef]) -> Result { + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMETER:expr, $ARRAY_TYPE:ident) => {{ + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); for x in arr { match x { Some(x) => { $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMETER); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } } - None => {} } } - Ok($ARG) }}; } -/// Array_to_string SQL function -pub fn array_to_string(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; +fn union_generic_lists( + l: &GenericListArray, + r: &GenericListArray, + field: &FieldRef, +) -> Result> { + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; + + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + let l_values = l.values().clone(); + let r_values = r.values().clone(); + let l_values = converter.convert_columns(&[l_values])?; + let r_values = converter.convert_columns(&[r_values])?; + + // Might be worth adding an upstream OffsetBufferBuilder + let mut offsets = Vec::::with_capacity(l.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows()); + let mut dedup = HashSet::new(); + for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { + let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); + let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); + for i in l_slice { + let left_row = l_values.row(i); + if dedup.insert(left_row) { + rows.push(left_row); + } + } + for i in r_slice { + let right_row = r_values.row(i); + if dedup.insert(right_row) { + rows.push(right_row); + } + } + offsets.push(OffsetSize::usize_as(rows.len())); + dedup.clear(); + } - let scalar = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_to_string function requires scalar element".to_string(), - )) + let values = converter.convert_rows(rows)?; + let offsets = OffsetBuffer::new(offsets.into()); + let result = values[0].clone(); + Ok(GenericListArray::::new( + field.clone(), + offsets, + result, + nulls, + )) +} + +/// Array_union SQL function +pub fn array_union(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_union needs two arguments"); + } + let array1 = &args[0]; + let array2 = &args[1]; + + fn union_arrays( + array1: &ArrayRef, + array2: &ArrayRef, + l_field_ref: &Arc, + r_field_ref: &Arc, + ) -> Result { + match (l_field_ref.data_type(), r_field_ref.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (_, _) => { + let list1 = array1.as_list::(); + let list2 = array2.as_list::(); + let result = union_generic_lists::(list1, list2, l_field_ref)?; + Ok(Arc::new(result)) + } } - }; + } - let delimeter = match scalar { - ScalarValue::Utf8(Some(value)) => String::from(&value), + match (array1.data_type(), array2.data_type()) { + (DataType::Null, _) => Ok(array2.clone()), + (_, DataType::Null) => Ok(array1.clone()), + (DataType::List(l_field_ref), DataType::List(r_field_ref)) => { + union_arrays::(array1, array2, l_field_ref, r_field_ref) + } + (DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => { + union_arrays::(array1, array2, l_field_ref, r_field_ref) + } _ => { - return Err(DataFusionError::Internal( - "Array_to_string function requires positive integer scalar element" - .to_string(), - )) + internal_err!( + "array_union only support list with offsets of type int32 and int64" + ) } - }; + } +} + +/// Array_to_string SQL function +pub fn array_to_string(args: &[ArrayRef]) -> Result { + let arr = &args[0]; + + let delimiters = as_string_array(&args[1])?; + let delimiters: Vec> = delimiters.iter().collect(); + + let mut null_string = String::from(""); + let mut with_null_string = false; + if args.len() == 3 { + null_string = as_string_array(&args[2])?.value(0).to_string(); + with_null_string = true; + } fn compute_array_to_string( arg: &mut String, arr: ArrayRef, - delimeter: String, + delimiter: String, + null_string: String, + with_null_string: bool, ) -> Result<&mut String> { match arr.data_type() { DataType::List(..) => { let list_array = downcast_arg!(arr, ListArray); for i in 0..list_array.len() { - compute_array_to_string(arg, list_array.value(i), delimeter.clone())?; + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; } Ok(arg) } - DataType::Utf8 => to_string!(arg, arr, &delimeter, StringArray), - DataType::LargeUtf8 => to_string!(arg, arr, &delimeter, LargeStringArray), - DataType::Boolean => to_string!(arg, arr, &delimeter, BooleanArray), - DataType::Float32 => to_string!(arg, arr, &delimeter, Float32Array), - DataType::Float64 => to_string!(arg, arr, &delimeter, Float64Array), - DataType::Int8 => to_string!(arg, arr, &delimeter, Int8Array), - DataType::Int16 => to_string!(arg, arr, &delimeter, Int16Array), - DataType::Int32 => to_string!(arg, arr, &delimeter, Int32Array), - DataType::Int64 => to_string!(arg, arr, &delimeter, Int64Array), - DataType::UInt8 => to_string!(arg, arr, &delimeter, UInt8Array), - DataType::UInt16 => to_string!(arg, arr, &delimeter, UInt16Array), - DataType::UInt32 => to_string!(arg, arr, &delimeter, UInt32Array), - DataType::UInt64 => to_string!(arg, arr, &delimeter, UInt64Array), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))), + DataType::Null => Ok(arg), + data_type => { + macro_rules! array_function { + ($ARRAY_TYPE:ident) => { + to_string!( + arg, + arr, + &delimiter, + &null_string, + with_null_string, + $ARRAY_TYPE + ) + }; + } + call_array_function!(data_type, false) + } } } let mut arg = String::from(""); - let mut res = compute_array_to_string(&mut arg, arr, delimeter.clone())?.clone(); - res.truncate(res.len() - delimeter.len()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))) -} + let mut res: Vec> = Vec::new(); -/// Trim_array SQL function -pub fn trim_array(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; - - let scalar = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Trim_array function requires positive integer scalar element" - .to_string(), - )) + match arr.data_type() { + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { + let list_array = arr.as_list::(); + for (arr, &delimiter) in list_array.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); + } else { + res.push(Some(s)); + } + } else { + res.push(None); + } + } } - }; - - let n = match scalar { - ScalarValue::Int8(Some(value)) => value as usize, - ScalarValue::Int16(Some(value)) => value as usize, - ScalarValue::Int32(Some(value)) => value as usize, - ScalarValue::Int64(Some(value)) => value as usize, - ScalarValue::UInt8(Some(value)) => value as usize, - ScalarValue::UInt16(Some(value)) => value as usize, - ScalarValue::UInt32(Some(value)) => value as usize, - ScalarValue::UInt64(Some(value)) => value as usize, _ => { - return Err(DataFusionError::Internal( - "Trim_array function requires positive integer scalar element" - .to_string(), - )) + // delimiter length is 1 + assert_eq!(delimiters.len(), 1); + let delimiter = delimiters[0].unwrap(); + let s = compute_array_to_string( + &mut arg, + arr.clone(), + delimiter.to_string(), + null_string, + with_null_string, + )? + .clone(); + + if !s.is_empty() { + let s = s.strip_suffix(delimiter).unwrap().to_string(); + res.push(Some(s)); + } else { + res.push(Some(s)); + } } - }; - - let list_array = downcast_arg!(arr, ListArray); - let values = list_array.value(0); - let res = values.slice(0, values.len() - n); - - let mut scalars = vec![]; - for i in 0..res.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&res, i)?)); } - array(scalars.as_slice()) + + Ok(Arc::new(StringArray::from(res))) } /// Cardinality SQL function -pub fn cardinality(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; +pub fn cardinality(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?.clone(); - fn compute_cardinality(arg: &mut u64, arr: ArrayRef) -> Result<&mut u64> { - match arr.data_type() { - DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); - for i in 0..list_array.len() { - compute_cardinality(arg, list_array.value(i))?; - } + let result = list_array + .iter() + .map(|arr| match compute_array_dims(arr)? { + Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), + None => Ok(None), + }) + .collect::>()?; - Ok(arg) + Ok(Arc::new(result) as ArrayRef) +} + +// Create new offsets that are euqiavlent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes.iter().map(|i| buffer[*i as usize]).collect(); + OffsetBuffer::new(offsets.into()) +} + +fn flatten_internal( + array: &dyn Array, + indexes: Option>, +) -> Result { + let list_arr = as_list_array(array)?; + let (field, offsets, values, _) = list_arr.clone().into_parts(); + let data_type = field.data_type(); + + match data_type { + // Recursively get the base offsets for flattened array + DataType::List(_) => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + flatten_internal(&values, Some(offsets)) + } else { + flatten_internal(&values, Some(offsets)) } - DataType::Null - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - *arg += arr.len() as u64; - Ok(arg) + } + // Reach the base level, create a new list array + _ => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + let list_arr = ListArray::new(field, offsets, values, None); + Ok(list_arr) + } else { + Ok(list_arr.clone()) } - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))), } } - let mut arg: u64 = 0; - Ok(ColumnarValue::Array(Arc::new(UInt64Array::from(vec![ - *compute_cardinality(&mut arg, arr)?, - ])))) } -/// Array_length SQL function -pub fn array_length(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), +/// Flatten SQL function +pub fn flatten(args: &[ArrayRef]) -> Result { + let flattened_array = flatten_internal(&args[0], None)?; + Ok(Arc::new(flattened_array) as ArrayRef) +} + +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() + } else { + Int64Array::from_value(1, list_array.len()) }; - let mut element: u8 = 1; - if args.len() == 2 { - let scalar = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_length function requires positive integer scalar element" - .to_string(), - )) - } - }; - element = match scalar { - ScalarValue::Int8(Some(value)) => value as u8, - ScalarValue::Int16(Some(value)) => value as u8, - ScalarValue::Int32(Some(value)) => value as u8, - ScalarValue::Int64(Some(value)) => value as u8, - ScalarValue::UInt8(Some(value)) => value, - ScalarValue::UInt16(Some(value)) => value as u8, - ScalarValue::UInt32(Some(value)) => value as u8, - ScalarValue::UInt64(Some(value)) => value as u8, - _ => { - return Err(DataFusionError::Internal( - "Array_length function requires positive integer scalar element" - .to_string(), - )) - } - }; + let result = list_array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; - if element == 0 { - return Err(DataFusionError::Internal( - "Array_length function requires positive integer scalar element" - .to_string(), - )); - } - } + Ok(Arc::new(result) as ArrayRef) +} - fn compute_array_length(arg: u8, array: ArrayRef, element: u8) -> Result> { - match array.data_type() { - DataType::List(..) => { - let list_array = downcast_arg!(array, ListArray); - if arg == element + 1 { - Ok(Some(list_array.len() as u8)) - } else { - compute_array_length(arg + 1, list_array.value(0), element) - } - } - DataType::Null - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - if arg == element + 1 { - Ok(Some(array.len() as u8)) - } else { - Ok(None) - } - } - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))), - } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + _ => internal_err!( + "array_length does not support type '{:?}'", + args[0].data_type() + ), } - let arg: u8 = 1; - Ok(ColumnarValue::Array(Arc::new(UInt8Array::from(vec![ - compute_array_length(arg, arr, element)?, - ])))) } /// Array_dims SQL function -pub fn array_dims(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Array(arr) => arr.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - }; +pub fn array_dims(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; - fn compute_array_dims( - arg: &mut Vec, - arr: ArrayRef, - ) -> Result<&mut Vec> { - match arr.data_type() { - DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray).value(0); - arg.push(ScalarValue::UInt8(Some(list_array.len() as u8))); - return compute_array_dims(arg, list_array); + let data = list_array + .iter() + .map(compute_array_dims) + .collect::>>()?; + let result = ListArray::from_iter_primitive::(data); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Array_ndims SQL function +pub fn array_ndims(args: &[ArrayRef]) -> Result { + if let Some(list_array) = args[0].as_list_opt::() { + let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); + + let mut data = vec![]; + for arr in list_array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) } - DataType::Null - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(arg), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))), } + + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } else { + Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) } +} - let list_field = Arc::new(Field::new("item", DataType::UInt8, true)); - let mut arg: Vec = vec![]; - Ok(ColumnarValue::Scalar(ScalarValue::List( - Some(compute_array_dims(&mut arg, arr)?.clone()), - list_field, - ))) +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, + // array_has + Single, } -/// Array_ndims SQL function -pub fn array_ndims(args: &[ColumnarValue]) -> Result { - let arr = match &args[0] { - ColumnarValue::Array(arr) => arr.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), +fn general_array_has_dispatch( + array: &ArrayRef, + sub_array: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let array = if comparison_type == ComparisonType::Single { + let arr = as_generic_list_array::(array)?; + check_datatypes("array_has", &[arr.values(), sub_array])?; + arr + } else { + check_datatypes("array_has", &[array, sub_array])?; + as_generic_list_array::(array)? }; - fn compute_array_ndims(arg: u8, arr: ArrayRef) -> Result { - match arr.data_type() { - DataType::List(..) => { - let list_array = downcast_arg!(arr, ListArray); - compute_array_ndims(arg + 1, list_array.value(0)) + let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + + let element = sub_array.clone(); + let sub_array = if comparison_type != ComparisonType::Single { + as_generic_list_array::(sub_array)? + } else { + array + }; + + for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = if comparison_type != ComparisonType::Single { + converter.convert_columns(&[sub_arr])? + } else { + converter.convert_columns(&[element.clone()])? + }; + + let mut res = match comparison_type { + ComparisonType::All => sub_arr_values + .iter() + .dedup() + .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Any => sub_arr_values + .iter() + .dedup() + .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), + ComparisonType::Single => arr_values + .iter() + .dedup() + .any(|x| x == sub_arr_values.row(row_idx)), + }; + + if comparison_type == ComparisonType::Any { + res |= res; } - DataType::Null - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Boolean - | DataType::Float32 - | DataType::Float64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => Ok(arg), - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))), + + boolean_builder.append_value(res); } } - let arg: u8 = 0; - Ok(ColumnarValue::Array(Arc::new(UInt8Array::from(vec![ - compute_array_ndims(arg, arr)?, - ])))) + Ok(Arc::new(boolean_builder.finish())) } -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::UInt8Array; - use datafusion_common::cast::{ - as_generic_string_array, as_list_array, as_uint64_array, as_uint8_array, - }; - use datafusion_common::scalar::ScalarValue; +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - #[test] - fn test_array() { - // make_array(1, 2, 3) = [1, 2, 3] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ) + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Single) + } + _ => internal_err!("array_has does not support type '{array_type:?}'."), } +} - #[test] - fn test_nested_array() { - // make_array([1, 3, 5], [2, 4, 6]) = [[1, 3, 5], [2, 4, 6]] - let args = [ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![1, 2]))), - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 4]))), - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 6]))), - ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 2); - assert_eq!( - &[1, 3, 5], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - assert_eq!( - &[2, 4, 6], - result - .value(1) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); + + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => internal_err!("array_has_any does not support type '{array_type:?}'."), } +} - #[test] - fn test_array_append() { - // array_append([1, 2, 3], 4) = [1, 2, 3, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ - ScalarValue::Int64(Some(1)), - ScalarValue::Int64(Some(2)), - ScalarValue::Int64(Some(3)), - ]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - - let array = array_append(&args) - .expect("failed to initialize function array_append") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); +/// Array_has_all SQL function +pub fn array_has_all(args: &[ArrayRef]) -> Result { + let array_type = args[0].data_type(); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); + match array_type { + DataType::List(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) + } + _ => internal_err!("array_has_all does not support type '{array_type:?}'."), } +} - #[test] - fn test_array_prepend() { - // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ - ScalarValue::Int64(Some(2)), - ScalarValue::Int64(Some(3)), - ScalarValue::Int64(Some(4)), - ]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ]; - - let array = array_prepend(&args) - .expect("failed to initialize function array_append") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_append"); +/// Splits string at occurrences of delimiter and returns an array of parts +/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' +pub fn string_to_array(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + + let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + )); + + match args.len() { + 2 => { + string_array.iter().zip(delimiter_array.iter()).for_each( + |(string, delimiter)| { + match (string, delimiter) { + (Some(string), Some("")) => { + list_builder.values().append_value(string); + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + list_builder.values().append_value(s); + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + list_builder.values().append_value(c); + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }, + ); + } - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); + 3 => { + let null_value_array = as_generic_string_array::(&args[2])?; + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(s); + } + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }); + } + _ => { + return internal_err!( + "Expect string_to_array function to take two or three parameters" + ) + } } - #[test] - fn test_array_concat() { - // array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9] - let args = [ - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ - ScalarValue::Int64(Some(1)), - ScalarValue::Int64(Some(2)), - ScalarValue::Int64(Some(3)), - ]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ - ScalarValue::Int64(Some(4)), - ScalarValue::Int64(Some(5)), - ScalarValue::Int64(Some(6)), - ]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ - ScalarValue::Int64(Some(7)), - ScalarValue::Int64(Some(8)), - ScalarValue::Int64(Some(9)), - ]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ]; - - let array = array_concat(&args) - .expect("failed to initialize function array_concat") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_concat"); + let list_array = list_builder.finish(); + Ok(Arc::new(list_array) as ArrayRef) +} - assert_eq!( - &[1, 2, 3, 4, 5, 6, 7, 8, 9], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } +/// array_intersect SQL function +pub fn array_intersect(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 2); - #[test] - fn test_array_fill() { - // array_fill(4, [5]) = [4, 4, 4, 4, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ScalarValue::Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ]; - - let array = array_fill(&args) - .expect("failed to initialize function array_fill") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_fill"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 4, 4, 4, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } + let first_array = &args[0]; + let second_array = &args[1]; - #[test] - fn test_array_position() { - // array_position([1, 2, 3, 4], 3) = 3 - let list_array = return_array(); - let array = array_position(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]) - .expect("failed to initialize function array_position") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_position"); - - assert_eq!(result, &UInt8Array::from(vec![3])); - } + match (first_array.data_type(), second_array.data_type()) { + (DataType::Null, _) => Ok(second_array.clone()), + (_, DataType::Null) => Ok(first_array.clone()), + _ => { + let first_array = as_list_array(&first_array)?; + let second_array = as_list_array(&second_array)?; - #[test] - fn test_array_positions() { - // array_positions([1, 2, 3, 4], 3) = [3] - let list_array = return_array(); - let array = array_positions(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]) - .expect("failed to initialize function array_position") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_position"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); + } - #[test] - fn test_array_remove() { - // array_remove([1, 2, 3, 4], 3) = [1, 2, 4] - let list_array = return_array(); - let arr = array_remove(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]) - .expect("failed to initialize function array_remove") - .into_array(1); - let result = - as_list_array(&arr).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } + let dt = first_array.value_type(); - #[test] - fn test_array_replace() { - // array_replace([1, 2, 3, 4], 3, 4) = [1, 2, 4, 4] - let list_array = return_array(); - let array = array_replace(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]) - .expect("failed to initialize function array_replace") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_replace"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 4, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } + let mut offsets = vec![0]; + let mut new_arrays = vec![]; - #[test] - fn test_array_to_string() { - // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 - let list_array = return_array(); - let array = array_to_string(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), - ]) - .expect("failed to initialize function array_to_string") - .into_array(1); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1,2,3,4", result.value(0)); - } + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; - #[test] - fn test_nested_array_to_string() { - // array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], '-') = 1-2-3-4-5-6-7-8 - let list_array = return_nested_array(); - let array = array_to_string(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("-")))), - ]) - .expect("failed to initialize function array_to_string") - .into_array(1); - let result = as_generic_string_array::(&array) - .expect("failed to initialize function array_to_string"); - - assert_eq!(result.len(), 1); - assert_eq!("1-2-3-4-5-6-7-8", result.value(0)); - } + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } - #[test] - fn test_trim_array() { - // trim_array([1, 2, 3, 4], 1) = [1, 2, 3] - let list_array = return_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); + } + } - // trim_array([1, 2, 3, 4], 3) = [1] - let list_array = return_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = + new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) + } } +} - #[test] - fn test_nested_trim_array() { - // trim_array([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array(); - let arr = trim_array(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ]) - .expect("failed to initialize function trim_array") - .into_array(1); - let binding = as_list_array(&arr) - .expect("failed to initialize function trim_array") - .value(0); - let result = - as_list_array(&binding).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); +pub fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len()); + offsets.push(OffsetSize::usize_as(0)); + let mut new_arrays = Vec::with_capacity(array.len()); + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + // distinct for each list in ListArray + for arr in array.iter().flatten() { + let values = converter.convert_columns(&[arr])?; + // sort elements in list and remove duplicates + let rows = values.iter().sorted().dedup().collect::>(); + let last_offset: OffsetSize = offsets.last().copied().unwrap(); + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!("array_distinct: failed to get array from rows") + } + }; + new_arrays.push(array); } + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + offsets, + values, + None, + )?)) +} - #[test] - fn test_cardinality() { - // cardinality([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality") - .into_array(1); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![4])); - } +/// array_distinct SQL function +/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] +pub fn array_distinct(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 1); - #[test] - fn test_nested_cardinality() { - // cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]) = 8 - let list_array = return_nested_array(); - let arr = cardinality(&[list_array]) - .expect("failed to initialize function cardinality") - .into_array(1); - let result = - as_uint64_array(&arr).expect("failed to initialize function cardinality"); - - assert_eq!(result, &UInt64Array::from(vec![8])); + // handle null + if args[0].data_type() == &DataType::Null { + return Ok(args[0].clone()); } - #[test] - fn test_array_length() { - // array_length([1, 2, 3, 4]) = 4 - let list_array = return_array(); - let array = array_length(&[list_array.clone()]) - .expect("failed to initialize function array_ndims") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt8Array::from(vec![4])); - - // array_length([1, 2, 3, 4], 1) = 2 - let array = array_length(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::UInt8(Some(1_u8))), - ]) - .expect("failed to initialize function array_ndims") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt8Array::from(vec![4])); + // handle for list & largelist + match args[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&args[0])?; + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + general_array_distinct(array, field) + } + _ => internal_err!("array_distinct only support list array"), } +} - #[test] - fn test_nested_array_length() { - let list_array = return_nested_array(); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = 2 - let array = array_length(&[ - list_array.clone(), - ColumnarValue::Scalar(ScalarValue::UInt8(Some(1_u8))), - ]) - .expect("failed to initialize function array_length") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt8Array::from(vec![2])); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = 4 - let array = array_length(&[ - list_array.clone(), - ColumnarValue::Scalar(ScalarValue::UInt8(Some(2_u8))), - ]) - .expect("failed to initialize function array_length") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt8Array::from(vec![4])); - - // array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3) = NULL - let array = array_length(&[ - list_array, - ColumnarValue::Scalar(ScalarValue::UInt8(Some(3_u8))), - ]) - .expect("failed to initialize function array_length") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_length"); - - assert_eq!(result, &UInt8Array::from(vec![None])); - } +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Int64Type; + /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] - fn test_array_dims() { - // array_dims([1, 2, 3, 4]) = [4] - let list_array = return_array(); - - let array = array_dims(&[list_array]) - .expect("failed to initialize function array_dims") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); + fn test_align_array_dimensions() { + let array1d_1 = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])); + let array1d_2 = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(6), Some(7), Some(8)]), + ])); + + let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; + let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; + + let res = + align_array_dimensions(vec![array1d_1.to_owned(), array2d_2.to_owned()]) + .unwrap(); + let expected = as_list_array(&array2d_1).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - &[4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() + datafusion_common::utils::list_ndims(res[0].data_type()), + expected_dim ); - } - - #[test] - fn test_nested_array_dims() { - // array_dims([[1, 2, 3, 4], [5, 6, 7, 8]]) = [2, 4] - let list_array = return_nested_array(); - let array = array_dims(&[list_array]) - .expect("failed to initialize function array_dims") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_dims"); + let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; + let array3d_2 = array_into_list_array(array2d_2.to_owned()); + let res = + align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); + let expected = as_list_array(&array3d_1).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - &[2, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() + datafusion_common::utils::list_ndims(res[0].data_type()), + expected_dim ); } #[test] - fn test_array_ndims() { - // array_ndims([1, 2]) = 1 - let list_array = return_array(); + fn test_check_invalid_datatypes() { + let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_ndims"); + let args = [list_array.clone(), int64_array.clone()]; - assert_eq!(result, &UInt8Array::from(vec![1])); - } - - #[test] - fn test_nested_array_ndims() { - // array_ndims([[1, 2], [3, 4]]) = 2 - let list_array = return_nested_array(); - - let array = array_ndims(&[list_array]) - .expect("failed to initialize function array_ndims") - .into_array(1); - let result = - as_uint8_array(&array).expect("failed to initialize function array_ndims"); - - assert_eq!(result, &UInt8Array::from(vec![2])); - } - - fn return_array() -> ColumnarValue { - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) - } + let array = array_append(&args); - fn return_nested_array() -> ColumnarValue { - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(6))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), - ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); - - let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); - ColumnarValue::Array(result.clone()) + assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); } } diff --git a/datafusion/physical-expr/src/conditional_expressions.rs b/datafusion/physical-expr/src/conditional_expressions.rs index 09c5b382da260..a9a25ffe2ec18 100644 --- a/datafusion/physical-expr/src/conditional_expressions.rs +++ b/datafusion/physical-expr/src/conditional_expressions.rs @@ -19,17 +19,17 @@ use arrow::array::{new_null_array, Array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; /// coalesce evaluates to the first value which is not NULL pub fn coalesce(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal(format!( + return internal_err!( "coalesce was called with {} arguments. It requires at least 1.", args.len() - ))); + ); } let return_type = args[0].data_type(); @@ -54,7 +54,7 @@ pub fn coalesce(args: &[ColumnarValue]) -> Result { if value.is_null() { continue; } else { - let last_value = value.to_array_of_size(size); + let last_value = value.to_array_of_size(size)?; current_value = zip(&remainder, &last_value, current_value.as_ref())?; break; diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs index c940933b102fa..580b0ed01b6ed 100644 --- a/datafusion/physical-expr/src/crypto_expressions.rs +++ b/datafusion/physical-expr/src/crypto_expressions.rs @@ -23,11 +23,12 @@ use arrow::{ }; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; -use datafusion_common::cast::{ - as_binary_array, as_generic_binary_array, as_generic_string_array, -}; use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ + cast::{as_binary_array, as_generic_binary_array, as_generic_string_array}, + plan_err, +}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; @@ -65,9 +66,9 @@ fn digest_process( DataType::LargeBinary => { digest_algorithm.digest_binary_array::(a.as_ref()) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {digest_algorithm}", - ))), + other => internal_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ), }, ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { @@ -76,9 +77,9 @@ fn digest_process( } ScalarValue::Binary(a) | ScalarValue::LargeBinary(a) => Ok(digest_algorithm .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {digest_algorithm}", - ))), + other => internal_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ), }, } } @@ -224,9 +225,9 @@ impl FromStr for DigestAlgorithm { .map(|i| i.to_string()) .collect::>() .join(", "); - return Err(DataFusionError::Plan(format!( - "There is no built-in digest algorithm named '{name}', currently supported algorithms are: {options}", - ))); + return plan_err!( + "There is no built-in digest algorithm named '{name}', currently supported algorithms are: {options}" + ); } }) } @@ -237,11 +238,11 @@ macro_rules! define_digest_function { #[doc = $DOC] pub fn $NAME(args: &[ColumnarValue]) -> Result { if args.len() != 1 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but {} takes exactly one argument", args.len(), - DigestAlgorithm::$METHOD.to_string(), - ))); + DigestAlgorithm::$METHOD.to_string() + ); } digest_process(&args[0], DigestAlgorithm::$METHOD) } @@ -263,11 +264,11 @@ fn hex_encode>(data: T) -> String { /// computes md5 hash digest of the given input pub fn md5(args: &[ColumnarValue]) -> Result { if args.len() != 1 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but {} takes exactly one argument", args.len(), - DigestAlgorithm::Md5, - ))); + DigestAlgorithm::Md5 + ); } let value = digest_process(&args[0], DigestAlgorithm::Md5)?; // md5 requires special handling because of its unique utf8 return type @@ -283,11 +284,7 @@ pub fn md5(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { ColumnarValue::Scalar(ScalarValue::Utf8(opt.map(hex_encode::<_>))) } - _ => { - return Err(DataFusionError::Internal( - "Impossibly got invalid results from digest".into(), - )) - } + _ => return internal_err!("Impossibly got invalid results from digest"), }) } @@ -332,23 +329,21 @@ define_digest_function!( /// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512. pub fn digest(args: &[ColumnarValue]) -> Result { if args.len() != 2 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but digest takes exactly two arguments", - args.len(), - ))); + args.len() + ); } let digest_algorithm = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { method.parse::() } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function digest", - ))), + other => internal_err!("Unsupported data type {other:?} for function digest"), }, - ColumnarValue::Array(_) => Err(DataFusionError::Internal( - "Digest using dynamically decided method is not yet supported".into(), - )), + ColumnarValue::Array(_) => { + internal_err!("Digest using dynamically decided method is not yet supported") + } }?; digest_process(&args[0], digest_algorithm) } diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 32db0a7ee179e..bbeb2b0dce86b 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,11 +17,15 @@ //! DateTime expressions +use crate::datetime_expressions; +use crate::expressions::cast_column; use arrow::array::Float64Builder; use arrow::compute::cast; use arrow::{ - array::TimestampNanosecondArray, compute::kernels::temporal, datatypes::TimeUnit, - temporal_conversions::timestamp_ns_to_datetime, + array::TimestampNanosecondArray, + compute::kernels::temporal, + datatypes::TimeUnit, + temporal_conversions::{as_datetime_with_timezone, timestamp_ns_to_datetime}, }; use arrow::{ array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray}, @@ -32,19 +36,24 @@ use arrow::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, }; +use arrow_array::types::ArrowTimestampType; use arrow_array::{ - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, + timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampSecondArray, }; use chrono::prelude::*; use chrono::{Duration, Months, NaiveDate}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_generic_string_array, + as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; -use datafusion_common::{DataFusionError, Result}; -use datafusion_common::{ScalarType, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, DataFusionError, Result, ScalarType, + ScalarValue, +}; use datafusion_expr::ColumnarValue; +use std::str::FromStr; use std::sync::Arc; /// given a function `op` that maps a `&str` to a Result of an arrow native type, @@ -66,11 +75,11 @@ where F: Fn(&'a str) -> Result, { if args.len() != 1 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but {} takes exactly one argument", args.len(), - name, - ))); + name + ); } let array = as_generic_string_array::(args[0])?; @@ -100,9 +109,7 @@ where DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, ))), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {name}", - ))), + other => internal_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(a) => { @@ -113,9 +120,7 @@ where let result = a.as_ref().map(|x| (op)(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {name}" - ))), + other => internal_err!("Unsupported data type {other:?} for function {name}"), }, } } @@ -126,6 +131,10 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { } /// to_timestamp SQL function +/// +/// Note: `to_timestamp` returns `Timestamp(Nanosecond)` though its arguments are interpreted as **seconds**. The supported range for integer input is between `-9223372037` and `9223372036`. +/// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. +/// Please use `to_timestamp_seconds` for the input outside of supported bounds. pub fn to_timestamp(args: &[ColumnarValue]) -> Result { handle::( args, @@ -152,6 +161,15 @@ pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { ) } +/// to_timestamp_nanos SQL function +pub fn to_timestamp_nanos(args: &[ColumnarValue]) -> Result { + handle::( + args, + string_to_timestamp_nanos_shim, + "to_timestamp_nanos", + ) +} + /// to_timestamp_seconds SQL function pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { handle::( @@ -170,7 +188,7 @@ pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { pub fn make_now( now_ts: DateTime, ) -> impl Fn(&[ColumnarValue]) -> Result { - let now_ts = Some(now_ts.timestamp_nanos()); + let now_ts = now_ts.timestamp_nanos_opt(); move |_arg| { Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( now_ts, @@ -206,60 +224,173 @@ pub fn make_current_date( pub fn make_current_time( now_ts: DateTime, ) -> impl Fn(&[ColumnarValue]) -> Result { - let nano = Some(now_ts.timestamp_nanos() % 86400000000000); + let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); move |_arg| Ok(ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(nano))) } -fn quarter_month(date: &NaiveDateTime) -> u32 { +fn quarter_month(date: &T) -> u32 +where + T: chrono::Datelike, +{ 1 + 3 * ((date.month() - 1) / 3) } -fn date_trunc_single(granularity: &str, value: i64) -> Result { - let value = timestamp_ns_to_datetime(value) - .ok_or_else(|| { - DataFusionError::Execution(format!("Timestamp {value} out of range")) - })? - .with_nanosecond(0); +fn _date_trunc_coarse(granularity: &str, value: Option) -> Result> +where + T: chrono::Datelike + + chrono::Timelike + + std::ops::Sub + + std::marker::Copy, +{ let value = match granularity { - "second" | "millisecond" | "microsecond" => value, - "minute" => value.and_then(|d| d.with_second(0)), + "millisecond" => value, + "microsecond" => value, + "second" => value.and_then(|d| d.with_nanosecond(0)), + "minute" => value + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)), "hour" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)), "day" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)), "week" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)) .map(|d| d - Duration::seconds(60 * 60 * 24 * d.weekday() as i64)), "month" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)) .and_then(|d| d.with_day0(0)), "quarter" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)) .and_then(|d| d.with_day0(0)) .and_then(|d| d.with_month(quarter_month(&d))), "year" => value + .and_then(|d| d.with_nanosecond(0)) .and_then(|d| d.with_second(0)) .and_then(|d| d.with_minute(0)) .and_then(|d| d.with_hour(0)) .and_then(|d| d.with_day0(0)) .and_then(|d| d.with_month0(0)), unsupported => { - return Err(DataFusionError::Execution(format!( - "Unsupported date_trunc granularity: {unsupported}" - ))); + return exec_err!("Unsupported date_trunc granularity: {unsupported}"); } }; - // `with_x(0)` are infalible because `0` are always a valid - Ok(value.unwrap().timestamp_nanos()) + Ok(value) +} + +fn _date_trunc_coarse_with_tz( + granularity: &str, + value: Option>, +) -> Result> { + let value = _date_trunc_coarse::>(granularity, value)?; + Ok(value.and_then(|value| value.timestamp_nanos_opt())) +} + +fn _date_trunc_coarse_without_tz( + granularity: &str, + value: Option, +) -> Result> { + let value = _date_trunc_coarse::(granularity, value)?; + Ok(value.and_then(|value| value.timestamp_nanos_opt())) +} + +/// Tuncates the single `value`, expressed in nanoseconds since the +/// epoch, for granularities greater than 1 second, in taking into +/// account that some granularities are not uniform durations of time +/// (e.g. months are not always the same lengths, leap seconds, etc) +fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result { + let value = match tz { + Some(tz) => { + // Use chrono DateTime to clear the various fields because need to clear per timezone, + // and NaiveDateTime (ISO 8601) has no concept of timezones + let value = as_datetime_with_timezone::(value, tz) + .ok_or(DataFusionError::Execution(format!( + "Timestamp {value} out of range" + )))?; + _date_trunc_coarse_with_tz(granularity, Some(value)) + } + None => { + // Use chrono NaiveDateTime to clear the various fields, if we don't have a timezone. + let value = timestamp_ns_to_datetime(value).ok_or_else(|| { + DataFusionError::Execution(format!("Timestamp {value} out of range")) + })?; + _date_trunc_coarse_without_tz(granularity, Some(value)) + } + }?; + + // `with_x(0)` are infallible because `0` are always a valid + Ok(value.unwrap()) +} + +// truncates a single value with the given timeunit to the specified granularity +fn general_date_trunc( + tu: TimeUnit, + value: &Option, + tz: Option, + granularity: &str, +) -> Result, DataFusionError> { + let scale = match tu { + TimeUnit::Second => 1_000_000_000, + TimeUnit::Millisecond => 1_000_000, + TimeUnit::Microsecond => 1_000, + TimeUnit::Nanosecond => 1, + }; + + let Some(value) = value else { + return Ok(None); + }; + + // convert to nanoseconds + let nano = date_trunc_coarse(granularity, scale * value, tz)?; + + let result = match tu { + TimeUnit::Second => match granularity { + "minute" => Some(nano / 1_000_000_000 / 60 * 60), + _ => Some(nano / 1_000_000_000), + }, + TimeUnit::Millisecond => match granularity { + "minute" => Some(nano / 1_000_000 / 1_000 / 60 * 1_000 * 60), + "second" => Some(nano / 1_000_000 / 1_000 * 1_000), + _ => Some(nano / 1_000_000), + }, + TimeUnit::Microsecond => match granularity { + "minute" => Some(nano / 1_000 / 1_000_000 / 60 * 60 * 1_000_000), + "second" => Some(nano / 1_000 / 1_000_000 * 1_000_000), + "millisecond" => Some(nano / 1_000 / 1_000 * 1_000), + _ => Some(nano / 1_000), + }, + _ => match granularity { + "minute" => Some(nano / 1_000_000_000 / 60 * 1_000_000_000 * 60), + "second" => Some(nano / 1_000_000_000 * 1_000_000_000), + "millisecond" => Some(nano / 1_000_000 * 1_000_000), + "microsecond" => Some(nano / 1_000 * 1_000), + _ => Some(nano), + }, + }; + Ok(result) +} + +fn parse_tz(tz: &Option>) -> Result> { + tz.as_ref() + .map(|tz| { + Tz::from_str(tz).map_err(|op| { + DataFusionError::Execution(format!("failed on timezone {tz}: {:?}", op)) + }) + }) + .transpose() } /// date_trunc SQL function @@ -270,67 +401,70 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result { if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { v.to_lowercase() } else { - return Err(DataFusionError::Execution( - "Granularity of `date_trunc` must be non-null scalar Utf8".to_string(), - )); + return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); }; - let f = |x: Option| { - x.map(|x| date_trunc_single(granularity.as_str(), x)) - .transpose() - }; + fn process_array( + array: &dyn Array, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let array = as_primitive_array::(array)?; + let array = array + .iter() + .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str())) + .collect::>>()? + .with_timezone_opt(tz_opt.clone()); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn process_scalar( + v: &Option, + granularity: String, + tz_opt: &Option>, + ) -> Result { + let parsed_tz = parse_tz(tz_opt)?; + let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; + let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); + Ok(ColumnarValue::Scalar(value)) + } Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let nano = (f)(*v)?; - match granularity.as_str() { - "minute" => { - // cast to second - let second = ScalarValue::TimestampSecond( - Some(nano.unwrap() / 1_000_000_000), - tz_opt.clone(), - ); - ColumnarValue::Scalar(second) + process_scalar::(v, granularity, tz_opt)? + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { + process_scalar::(v, granularity, tz_opt)? + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { + process_scalar::(v, granularity, tz_opt)? + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { + process_scalar::(v, granularity, tz_opt)? + } + ColumnarValue::Array(array) => { + let array_type = array.data_type(); + match array_type { + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_array::(array, granularity, tz_opt)? } - "second" => { - // cast to millisecond - let mill = ScalarValue::TimestampMillisecond( - Some(nano.unwrap() / 1_000_000), - tz_opt.clone(), - ); - ColumnarValue::Scalar(mill) + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_array::(array, granularity, tz_opt)? } - "millisecond" => { - // cast to microsecond - let micro = ScalarValue::TimestampMicrosecond( - Some(nano.unwrap() / 1_000), - tz_opt.clone(), - ); - ColumnarValue::Scalar(micro) + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_array::(array, granularity, tz_opt)? } - _ => { - // cast to nanosecond - let nano = ScalarValue::TimestampNanosecond( - Some(nano.unwrap()), - tz_opt.clone(), - ); - ColumnarValue::Scalar(nano) + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_array::(array, granularity, tz_opt)? } + _ => process_array::(array, granularity, &None)?, } } - ColumnarValue::Array(array) => { - let array = as_timestamp_nanosecond_array(array)?; - let array = array - .iter() - .map(f) - .collect::>()?; - - ColumnarValue::Array(Arc::new(array)) - } _ => { - return Err(DataFusionError::Execution( - "array of `date_trunc` must be non-null scalar Utf8".to_string(), - )); + return exec_err!( + "second argument of `date_trunc` must be nanosecond timestamp scalar or array" + ); } }) } @@ -388,14 +522,14 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 }; } - bin_time.timestamp_nanos() + bin_time.timestamp_nanos_opt().unwrap() } fn to_utc_date_time(nanos: i64) -> DateTime { let secs = nanos / 1_000_000_000; let nsec = (nanos % 1_000_000_000) as u32; let date = NaiveDateTime::from_timestamp_opt(secs, nsec).unwrap(); - DateTime::::from_utc(date, Utc) + DateTime::::from_naive_utc_and_offset(date, Utc) } /// DATE_BIN sql function @@ -410,9 +544,7 @@ pub fn date_bin(args: &[ColumnarValue]) -> Result { } else if args.len() == 3 { date_bin_impl(&args[0], &args[1], &args[2]) } else { - Err(DataFusionError::Execution( - "DATE_BIN expected two or three arguments".to_string(), - )) + exec_err!("DATE_BIN expected two or three arguments") } } @@ -457,11 +589,7 @@ fn date_bin_impl( match nanos { Some(v) => Interval::Nanoseconds(v), - _ => { - return Err(DataFusionError::Execution( - "DATE_BIN stride argument is too large".to_string(), - )) - } + _ => return exec_err!("DATE_BIN stride argument is too large"), } } ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(v))) => { @@ -471,9 +599,9 @@ fn date_bin_impl( if months != 0 { // Return error if days or nanos is not zero if days != 0 || nanos != 0 { - return Err(DataFusionError::NotImplemented( - "DATE_BIN stride does not support combination of month, day and nanosecond intervals".to_string(), - )); + return not_impl_err!( + "DATE_BIN stride does not support combination of month, day and nanosecond intervals" + ); } else { Interval::Months(months as i64) } @@ -482,47 +610,41 @@ fn date_bin_impl( .num_nanoseconds(); match nanos { Some(v) => Interval::Nanoseconds(v), - _ => { - return Err(DataFusionError::Execution( - "DATE_BIN stride argument is too large".to_string(), - )) - } + _ => return exec_err!("DATE_BIN stride argument is too large"), } } } ColumnarValue::Scalar(v) => { - return Err(DataFusionError::Execution(format!( + return exec_err!( "DATE_BIN expects stride argument to be an INTERVAL but got {}", - v.get_datatype() - ))) + v.data_type() + ) } - ColumnarValue::Array(_) => return Err(DataFusionError::NotImplemented( + ColumnarValue::Array(_) => { + return not_impl_err!( "DATE_BIN only supports literal values for the stride argument, not arrays" - .to_string(), - )), + ) + } }; let origin = match origin { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => *v, ColumnarValue::Scalar(v) => { - return Err(DataFusionError::Execution(format!( + return exec_err!( "DATE_BIN expects origin argument to be a TIMESTAMP with nanosececond precision but got {}", - v.get_datatype() - ))) + v.data_type() + ) } - ColumnarValue::Array(_) => return Err(DataFusionError::NotImplemented( + ColumnarValue::Array(_) => return not_impl_err!( "DATE_BIN only supports literal values for the origin argument, not arrays" - .to_string(), - )), + ), }; let (stride, stride_fn) = stride.bin_fn(); // Return error if stride is 0 if stride == 0 { - return Err(DataFusionError::Execution( - "DATE_BIN stride must be non-zero".to_string(), - )); + return exec_err!("DATE_BIN stride must be non-zero"); } let f_nanos = |x: Option| x.map(|x| stride_fn(stride, x, origin)); @@ -565,50 +687,53 @@ fn date_bin_impl( )) } ColumnarValue::Array(array) => match array.data_type() { - DataType::Timestamp(TimeUnit::Nanosecond, _) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { let array = as_timestamp_nanosecond_array(array)? .iter() .map(f_nanos) - .collect::(); + .collect::() + .with_timezone_opt(tz_opt.clone()); ColumnarValue::Array(Arc::new(array)) } - DataType::Timestamp(TimeUnit::Microsecond, _) => { + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { let array = as_timestamp_microsecond_array(array)? .iter() .map(f_micros) - .collect::(); + .collect::() + .with_timezone_opt(tz_opt.clone()); ColumnarValue::Array(Arc::new(array)) } - DataType::Timestamp(TimeUnit::Millisecond, _) => { + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { let array = as_timestamp_millisecond_array(array)? .iter() .map(f_millis) - .collect::(); + .collect::() + .with_timezone_opt(tz_opt.clone()); ColumnarValue::Array(Arc::new(array)) } - DataType::Timestamp(TimeUnit::Second, _) => { + DataType::Timestamp(TimeUnit::Second, tz_opt) => { let array = as_timestamp_second_array(array)? .iter() .map(f_secs) - .collect::(); + .collect::() + .with_timezone_opt(tz_opt.clone()); ColumnarValue::Array(Arc::new(array)) } _ => { - return Err(DataFusionError::Execution(format!( + return exec_err!( "DATE_BIN expects source argument to be a TIMESTAMP but got {}", array.data_type() - ))) + ) } }, _ => { - return Err(DataFusionError::Execution( + return exec_err!( "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" - .to_string(), - )); + ); } }) } @@ -648,10 +773,7 @@ macro_rules! extract_date_part { .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) } }, - datatype => Err(DataFusionError::Internal(format!( - "Extract does not support datatype {:?}", - datatype - ))), + datatype => internal_err!("Extract does not support datatype {:?}", datatype), } }; } @@ -659,25 +781,21 @@ macro_rules! extract_date_part { /// DATE_PART SQL function pub fn date_part(args: &[ColumnarValue]) -> Result { if args.len() != 2 { - return Err(DataFusionError::Execution( - "Expected two arguments in DATE_PART".to_string(), - )); + return exec_err!("Expected two arguments in DATE_PART"); } let (date_part, array) = (&args[0], &args[1]); let date_part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { v } else { - return Err(DataFusionError::Execution( - "First argument of `DATE_PART` must be non-null scalar Utf8".to_string(), - )); + return exec_err!("First argument of `DATE_PART` must be non-null scalar Utf8"); }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); let array = match array { ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; let arr = match date_part.to_lowercase().as_str() { @@ -695,9 +813,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { "microsecond" => extract_date_part!(&array, micros), "nanosecond" => extract_date_part!(&array, nanos), "epoch" => extract_date_part!(&array, epoch), - _ => Err(DataFusionError::Execution(format!( - "Date part '{date_part}' not supported" - ))), + _ => exec_err!("Date part '{date_part}' not supported"), }?; Ok(if is_scalar { @@ -778,21 +894,173 @@ where } } } - _ => { - return Err(DataFusionError::Internal(format!( - "Can not convert {:?} to epoch", - array.data_type() - ))) - } + _ => return internal_err!("Can not convert {:?} to epoch", array.data_type()), } Ok(b.finish()) } +/// to_timestammp() SQL function implementation +pub fn to_timestamp_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp function requires 1 arguments, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => cast_column( + &cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)?, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Float64 => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp", + other + ) + } + } +} + +/// to_timestamp_millis() SQL function implementation +pub fn to_timestamp_millis_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_millis function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Millisecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_millis(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_millis", + other + ) + } + } +} + +/// to_timestamp_micros() SQL function implementation +pub fn to_timestamp_micros_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_micros function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Microsecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_micros(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_micros", + other + ) + } + } +} + +/// to_timestamp_nanos() SQL function implementation +pub fn to_timestamp_nanos_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_nanos function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => cast_column( + &args[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + ), + DataType::Utf8 => datetime_expressions::to_timestamp_nanos(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_nanos", + other + ) + } + } +} + +/// to_timestamp_seconds() SQL function implementation +pub fn to_timestamp_seconds_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "to_timestamp_seconds function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 | DataType::Timestamp(_, None) => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + DataType::Utf8 => datetime_expressions::to_timestamp_seconds(args), + other => { + internal_err!( + "Unsupported data type {:?} for function to_timestamp_seconds", + other + ) + } + } +} + +/// from_unixtime() SQL function implementation +pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!( + "from_unixtime function requires 1 argument, got {}", + args.len() + ); + } + + match args[0].data_type() { + DataType::Int64 => { + cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + } + other => { + internal_err!( + "Unsupported data type {:?} for function from_unixtime", + other + ) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::{ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder}; + use arrow::array::{ + as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, StringBuilder, + }; use super::*; @@ -803,7 +1071,7 @@ mod tests { let mut string_builder = StringBuilder::with_capacity(2, 1024); let mut ts_builder = TimestampNanosecondArray::builder(2); - string_builder.append_value("2020-09-08T13:42:29.190855Z"); + string_builder.append_value("2020-09-08T13:42:29.190855"); ts_builder.append_value(1599572549190855000); string_builder.append_null(); @@ -913,11 +1181,130 @@ mod tests { cases.iter().for_each(|(original, granularity, expected)| { let left = string_to_timestamp_nanos(original).unwrap(); let right = string_to_timestamp_nanos(expected).unwrap(); - let result = date_trunc_single(granularity, left).unwrap(); + let result = date_trunc_coarse(granularity, left, None).unwrap(); assert_eq!(result, right, "{original} = {expected}"); }); } + #[test] + fn test_date_trunc_timezones() { + let cases = vec![ + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("+00".into()), + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + None, + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("-02".into()), + vec![ + "2020-09-07T02:00:00Z", + "2020-09-07T02:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T02:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T01:00:00+05", + "2020-09-08T02:00:00+05", + "2020-09-08T03:00:00+05", + "2020-09-08T04:00:00+05", + ], + Some("+05".into()), + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T01:00:00+08", + "2020-09-08T02:00:00+08", + "2020-09-08T03:00:00+08", + "2020-09-08T04:00:00+08", + ], + Some("+08".into()), + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + ], + ), + ]; + + cases.iter().for_each(|(original, tz_opt, expected)| { + let input = original + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let result = date_trunc(&[ + ColumnarValue::Scalar(ScalarValue::from("day")), + ColumnarValue::Array(Arc::new(input)), + ]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + ); + let left = as_primitive_array::(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } + #[test] fn test_date_bin_single() { use chrono::Duration; @@ -1017,7 +1404,7 @@ mod tests { let res = date_bin(&[ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1)))]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); @@ -1028,7 +1415,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" ); @@ -1039,7 +1426,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); @@ -1050,7 +1437,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); @@ -1061,7 +1448,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); @@ -1072,7 +1459,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); @@ -1083,7 +1470,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosececond precision but got Timestamp(Microsecond, None)" ); @@ -1102,7 +1489,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" ); @@ -1114,11 +1501,141 @@ mod tests { ColumnarValue::Array(timestamps), ]); assert_eq!( - res.err().unwrap().to_string(), + res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" ); } + #[test] + fn test_date_bin_timezones() { + let cases = vec![ + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("+00".into()), + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + None, + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T01:00:00Z", + "2020-09-08T02:00:00Z", + "2020-09-08T03:00:00Z", + "2020-09-08T04:00:00Z", + ], + Some("-02".into()), + "1970-01-01T00:00:00Z", + vec![ + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + "2020-09-08T00:00:00Z", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T01:00:00+05", + "2020-09-08T02:00:00+05", + "2020-09-08T03:00:00+05", + "2020-09-08T04:00:00+05", + ], + Some("+05".into()), + "1970-01-01T00:00:00+05", + vec![ + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + "2020-09-08T00:00:00+05", + ], + ), + ( + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T01:00:00+08", + "2020-09-08T02:00:00+08", + "2020-09-08T03:00:00+08", + "2020-09-08T04:00:00+08", + ], + Some("+08".into()), + "1970-01-01T00:00:00+08", + vec![ + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + "2020-09-08T00:00:00+08", + ], + ), + ]; + + cases + .iter() + .for_each(|(original, tz_opt, origin, expected)| { + let input = original + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::() + .with_timezone_opt(tz_opt.clone()); + let result = date_bin(&[ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), + tz_opt.clone(), + )), + ]) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + ); + let left = as_primitive_array::(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } + #[test] fn to_timestamp_invalid_input_type() -> Result<()> { // pass the wrong type of input array to to_timestamp and test diff --git a/datafusion/physical-expr/src/encoding_expressions.rs b/datafusion/physical-expr/src/encoding_expressions.rs new file mode 100644 index 0000000000000..b74310485fb7e --- /dev/null +++ b/datafusion/physical-expr/src/encoding_expressions.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Encoding expressions + +use arrow::{ + array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait, StringArray}, + datatypes::DataType, +}; +use base64::{engine::general_purpose, Engine as _}; +use datafusion_common::ScalarValue; +use datafusion_common::{ + cast::{as_generic_binary_array, as_generic_string_array}, + internal_err, not_impl_err, plan_err, +}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; +use std::{fmt, str::FromStr}; + +#[derive(Debug, Copy, Clone)] +enum Encoding { + Base64, + Hex, +} + +fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { + match value { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => encoding.encode_utf8_array::(a.as_ref()), + DataType::LargeUtf8 => encoding.encode_utf8_array::(a.as_ref()), + DataType::Binary => encoding.encode_binary_array::(a.as_ref()), + DataType::LargeBinary => encoding.encode_binary_array::(a.as_ref()), + other => internal_err!( + "Unsupported data type {other:?} for function encode({encoding})" + ), + }, + ColumnarValue::Scalar(scalar) => { + match scalar { + ScalarValue::Utf8(a) => { + Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + } + ScalarValue::LargeUtf8(a) => Ok(encoding + .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Binary(a) => Ok( + encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) + ), + ScalarValue::LargeBinary(a) => Ok(encoding + .encode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), + other => internal_err!( + "Unsupported data type {other:?} for function encode({encoding})" + ), + } + } + } +} + +fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result { + match value { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => encoding.decode_utf8_array::(a.as_ref()), + DataType::LargeUtf8 => encoding.decode_utf8_array::(a.as_ref()), + DataType::Binary => encoding.decode_binary_array::(a.as_ref()), + DataType::LargeBinary => encoding.decode_binary_array::(a.as_ref()), + other => internal_err!( + "Unsupported data type {other:?} for function decode({encoding})" + ), + }, + ColumnarValue::Scalar(scalar) => { + match scalar { + ScalarValue::Utf8(a) => { + encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) + } + ScalarValue::LargeUtf8(a) => encoding + .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), + ScalarValue::Binary(a) => { + encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) + } + ScalarValue::LargeBinary(a) => encoding + .decode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice())), + other => internal_err!( + "Unsupported data type {other:?} for function decode({encoding})" + ), + } + } + } +} + +fn hex_encode(input: &[u8]) -> String { + hex::encode(input) +} + +fn base64_encode(input: &[u8]) -> String { + general_purpose::STANDARD_NO_PAD.encode(input) +} + +fn hex_decode(input: &[u8]) -> Result> { + hex::decode(input).map_err(|e| { + DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) + }) +} + +fn base64_decode(input: &[u8]) -> Result> { + general_purpose::STANDARD_NO_PAD.decode(input).map_err(|e| { + DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + }) +} + +macro_rules! encode_to_array { + ($METHOD: ident, $INPUT:expr) => {{ + let utf8_array: StringArray = $INPUT + .iter() + .map(|x| x.map(|x| $METHOD(x.as_ref()))) + .collect(); + Arc::new(utf8_array) + }}; +} + +macro_rules! decode_to_array { + ($METHOD: ident, $INPUT:expr) => {{ + let binary_array: BinaryArray = $INPUT + .iter() + .map(|x| x.map(|x| $METHOD(x.as_ref())).transpose()) + .collect::>()?; + Arc::new(binary_array) + }}; +} + +impl Encoding { + fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { + ColumnarValue::Scalar(match self { + Self::Base64 => ScalarValue::Utf8( + value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), + ), + Self::Hex => ScalarValue::Utf8(value.map(hex::encode)), + }) + } + + fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue { + ColumnarValue::Scalar(match self { + Self::Base64 => ScalarValue::LargeUtf8( + value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), + ), + Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)), + }) + } + + fn encode_binary_array(self, value: &dyn Array) -> Result + where + T: OffsetSizeTrait, + { + let input_value = as_generic_binary_array::(value)?; + let array: ArrayRef = match self { + Self::Base64 => encode_to_array!(base64_encode, input_value), + Self::Hex => encode_to_array!(hex_encode, input_value), + }; + Ok(ColumnarValue::Array(array)) + } + + fn encode_utf8_array(self, value: &dyn Array) -> Result + where + T: OffsetSizeTrait, + { + let input_value = as_generic_string_array::(value)?; + let array: ArrayRef = match self { + Self::Base64 => encode_to_array!(base64_encode, input_value), + Self::Hex => encode_to_array!(hex_encode, input_value), + }; + Ok(ColumnarValue::Array(array)) + } + + fn decode_scalar(self, value: Option<&[u8]>) -> Result { + let value = match value { + Some(value) => value, + None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))), + }; + + let out = match self { + Self::Base64 => { + general_purpose::STANDARD_NO_PAD + .decode(value) + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to decode value using base64: {}", + e + )) + })? + } + Self::Hex => hex::decode(value).map_err(|e| { + DataFusionError::Internal(format!( + "Failed to decode value using hex: {}", + e + )) + })?, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out)))) + } + + fn decode_large_scalar(self, value: Option<&[u8]>) -> Result { + let value = match value { + Some(value) => value, + None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))), + }; + + let out = match self { + Self::Base64 => { + general_purpose::STANDARD_NO_PAD + .decode(value) + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to decode value using base64: {}", + e + )) + })? + } + Self::Hex => hex::decode(value).map_err(|e| { + DataFusionError::Internal(format!( + "Failed to decode value using hex: {}", + e + )) + })?, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out)))) + } + + fn decode_binary_array(self, value: &dyn Array) -> Result + where + T: OffsetSizeTrait, + { + let input_value = as_generic_binary_array::(value)?; + let array: ArrayRef = match self { + Self::Base64 => decode_to_array!(base64_decode, input_value), + Self::Hex => decode_to_array!(hex_decode, input_value), + }; + Ok(ColumnarValue::Array(array)) + } + + fn decode_utf8_array(self, value: &dyn Array) -> Result + where + T: OffsetSizeTrait, + { + let input_value = as_generic_string_array::(value)?; + let array: ArrayRef = match self { + Self::Base64 => decode_to_array!(base64_decode, input_value), + Self::Hex => decode_to_array!(hex_decode, input_value), + }; + Ok(ColumnarValue::Array(array)) + } +} + +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", format!("{self:?}").to_lowercase()) + } +} + +impl FromStr for Encoding { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "base64" => Self::Base64, + "hex" => Self::Hex, + _ => { + let options = [Self::Base64, Self::Hex] + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", "); + return plan_err!( + "There is no built-in encoding named '{name}', currently supported encodings are: {options}" + ); + } + }) + } +} + +/// Encodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Second argument is the encoding to use. +/// Standard encodings are base64 and hex. +pub fn encode(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return internal_err!( + "{:?} args were supplied but encode takes exactly two arguments", + args.len() + ); + } + let encoding = match &args[1] { + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + method.parse::() + } + _ => not_impl_err!( + "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" + ), + }, + ColumnarValue::Array(_) => not_impl_err!( + "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" + ), + }?; + encode_process(&args[0], encoding) +} + +/// Decodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Second argument is the encoding to use. +/// Standard encodings are base64 and hex. +pub fn decode(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return internal_err!( + "{:?} args were supplied but decode takes exactly two arguments", + args.len() + ); + } + let encoding = match &args[1] { + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + method.parse::() + } + _ => not_impl_err!( + "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" + ), + }, + ColumnarValue::Array(_) => not_impl_err!( + "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" + ), + }?; + decode_process(&args[0], encoding) +} diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index 78279851bba5e..4a562f4ef1012 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -15,592 +15,3701 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::{BinaryExpr, Column}; +use std::hash::Hash; +use std::sync::Arc; + +use crate::expressions::{Column, Literal}; +use crate::physical_expr::deduplicate_physical_exprs; +use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ - normalize_expr_with_equivalence_properties, LexOrdering, PhysicalExpr, - PhysicalSortExpr, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; - use arrow::datatypes::SchemaRef; +use arrow_schema::SortOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; -use std::collections::HashMap; -use std::sync::Arc; +use indexmap::IndexSet; -/// Represents a collection of [`EquivalentClass`] (equivalences -/// between columns in relations) +/// An `EquivalenceClass` is a set of [`Arc`]s that are known +/// to have the same value for all tuples in a relation. These are generated by +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. /// -/// This is used to represent both: -/// -/// 1. Equality conditions (like `A=B`), when `T` = [`Column`] -/// 2. Ordering (like `A ASC = B ASC`), when `T` = [`PhysicalSortExpr`] +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. #[derive(Debug, Clone)] -pub struct EquivalenceProperties { - classes: Vec>, - schema: SchemaRef, +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, } -impl EquivalenceProperties { - pub fn new(schema: SchemaRef) -> Self { - EquivalenceProperties { - classes: vec![], - schema, - } +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) } +} - /// return the set of equivalences - pub fn classes(&self) -> &[EquivalentClass] { - &self.classes +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } } - pub fn schema(&self) -> SchemaRef { - self.schema.clone() + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } } - /// Add the [`EquivalentClass`] from `iter` to this list - pub fn extend>>(&mut self, iter: I) { - for ec in iter { - self.classes.push(ec) - } + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs } - /// Adds new equal conditions into the EquivalenceProperties. New equal - /// conditions usually come from equality predicates in a join/filter. - pub fn add_equal_conditions(&mut self, new_conditions: (&T, &T)) { - let mut idx1: Option = None; - let mut idx2: Option = None; - for (idx, class) in self.classes.iter_mut().enumerate() { - let contains_first = class.contains(new_conditions.0); - let contains_second = class.contains(new_conditions.1); - match (contains_first, contains_second) { - (true, false) => { - class.insert(new_conditions.1.clone()); - idx1 = Some(idx); - } - (false, true) => { - class.insert(new_conditions.0.clone()); - idx2 = Some(idx); - } - (true, true) => { - idx1 = Some(idx); - idx2 = Some(idx); - break; - } - (false, false) => {} - } - } + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } - match (idx1, idx2) { - (Some(idx_1), Some(idx_2)) if idx_1 != idx_2 => { - // need to merge the two existing EquivalentClasses - let second_eq_class = self.classes.get(idx_2).unwrap().clone(); - let first_eq_class = self.classes.get_mut(idx_1).unwrap(); - for prop in second_eq_class.iter() { - if !first_eq_class.contains(prop) { - first_eq_class.insert(prop.clone()); - } - } - self.classes.remove(idx_2); - } - (None, None) => { - // adding new pairs - self.classes.push(EquivalentClass::::new( - new_conditions.0.clone(), - vec![new_conditions.1.clone()], - )); - } - _ => {} + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); } } -} -/// Remove duplicates inside the `in_data` vector, returned vector would consist of unique entries -fn deduplicate_vector(in_data: Vec) -> Vec { - let mut result = vec![]; - for elem in in_data { - if !result.contains(&elem) { - result.push(elem); + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); } } - result -} -/// Find the position of `entry` inside `in_data`, if `entry` is not found return `None`. -fn get_entry_position(in_data: &[T], entry: &T) -> Option { - in_data.iter().position(|item| item.eq(entry)) -} + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } -/// Remove `entry` for the `in_data`, returns `true` if removal is successful (e.g `entry` is indeed in the `in_data`) -/// Otherwise return `false` -fn remove_from_vec(in_data: &mut Vec, entry: &T) -> bool { - if let Some(idx) = get_entry_position(in_data, entry) { - in_data.remove(idx); - true - } else { - false + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) } -} -// Helper function to calculate column info recursively -fn get_column_indices_helper( - indices: &mut Vec<(usize, String)>, - expr: &Arc, -) { - if let Some(col) = expr.as_any().downcast_ref::() { - indices.push((col.index(), col.name().to_string())) - } else if let Some(binary_expr) = expr.as_any().downcast_ref::() { - get_column_indices_helper(indices, binary_expr.left()); - get_column_indices_helper(indices, binary_expr.right()); - }; -} + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } -/// Get index and name of each column that is in the expression (Can return multiple entries for `BinaryExpr`s) -fn get_column_indices(expr: &Arc) -> Vec<(usize, String)> { - let mut result = vec![]; - get_column_indices_helper(&mut result, expr); - result + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } } -/// `OrderingEquivalenceProperties` keeps track of columns that describe the -/// global ordering of the schema. These columns are not necessarily same; e.g. -/// ```text -/// ┌-------┐ -/// | a | b | -/// |---|---| -/// | 1 | 9 | -/// | 2 | 8 | -/// | 3 | 7 | -/// | 5 | 5 | -/// └---┴---┘ -/// ``` -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `OrderingEquivalenceProperties`, we can keep track of these equivalences -/// and treat `a ASC` and `b DESC` as the same ordering requirement. -pub type OrderingEquivalenceProperties = EquivalenceProperties; - -/// EquivalentClass is a set of [`Column`]s or [`PhysicalSortExpr`]s that are known -/// to have the same value in all tuples in a relation. `EquivalentClass` -/// is generated by equality predicates, typically equijoin conditions and equality -/// conditions in filters. `EquivalentClass` is generated by the -/// `ROW_NUMBER` window function. +/// Stores the mapping between source expressions and target expressions for a +/// projection. #[derive(Debug, Clone)] -pub struct EquivalentClass { - /// First element in the EquivalentClass - head: T, - /// Other equal columns - others: Vec, +pub struct ProjectionMapping { + /// Mapping between source expressions and target expressions. + /// Vector indices correspond to the indices after projection. + map: Vec<(Arc, Arc)>, } -impl EquivalentClass { - pub fn new(head: T, others: Vec) -> EquivalentClass { - let others = deduplicate_vector(others); - EquivalentClass { head, others } - } - - pub fn head(&self) -> &T { - &self.head +impl ProjectionMapping { + /// Constructs the mapping between a projection's input and output + /// expressions. + /// + /// For example, given the input projection expressions (`a + b`, `c + d`) + /// and an output schema with two columns `"c + d"` and `"a + b"`, the + /// projection mapping would be: + /// + /// ```text + /// [0]: (c + d, col("c + d")) + /// [1]: (a + b, col("a + b")) + /// ``` + /// + /// where `col("c + d")` means the column named `"c + d"`. + pub fn try_new( + expr: &[(Arc, String)], + input_schema: &SchemaRef, + ) -> Result { + // Construct a map from the input expressions to the output expression of the projection: + expr.iter() + .enumerate() + .map(|(expr_idx, (expression, name))| { + let target_expr = Arc::new(Column::new(name, expr_idx)) as _; + expression + .clone() + .transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_input_field = input_schema.field(idx); + let matching_input_column = + Column::new(matching_input_field.name(), idx); + Ok(Transformed::Yes(Arc::new(matching_input_column))) + } + None => Ok(Transformed::No(e)), + }) + .map(|source_expr| (source_expr, target_expr)) + }) + .collect::>>() + .map(|map| Self { map }) } - pub fn others(&self) -> &[T] { - &self.others + /// Iterate over pairs of (source, target) expressions + pub fn iter( + &self, + ) -> impl Iterator, Arc)> + '_ { + self.map.iter() } - pub fn contains(&self, col: &T) -> bool { - self.head == *col || self.others.contains(col) + /// This function returns the target expression for a given source expression. + /// + /// # Arguments + /// + /// * `expr` - Source physical expression. + /// + /// # Returns + /// + /// An `Option` containing the target for the given source expression, + /// where a `None` value means that `expr` is not inside the mapping. + pub fn target_expr( + &self, + expr: &Arc, + ) -> Option> { + self.map + .iter() + .find(|(source, _)| source.eq(expr)) + .map(|(_, target)| target.clone()) } +} - pub fn insert(&mut self, col: T) -> bool { - if self.head != col && !self.others.contains(&col) { - self.others.push(col); - true - } else { - false - } - } +/// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each +/// class represents a distinct equivalence class in a relation. +#[derive(Debug, Clone)] +pub struct EquivalenceGroup { + classes: Vec, +} - pub fn remove(&mut self, col: &T) -> bool { - let removed = remove_from_vec(&mut self.others, col); - // If we are removing the head, shift others so that its first entry becomes the new head. - if !removed && *col == self.head { - let one_col = self.others.first().cloned(); - if let Some(col) = one_col { - let removed = remove_from_vec(&mut self.others, &col); - self.head = col; - removed - } else { - false - } - } else { - removed - } +impl EquivalenceGroup { + /// Creates an empty equivalence group. + fn empty() -> Self { + Self { classes: vec![] } } - pub fn iter(&self) -> impl Iterator { - std::iter::once(&self.head).chain(self.others.iter()) + /// Creates an equivalence group from the given equivalence classes. + fn new(classes: Vec) -> Self { + let mut result = Self { classes }; + result.remove_redundant_entries(); + result } - pub fn len(&self) -> usize { - self.others.len() + 1 + /// Returns how many equivalence classes there are in this group. + fn len(&self) -> usize { + self.classes.len() } + /// Checks whether this equivalence group is empty. pub fn is_empty(&self) -> bool { self.len() == 0 } -} -/// `LexOrdering` stores the lexicographical ordering for a schema. -/// OrderingEquivalentClass keeps track of different alternative orderings than can -/// describe the schema. -/// For instance, for the table below -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| -/// both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the ordering of the table. -/// For this case, we say that `vec![a ASC, b ASC]`, and `vec![c DESC, d ASC]` are ordering equivalent. -pub type OrderingEquivalentClass = EquivalentClass; - -impl OrderingEquivalentClass { - /// This function extends ordering equivalences with alias information. - /// For instance, assume column a and b are aliases, - /// and column (a ASC), (c DESC) are ordering equivalent. We append (b ASC) to ordering equivalence, - /// since b is alias of colum a. After this function (a ASC), (c DESC), (b ASC) would be ordering equivalent. - fn update_with_aliases(&mut self, columns_map: &HashMap>) { - for (column, columns) in columns_map { - let col_expr = Arc::new(column.clone()) as Arc; - let mut to_insert = vec![]; - for ordering in std::iter::once(&self.head).chain(self.others.iter()) { - for (idx, item) in ordering.iter().enumerate() { - if item.expr.eq(&col_expr) { - for col in columns { - let col_expr = Arc::new(col.clone()) as Arc; - let mut normalized = self.head.clone(); - // Change the corresponding entry in the head with the alias column: - let entry = &mut normalized[idx]; - (entry.expr, entry.options) = (col_expr, item.options); - to_insert.push(normalized); - } + /// Returns an iterator over the equivalence classes in this group. + pub fn iter(&self) -> impl Iterator { + self.classes.iter() + } + + /// Adds the equality `left` = `right` to this equivalence group. + /// New equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. + fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + let mut first_class = None; + let mut second_class = None; + for (idx, cls) in self.classes.iter().enumerate() { + if cls.contains(left) { + first_class = Some(idx); + } + if cls.contains(right) { + second_class = Some(idx); + } + } + match (first_class, second_class) { + (Some(mut first_idx), Some(mut second_idx)) => { + // If the given left and right sides belong to different classes, + // we should unify/bridge these classes. + if first_idx != second_idx { + // By convention, make sure `second_idx` is larger than `first_idx`. + if first_idx > second_idx { + (first_idx, second_idx) = (second_idx, first_idx); } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.classes.swap_remove(second_idx); + self.classes[first_idx].extend(other_class); } } - for items in to_insert { - self.insert(items); + (Some(group_idx), None) => { + // Right side is new, extend left side's class: + self.classes[group_idx].push(right.clone()); + } + (None, Some(group_idx)) => { + // Left side is new, extend right side's class: + self.classes[group_idx].push(left.clone()); + } + (None, None) => { + // None of the expressions is among existing classes. + // Create a new equivalence class and extend the group. + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); } } } -} -/// This is a builder object facilitating incremental construction -/// for ordering equivalences. -pub struct OrderingEquivalenceBuilder { - eq_properties: EquivalenceProperties, - ordering_eq_properties: OrderingEquivalenceProperties, - existing_ordering: Vec, -} + /// Removes redundant entries from this group. + fn remove_redundant_entries(&mut self) { + // Remove duplicate entries from each equivalence class: + self.classes.retain_mut(|cls| { + // Keep groups that have at least two entries as singleton class is + // meaningless (i.e. it contains no non-trivial information): + cls.len() > 1 + }); + // Unify/bridge groups that have common expressions: + self.bridge_classes() + } -impl OrderingEquivalenceBuilder { - pub fn new(schema: SchemaRef) -> Self { - let eq_properties = EquivalenceProperties::new(schema.clone()); - let ordering_eq_properties = OrderingEquivalenceProperties::new(schema); - Self { - eq_properties, - ordering_eq_properties, - existing_ordering: vec![], + /// This utility function unifies/bridges classes that have common expressions. + /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. + /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all + /// equal and belong to one class. This utility converts merges such classes. + fn bridge_classes(&mut self) { + let mut idx = 0; + while idx < self.classes.len() { + let mut next_idx = idx + 1; + let start_size = self.classes[idx].len(); + while next_idx < self.classes.len() { + if self.classes[idx].contains_any(&self.classes[next_idx]) { + let extension = self.classes.swap_remove(next_idx); + self.classes[idx].extend(extension); + } else { + next_idx += 1; + } + } + if self.classes[idx].len() > start_size { + continue; + } + idx += 1; } } - pub fn extend( - mut self, - new_ordering_eq_properties: OrderingEquivalenceProperties, - ) -> Self { - self.ordering_eq_properties - .extend(new_ordering_eq_properties.classes().iter().cloned()); - self + /// Extends this equivalence group with the `other` equivalence group. + fn extend(&mut self, other: Self) { + self.classes.extend(other.classes); + self.remove_redundant_entries(); } - pub fn with_existing_ordering( - mut self, - existing_ordering: Option>, - ) -> Self { - if let Some(existing_ordering) = existing_ordering { - self.existing_ordering = existing_ordering; - } - self + /// Normalizes the given physical expression according to this group. + /// The expression is replaced with the first expression in the equivalence + /// class it matches with (if any). + pub fn normalize_expr(&self, expr: Arc) -> Arc { + expr.clone() + .transform(&|expr| { + for cls in self.iter() { + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + } + } + Ok(Transformed::No(expr)) + }) + .unwrap_or(expr) } - pub fn with_equivalences(mut self, new_eq_properties: EquivalenceProperties) -> Self { - self.eq_properties = new_eq_properties; - self + /// Normalizes the given sort expression according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the sort expression as is. + pub fn normalize_sort_expr( + &self, + mut sort_expr: PhysicalSortExpr, + ) -> PhysicalSortExpr { + sort_expr.expr = self.normalize_expr(sort_expr.expr); + sort_expr } - pub fn add_equal_conditions( - &mut self, - new_equivalent_ordering: Vec, - ) { - let mut normalized_out_ordering = vec![]; - for item in &self.existing_ordering { - // To account for ordering equivalences, first normalize the expression: - let normalized = normalize_expr_with_equivalence_properties( - item.expr.clone(), - self.eq_properties.classes(), - ); - normalized_out_ordering.push(PhysicalSortExpr { - expr: normalized, - options: item.options, - }); - } - // If there is an existing ordering, add new ordering as an equivalence: - if !normalized_out_ordering.is_empty() { - self.ordering_eq_properties.add_equal_conditions(( - &normalized_out_ordering, - &new_equivalent_ordering, - )); - } + /// Normalizes the given sort requirement according to this group. + /// The underlying physical expression is replaced with the first expression + /// in the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, returns + /// the given sort requirement as is. + pub fn normalize_sort_requirement( + &self, + mut sort_requirement: PhysicalSortRequirement, + ) -> PhysicalSortRequirement { + sort_requirement.expr = self.normalize_expr(sort_requirement.expr); + sort_requirement } - pub fn build(self) -> OrderingEquivalenceProperties { - self.ordering_eq_properties + /// This function applies the `normalize_expr` function for all expressions + /// in `exprs` and returns the corresponding normalized physical expressions. + pub fn normalize_exprs( + &self, + exprs: impl IntoIterator>, + ) -> Vec> { + exprs + .into_iter() + .map(|expr| self.normalize_expr(expr)) + .collect() } -} -/// This function applies the given projection to the given equivalence -/// properties to compute the resulting (projected) equivalence properties; e.g. -/// 1) Adding an alias, which can introduce additional equivalence properties, -/// as in Projection(a, a as a1, a as a2). -/// 2) Truncate the [`EquivalentClass`]es that are not in the output schema. -pub fn project_equivalence_properties( - input_eq: EquivalenceProperties, - alias_map: &HashMap>, - output_eq: &mut EquivalenceProperties, -) { - let mut eq_classes = input_eq.classes().to_vec(); - for (column, columns) in alias_map { - let mut find_match = false; - for class in eq_classes.iter_mut() { - if class.contains(column) { - for col in columns { - class.insert(col.clone()); + /// This function applies the `normalize_sort_expr` function for all sort + /// expressions in `sort_exprs` and returns the corresponding normalized + /// sort expressions. + pub fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// This function applies the `normalize_sort_requirement` function for all + /// requirements in `sort_reqs` and returns the corresponding normalized + /// sort requirements. + pub fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + collapse_lex_req( + sort_reqs + .iter() + .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) + .collect(), + ) + } + + /// Projects `expr` according to the given projection mapping. + /// If the resulting expression is invalid after projection, returns `None`. + fn project_expr( + &self, + mapping: &ProjectionMapping, + expr: &Arc, + ) -> Option> { + // First, we try to project expressions with an exact match. If we are + // unable to do this, we consult equivalence classes. + if let Some(target) = mapping.target_expr(expr) { + // If we match the source, we can project directly: + return Some(target); + } else { + // If the given expression is not inside the mapping, try to project + // expressions considering the equivalence classes. + for (source, target) in mapping.iter() { + // If we match an equivalent expression to `source`, then we can + // project. For example, if we have the mapping `(a as a1, a + c)` + // and the equivalence class `(a, b)`, expression `b` projects to `a1`. + if self + .get_equivalence_class(source) + .map_or(false, |group| group.contains(expr)) + { + return Some(target.clone()); } - find_match = true; - break; } } - if !find_match { - eq_classes.push(EquivalentClass::new(column.clone(), columns.clone())); + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // Leaf expression should be inside mapping. + return None; } + children + .into_iter() + .map(|child| self.project_expr(mapping, &child)) + .collect::>>() + .map(|children| expr.clone().with_new_children(children).unwrap()) } - // Prune columns that are no longer in the schema from equivalences. - let schema = output_eq.schema(); - let fields = schema.fields(); - for class in eq_classes.iter_mut() { - let columns_to_remove = class + /// Projects `ordering` according to the given projection mapping. + /// If the resulting ordering is invalid after projection, returns `None`. + fn project_ordering( + &self, + mapping: &ProjectionMapping, + ordering: LexOrderingRef, + ) -> Option { + // If any sort expression is invalid after projection, rest of the + // ordering shouldn't be projected either. For example, if input ordering + // is [a ASC, b ASC, c ASC], and column b is not valid after projection, + // the result should be [a ASC], not [a ASC, c ASC], even if column c is + // valid after projection. + let result = ordering .iter() - .filter(|column| { - let idx = column.index(); - idx >= fields.len() || fields[idx].name() != column.name() + .map_while(|sort_expr| { + self.project_expr(mapping, &sort_expr.expr) + .map(|expr| PhysicalSortExpr { + expr, + options: sort_expr.options, + }) }) - .cloned() .collect::>(); - for column in columns_to_remove { - class.remove(&column); + (!result.is_empty()).then_some(result) + } + + /// Projects this equivalence group according to the given projection mapping. + pub fn project(&self, mapping: &ProjectionMapping) -> Self { + let projected_classes = self.iter().filter_map(|cls| { + let new_class = cls + .iter() + .filter_map(|expr| self.project_expr(mapping, expr)) + .collect::>(); + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + }); + // TODO: Convert the algorithm below to a version that uses `HashMap`. + // once `Arc` can be stored in `HashMap`. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut new_classes = vec![]; + for (source, target) in mapping.iter() { + if new_classes.is_empty() { + new_classes.push((source, vec![target.clone()])); + } + if let Some((_, values)) = + new_classes.iter_mut().find(|(key, _)| key.eq(source)) + { + if !physical_exprs_contains(values, target) { + values.push(target.clone()); + } + } } + // Only add equivalence classes with at least two members as singleton + // equivalence classes are meaningless. + let new_classes = new_classes + .into_iter() + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + + let classes = projected_classes.chain(new_classes).collect(); + Self::new(classes) + } + + /// Returns the equivalence class containing `expr`. If no equivalence class + /// contains `expr`, returns `None`. + fn get_equivalence_class( + &self, + expr: &Arc, + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) } - eq_classes.retain(|props| props.len() > 1); - output_eq.extend(eq_classes); + /// Combine equivalence groups of the given join children. + pub fn join( + &self, + right_equivalences: &Self, + join_type: &JoinType, + left_size: usize, + on: &[(Column, Column)], + ) -> Self { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let mut result = Self::new( + self.iter() + .cloned() + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) + .collect(), + ); + // In we have an inner join, expressions in the "on" condition + // are equal in the resulting table. + if join_type == &JoinType::Inner { + for (lhs, rhs) in on.iter() { + let index = rhs.index() + left_size; + let new_lhs = Arc::new(lhs.clone()) as _; + let new_rhs = Arc::new(Column::new(rhs.name(), index)) as _; + result.add_equal_conditions(&new_lhs, &new_rhs); + } + } + result + } + JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + } + } } -/// This function applies the given projection to the given ordering -/// equivalence properties to compute the resulting (projected) ordering -/// equivalence properties; e.g. -/// 1) Adding an alias, which can introduce additional ordering equivalence -/// properties, as in Projection(a, a as a1, a as a2) extends global ordering -/// of a to a1 and a2. -/// 2) Truncate the [`OrderingEquivalentClass`]es that are not in the output schema. -pub fn project_ordering_equivalence_properties( - input_eq: OrderingEquivalenceProperties, - columns_map: &HashMap>, - output_eq: &mut OrderingEquivalenceProperties, -) { - let mut eq_classes = input_eq.classes().to_vec(); - for class in eq_classes.iter_mut() { - class.update_with_aliases(columns_map); +/// This function constructs a duplicate-free `LexOrderingReq` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. +pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); + } } + output +} - // Prune columns that no longer is in the schema from from the OrderingEquivalenceProperties. - let schema = output_eq.schema(); - let fields = schema.fields(); - for class in eq_classes.iter_mut() { - let sort_exprs_to_remove = class - .iter() - .filter(|sort_exprs| { - sort_exprs.iter().any(|sort_expr| { - let col_infos = get_column_indices(&sort_expr.expr); - // If any one of the columns, used in Expression is invalid, remove expression - // from ordering equivalences - col_infos.into_iter().any(|(idx, name)| { - idx >= fields.len() || fields[idx].name() != &name - }) - }) - }) - .cloned() - .collect::>(); - for sort_exprs in sort_exprs_to_remove { - class.remove(&sort_exprs); +/// This function constructs a duplicate-free `LexOrdering` by filtering out +/// duplicate entries that have same physical expression inside. For example, +/// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. +pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { + let mut output = Vec::::new(); + for item in input { + if !output.iter().any(|req| req.expr.eq(&item.expr)) { + output.push(item); } } - eq_classes.retain(|props| props.len() > 1); + output +} - output_eq.extend(eq_classes); +/// An `OrderingEquivalenceClass` object keeps track of different alternative +/// orderings than can describe a schema. For example, consider the following table: +/// +/// ```text +/// |a|b|c|d| +/// |1|4|3|1| +/// |2|3|3|2| +/// |3|1|2|2| +/// |3|2|1|3| +/// ``` +/// +/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// ordering. In this case, we say that these orderings are equivalent. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OrderingEquivalenceClass { + orderings: Vec, } -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::Column; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; +impl OrderingEquivalenceClass { + /// Creates new empty ordering equivalence class. + fn empty() -> Self { + Self { orderings: vec![] } + } - use datafusion_expr::Operator; - use std::sync::Arc; + /// Clears (empties) this ordering equivalence class. + pub fn clear(&mut self) { + self.orderings.clear(); + } - #[test] - fn add_equal_conditions_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("x", DataType::Int64, true), - Field::new("y", DataType::Int64, true), - ])); + /// Creates new ordering equivalence class from the given orderings. + pub fn new(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } - let mut eq_properties = EquivalenceProperties::new(schema); - let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - - let new_condition = (&Column::new("b", 1), &Column::new("a", 0)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 2); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - - let new_condition = (&Column::new("b", 1), &Column::new("c", 2)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 3); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - assert!(eq_properties.classes()[0].contains(&Column::new("c", 2))); - - let new_condition = (&Column::new("x", 3), &Column::new("y", 4)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 2); - - let new_condition = (&Column::new("x", 3), &Column::new("a", 0)); - eq_properties.add_equal_conditions(new_condition); - assert_eq!(eq_properties.classes().len(), 1); - assert_eq!(eq_properties.classes()[0].len(), 5); - assert!(eq_properties.classes()[0].contains(&Column::new("a", 0))); - assert!(eq_properties.classes()[0].contains(&Column::new("b", 1))); - assert!(eq_properties.classes()[0].contains(&Column::new("c", 2))); - assert!(eq_properties.classes()[0].contains(&Column::new("x", 3))); - assert!(eq_properties.classes()[0].contains(&Column::new("y", 4))); + /// Checks whether `ordering` is a member of this equivalence class. + pub fn contains(&self, ordering: &LexOrdering) -> bool { + self.orderings.contains(ordering) + } - Ok(()) + /// Adds `ordering` to this equivalence class. + #[allow(dead_code)] + fn push(&mut self, ordering: LexOrdering) { + self.orderings.push(ordering); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); } - #[test] - fn project_equivalence_properties_test() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])); + /// Checks whether this ordering equivalence class is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } - let mut input_properties = EquivalenceProperties::new(input_schema); - let new_condition = (&Column::new("a", 0), &Column::new("b", 1)); - input_properties.add_equal_conditions(new_condition); - let new_condition = (&Column::new("b", 1), &Column::new("c", 2)); - input_properties.add_equal_conditions(new_condition); + /// Returns an iterator over the equivalent orderings in this class. + pub fn iter(&self) -> impl Iterator { + self.orderings.iter() + } - let out_schema = Arc::new(Schema::new(vec![ - Field::new("a1", DataType::Int64, true), - Field::new("a2", DataType::Int64, true), - Field::new("a3", DataType::Int64, true), - Field::new("a4", DataType::Int64, true), - ])); + /// Returns how many equivalent orderings there are in this class. + pub fn len(&self) -> usize { + self.orderings.len() + } - let mut alias_map = HashMap::new(); - alias_map.insert( - Column::new("a", 0), - vec![ - Column::new("a1", 0), - Column::new("a2", 1), - Column::new("a3", 2), - Column::new("a4", 3), - ], - ); - let mut out_properties = EquivalenceProperties::new(out_schema); + /// Extend this ordering equivalence class with the `other` class. + pub fn extend(&mut self, other: Self) { + self.orderings.extend(other.orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } - project_equivalence_properties(input_properties, &alias_map, &mut out_properties); - assert_eq!(out_properties.classes().len(), 1); - assert_eq!(out_properties.classes()[0].len(), 4); - assert!(out_properties.classes()[0].contains(&Column::new("a1", 0))); - assert!(out_properties.classes()[0].contains(&Column::new("a2", 1))); - assert!(out_properties.classes()[0].contains(&Column::new("a3", 2))); - assert!(out_properties.classes()[0].contains(&Column::new("a4", 3))); + /// Adds new orderings into this ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.orderings.extend(orderings); + // Make sure that there are no redundant orderings: + self.remove_redundant_entries(); + } - Ok(()) + /// Removes redundant orderings from this equivalence class. For instance, + /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is + /// no need to keep ordering `[a ASC, b ASC]` in the state. + fn remove_redundant_entries(&mut self) { + let mut work = true; + while work { + work = false; + let mut idx = 0; + while idx < self.orderings.len() { + let mut ordering_idx = idx + 1; + let mut removal = self.orderings[idx].is_empty(); + while ordering_idx < self.orderings.len() { + work |= resolve_overlap(&mut self.orderings, idx, ordering_idx); + if self.orderings[idx].is_empty() { + removal = true; + break; + } + work |= resolve_overlap(&mut self.orderings, ordering_idx, idx); + if self.orderings[ordering_idx].is_empty() { + self.orderings.swap_remove(ordering_idx); + } else { + ordering_idx += 1; + } + } + if removal { + self.orderings.swap_remove(idx); + } else { + idx += 1; + } + } + } } - #[test] - fn test_deduplicate_vector() -> Result<()> { - assert_eq!(deduplicate_vector(vec![1, 1, 2, 3, 3]), vec![1, 2, 3]); - assert_eq!( - deduplicate_vector(vec![1, 2, 3, 4, 3, 2, 1, 0]), - vec![1, 2, 3, 4, 0] - ); - Ok(()) + /// Returns the concatenation of all the orderings. This enables merge + /// operations to preserve all equivalent orderings simultaneously. + pub fn output_ordering(&self) -> Option { + let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = collapse_lex_ordering(output_ordering); + (!output_ordering.is_empty()).then_some(output_ordering) } - #[test] - fn test_get_entry_position() -> Result<()> { - assert_eq!(get_entry_position(&[1, 1, 2, 3, 3], &2), Some(2)); - assert_eq!(get_entry_position(&[1, 1, 2, 3, 3], &1), Some(0)); - assert_eq!(get_entry_position(&[1, 1, 2, 3, 3], &5), None); - Ok(()) + // Append orderings in `other` to all existing orderings in this equivalence + // class. + pub fn join_suffix(mut self, other: &Self) -> Self { + for ordering in other.iter() { + for idx in 0..self.orderings.len() { + self.orderings[idx].extend(ordering.iter().cloned()); + } + } + self } - #[test] - fn test_remove_from_vec() -> Result<()> { - let mut in_data = vec![1, 1, 2, 3, 3]; - remove_from_vec(&mut in_data, &5); - assert_eq!(in_data, vec![1, 1, 2, 3, 3]); - remove_from_vec(&mut in_data, &2); - assert_eq!(in_data, vec![1, 1, 3, 3]); - remove_from_vec(&mut in_data, &2); - assert_eq!(in_data, vec![1, 1, 3, 3]); - remove_from_vec(&mut in_data, &3); - assert_eq!(in_data, vec![1, 1, 3]); - remove_from_vec(&mut in_data, &3); - assert_eq!(in_data, vec![1, 1]); - Ok(()) + /// Adds `offset` value to the index of each expression inside this + /// ordering equivalence class. + pub fn add_offset(&mut self, offset: usize) { + for ordering in self.orderings.iter_mut() { + for sort_expr in ordering { + sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + } + } } - #[test] - fn test_get_column_infos() -> Result<()> { - let expr1 = Arc::new(Column::new("col1", 2)) as _; - assert_eq!(get_column_indices(&expr1), vec![(2, "col1".to_string())]); - let expr2 = Arc::new(Column::new("col2", 5)) as _; - assert_eq!(get_column_indices(&expr2), vec![(5, "col2".to_string())]); - let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _; - assert_eq!( - get_column_indices(&expr3), - vec![(2, "col1".to_string()), (5, "col2".to_string())] - ); + /// Gets sort options associated with this expression if it is a leading + /// ordering expression. Otherwise, returns `None`. + fn get_options(&self, expr: &Arc) -> Option { + for ordering in self.iter() { + let leading_ordering = &ordering[0]; + if leading_ordering.expr.eq(expr) { + return Some(leading_ordering.options); + } + } + None + } +} + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: usize, +) -> Arc { + expr.transform_down(&|e| match e.as_any().downcast_ref::() { + Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + col.name(), + offset + col.index(), + )))), + None => Ok(Transformed::No(e)), + }) + .unwrap() + // Note that we can safely unwrap here since our transform always returns + // an `Ok` value. +} + +/// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of +/// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. +fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> bool { + let length = orderings[idx].len(); + let other_length = orderings[pre_idx].len(); + for overlap in 1..=length.min(other_length) { + if orderings[idx][length - overlap..] == orderings[pre_idx][..overlap] { + orderings[idx].truncate(length - overlap); + return true; + } + } + false +} + +/// A `EquivalenceProperties` object stores useful information related to a schema. +/// Currently, it keeps track of: +/// - Equivalent expressions, e.g expressions that have same value. +/// - Valid sort expressions (orderings) for the schema. +/// - Constants expressions (e.g expressions that are known to have constant values). +/// +/// Consider table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 9 | +/// | 2 | 8 | +/// | 3 | 7 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where both `a ASC` and `b DESC` can describe the table ordering. With +/// `EquivalenceProperties`, we can keep track of these different valid sort +/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// +/// Similarly, consider the table below: +/// +/// ```text +/// ┌-------┐ +/// | a | b | +/// |---|---| +/// | 1 | 1 | +/// | 2 | 2 | +/// | 3 | 3 | +/// | 5 | 5 | +/// └---┴---┘ +/// ``` +/// +/// where columns `a` and `b` always have the same value. We keep track of such +/// equivalences inside this object. With this information, we can optimize +/// things like partitioning. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that +/// the existing partitioning satisfies the requirement. +#[derive(Debug, Clone)] +pub struct EquivalenceProperties { + /// Collection of equivalence classes that store expressions with the same + /// value. + eq_group: EquivalenceGroup, + /// Equivalent sort expressions for this table. + oeq_class: OrderingEquivalenceClass, + /// Expressions whose values are constant throughout the table. + /// TODO: We do not need to track constants separately, they can be tracked + /// inside `eq_groups` as `Literal` expressions. + constants: Vec>, + /// Schema associated with this object. + schema: SchemaRef, +} + +impl EquivalenceProperties { + /// Creates an empty `EquivalenceProperties` object. + pub fn new(schema: SchemaRef) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::empty(), + constants: vec![], + schema, + } + } + + /// Creates a new `EquivalenceProperties` object with the given orderings. + pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + Self { + eq_group: EquivalenceGroup::empty(), + oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), + constants: vec![], + schema, + } + } + + /// Returns the associated schema. + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Returns a reference to the ordering equivalence class within. + pub fn oeq_class(&self) -> &OrderingEquivalenceClass { + &self.oeq_class + } + + /// Returns a reference to the equivalence group within. + pub fn eq_group(&self) -> &EquivalenceGroup { + &self.eq_group + } + + /// Returns a reference to the constant expressions + pub fn constants(&self) -> &[Arc] { + &self.constants + } + + /// Returns the normalized version of the ordering equivalence class within. + /// Normalization removes constants and duplicates as well as standardizing + /// expressions according to the equivalence group within. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + OrderingEquivalenceClass::new( + self.oeq_class + .iter() + .map(|ordering| self.normalize_sort_exprs(ordering)) + .collect(), + ) + } + + /// Extends this `EquivalenceProperties` with the `other` object. + pub fn extend(mut self, other: Self) -> Self { + self.eq_group.extend(other.eq_group); + self.oeq_class.extend(other.oeq_class); + self.add_constants(other.constants) + } + + /// Clears (empties) the ordering equivalence class within this object. + /// Call this method when existing orderings are invalidated. + pub fn clear_orderings(&mut self) { + self.oeq_class.clear(); + } + + /// Extends this `EquivalenceProperties` by adding the orderings inside the + /// ordering equivalence class `other`. + pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { + self.oeq_class.extend(other); + } + + /// Adds new orderings into the existing ordering equivalence class. + pub fn add_new_orderings( + &mut self, + orderings: impl IntoIterator, + ) { + self.oeq_class.add_new_orderings(orderings); + } + + /// Incorporates the given equivalence group to into the existing + /// equivalence group within. + pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { + self.eq_group.extend(other_eq_group); + } + + /// Adds a new equality condition into the existing equivalence group. + /// If the given equality defines a new equivalence class, adds this new + /// equivalence class to the equivalence group. + pub fn add_equal_conditions( + &mut self, + left: &Arc, + right: &Arc, + ) { + self.eq_group.add_equal_conditions(left, right); + } + + /// Track/register physical expressions with constant values. + pub fn add_constants( + mut self, + constants: impl IntoIterator>, + ) -> Self { + for expr in self.eq_group.normalize_exprs(constants) { + if !physical_exprs_contains(&self.constants, &expr) { + self.constants.push(expr); + } + } + self + } + + /// Updates the ordering equivalence group within assuming that the table + /// is re-sorted according to the argument `sort_exprs`. Note that constants + /// and equivalence classes are unchanged as they are unaffected by a re-sort. + pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. + self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); + self + } + + /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the + /// equivalence group and the ordering equivalence class within. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_exprs(&self, sort_exprs: LexOrderingRef) -> LexOrdering { + // Convert sort expressions to sort requirements: + let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); + // Normalize the requirements: + let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); + // Convert sort requirements back to sort expressions: + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + } + + /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the + /// equivalence group and the ordering equivalence class within. It works by: + /// - Removing expressions that have a constant value from the given requirement. + /// - Replacing sections that belong to some equivalence class in the equivalence + /// group with the first entry in the matching equivalence class. + /// + /// Assume that `self.eq_group` states column `a` and `b` are aliases. + /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` + /// are equivalent (in the sense that both describe the ordering of the table). + /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this + /// function would return `vec![a ASC, c ASC]`. Internally, it would first + /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result + /// after deduplication. + fn normalize_sort_requirements( + &self, + sort_reqs: LexRequirementRef, + ) -> LexRequirement { + let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); + let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + // Prune redundant sections in the requirement: + collapse_lex_req( + normalized_sort_reqs + .iter() + .filter(|&order| { + !physical_exprs_contains(&constants_normalized, &order.expr) + }) + .cloned() + .collect(), + ) + } + + /// Checks whether the given ordering is satisfied by any of the existing + /// orderings. + pub fn ordering_satisfy(&self, given: LexOrderingRef) -> bool { + // Convert the given sort expressions to sort requirements: + let sort_requirements = PhysicalSortRequirement::from_sort_exprs(given.iter()); + self.ordering_satisfy_requirement(&sort_requirements) + } + + /// Checks whether the given sort requirements are satisfied by any of the + /// existing orderings. + pub fn ordering_satisfy_requirement(&self, reqs: LexRequirementRef) -> bool { + let mut eq_properties = self.clone(); + // First, standardize the given requirement: + let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); + for normalized_req in normalized_reqs { + // Check whether given ordering is satisfied + if !eq_properties.ordering_satisfy_single(&normalized_req) { + return false; + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + eq_properties = + eq_properties.add_constants(std::iter::once(normalized_req.expr)); + } + true + } + + /// Determines whether the ordering specified by the given sort requirement + /// is satisfied based on the orderings within, equivalence classes, and + /// constant expressions. + /// + /// # Arguments + /// + /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering + /// satisfaction check will be done. + /// + /// # Returns + /// + /// Returns `true` if the specified ordering is satisfied, `false` otherwise. + fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { + let expr_ordering = self.get_expr_ordering(req.expr.clone()); + let ExprOrdering { expr, state, .. } = expr_ordering; + match state { + SortProperties::Ordered(options) => { + let sort_expr = PhysicalSortExpr { expr, options }; + sort_expr.satisfy(req, self.schema()) + } + // Singleton expressions satisfies any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + } + } + + /// Checks whether the `given`` sort requirements are equal or more specific + /// than the `reference` sort requirements. + pub fn requirements_compatible( + &self, + given: LexRequirementRef, + reference: LexRequirementRef, + ) -> bool { + let normalized_given = self.normalize_sort_requirements(given); + let normalized_reference = self.normalize_sort_requirements(reference); + + (normalized_reference.len() <= normalized_given.len()) + && normalized_reference + .into_iter() + .zip(normalized_given) + .all(|(reference, given)| given.compatible(&reference)) + } + + /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking + /// any ties by choosing `lhs`. + /// + /// The finer ordering is the ordering that satisfies both of the orderings. + /// If the orderings are incomparable, returns `None`. + /// + /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is + /// the latter. + pub fn get_finer_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + // Convert the given sort expressions to sort requirements: + let lhs = PhysicalSortRequirement::from_sort_exprs(lhs); + let rhs = PhysicalSortRequirement::from_sort_exprs(rhs); + let finer = self.get_finer_requirement(&lhs, &rhs); + // Convert the chosen sort requirements back to sort expressions: + finer.map(PhysicalSortRequirement::to_sort_exprs) + } + + /// Returns the finer ordering among the requirements `lhs` and `rhs`, + /// breaking any ties by choosing `lhs`. + /// + /// The finer requirements are the ones that satisfy both of the given + /// requirements. If the requirements are incomparable, returns `None`. + /// + /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` + /// is the latter. + pub fn get_finer_requirement( + &self, + req1: LexRequirementRef, + req2: LexRequirementRef, + ) -> Option { + let mut lhs = self.normalize_sort_requirements(req1); + let mut rhs = self.normalize_sort_requirements(req2); + lhs.iter_mut() + .zip(rhs.iter_mut()) + .all(|(lhs, rhs)| { + lhs.expr.eq(&rhs.expr) + && match (lhs.options, rhs.options) { + (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, + (Some(options), None) => { + rhs.options = Some(options); + true + } + (None, Some(options)) => { + lhs.options = Some(options); + true + } + (None, None) => true, + } + }) + .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) + } + + /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). + /// The meet of a set of orderings is the finest ordering that is satisfied + /// by all the orderings in that set. For details, see: + /// + /// + /// + /// If there is no ordering that satisfies both `lhs` and `rhs`, returns + /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` + /// is `[a ASC]`. + pub fn get_meet_ordering( + &self, + lhs: LexOrderingRef, + rhs: LexOrderingRef, + ) -> Option { + let lhs = self.normalize_sort_exprs(lhs); + let rhs = self.normalize_sort_exprs(rhs); + let mut meet = vec![]; + for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { + if lhs.eq(&rhs) { + meet.push(lhs); + } else { + break; + } + } + (!meet.is_empty()).then_some(meet) + } + + /// Projects argument `expr` according to `projection_mapping`, taking + /// equivalences into account. + /// + /// For example, assume that columns `a` and `c` are always equal, and that + /// `projection_mapping` encodes following mapping: + /// + /// ```text + /// a -> a1 + /// b -> b1 + /// ``` + /// + /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to + /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + pub fn project_expr( + &self, + expr: &Arc, + projection_mapping: &ProjectionMapping, + ) -> Option> { + self.eq_group.project_expr(projection_mapping, expr) + } + + /// Projects constants based on the provided `ProjectionMapping`. + /// + /// This function takes a `ProjectionMapping` and identifies/projects + /// constants based on the existing constants and the mapping. It ensures + /// that constants are appropriately propagated through the projection. + /// + /// # Arguments + /// + /// - `mapping`: A reference to a `ProjectionMapping` representing the + /// mapping of source expressions to target expressions in the projection. + /// + /// # Returns + /// + /// Returns a `Vec>` containing the projected constants. + fn projected_constants( + &self, + mapping: &ProjectionMapping, + ) -> Vec> { + // First, project existing constants. For example, assume that `a + b` + // is known to be constant. If the projection were `a as a_new`, `b as b_new`, + // then we would project constant `a + b` as `a_new + b_new`. + let mut projected_constants = self + .constants + .iter() + .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .collect::>(); + // Add projection expressions that are known to be constant: + for (source, target) in mapping.iter() { + if self.is_expr_constant(source) + && !physical_exprs_contains(&projected_constants, target) + { + projected_constants.push(target.clone()); + } + } + projected_constants + } + + /// Projects the equivalences within according to `projection_mapping` + /// and `output_schema`. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + output_schema: SchemaRef, + ) -> Self { + let mut projected_orderings = self + .oeq_class + .iter() + .filter_map(|order| self.eq_group.project_ordering(projection_mapping, order)) + .collect::>(); + for (source, target) in projection_mapping.iter() { + let expr_ordering = ExprOrdering::new(source.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap(); + if let SortProperties::Ordered(options) = expr_ordering.state { + // Push new ordering to the state. + projected_orderings.push(vec![PhysicalSortExpr { + expr: target.clone(), + options, + }]); + } + } + Self { + eq_group: self.eq_group.project(projection_mapping), + oeq_class: OrderingEquivalenceClass::new(projected_orderings), + constants: self.projected_constants(projection_mapping), + schema: output_schema, + } + } + + /// Returns the longest (potentially partial) permutation satisfying the + /// existing ordering. For example, if we have the equivalent orderings + /// `[a ASC, b ASC]` and `[c DESC]`, with `exprs` containing `[c, b, a, d]`, + /// then this function returns `([a ASC, b ASC, c DESC], [2, 1, 0])`. + /// This means that the specification `[a ASC, b ASC, c DESC]` is satisfied + /// by the existing ordering, and `[a, b, c]` resides at indices: `2, 1, 0` + /// inside the argument `exprs` (respectively). For the mathematical + /// definition of "partial permutation", see: + /// + /// + pub fn find_longest_permutation( + &self, + exprs: &[Arc], + ) -> (LexOrdering, Vec) { + let mut eq_properties = self.clone(); + let mut result = vec![]; + // The algorithm is as follows: + // - Iterate over all the expressions and insert ordered expressions + // into the result. + // - Treat inserted expressions as constants (i.e. add them as constants + // to the state). + // - Continue the above procedure until no expression is inserted; i.e. + // the algorithm reaches a fixed point. + // This algorithm should reach a fixed point in at most `exprs.len()` + // iterations. + let mut search_indices = (0..exprs.len()).collect::>(); + for _idx in 0..exprs.len() { + // Get ordered expressions with their indices. + let ordered_exprs = search_indices + .iter() + .flat_map(|&idx| { + let ExprOrdering { expr, state, .. } = + eq_properties.get_expr_ordering(exprs[idx].clone()); + if let SortProperties::Ordered(options) = state { + Some((PhysicalSortExpr { expr, options }, idx)) + } else { + None + } + }) + .collect::>(); + // We reached a fixed point, exit. + if ordered_exprs.is_empty() { + break; + } + // Remove indices that have an ordering from `search_indices`, and + // treat ordered expressions as constants in subsequent iterations. + // We can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { + eq_properties = + eq_properties.add_constants(std::iter::once(expr.clone())); + search_indices.remove(idx); + } + // Add new ordered section to the state. + result.extend(ordered_exprs); + } + result.into_iter().unzip() + } + + /// This function determines whether the provided expression is constant + /// based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant according to equivalence + /// group, `false` otherwise. + fn is_expr_constant(&self, expr: &Arc) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); + let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// Retrieves the ordering information for a given physical expression. + /// + /// This function constructs an `ExprOrdering` object for the provided + /// expression, which encapsulates information about the expression's + /// ordering, including its [`SortProperties`]. + /// + /// # Arguments + /// + /// - `expr`: An `Arc` representing the physical expression + /// for which ordering information is sought. + /// + /// # Returns + /// + /// Returns an `ExprOrdering` object containing the ordering information for + /// the given expression. + pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { + ExprOrdering::new(expr.clone()) + .transform_up(&|expr| Ok(update_ordering(expr, self))) + // Guaranteed to always return `Ok`. + .unwrap() + } +} + +/// This function determines whether the provided expression is constant +/// based on the known constants. +/// +/// # Arguments +/// +/// - `constants`: A `&[Arc]` containing expressions known to +/// be a constant. +/// - `expr`: A reference to a `Arc` representing the expression +/// to check. +/// +/// # Returns +/// +/// Returns `true` if the expression is constant according to equivalence +/// group, `false` otherwise. +fn is_constant_recurse( + constants: &[Arc], + expr: &Arc, +) -> bool { + if physical_exprs_contains(constants, expr) { + return true; + } + let children = expr.children(); + !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn join_equivalence_properties( + left: EquivalenceProperties, + right: EquivalenceProperties, + join_type: &JoinType, + join_schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + on: &[(Column, Column)], +) -> EquivalenceProperties { + let left_size = left.schema.fields.len(); + let mut result = EquivalenceProperties::new(join_schema); + result.add_equivalence_group(left.eq_group().join( + right.eq_group(), + join_type, + left_size, + on, + )); + + let left_oeq_class = left.oeq_class; + let mut right_oeq_class = right.oeq_class; + match maintains_input_order { + [true, false] => { + // In this special case, right side ordering can be prefixed with + // the left side ordering. + if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + + // Right side ordering equivalence properties should be prepended + // with those of the left side while constructing output ordering + // equivalence properties since stream side is the left side. + // + // For example, if the right side ordering equivalences contain + // `b ASC`, and the left side ordering equivalences contain `a ASC`, + // then we should add `a ASC, b ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(left_oeq_class); + } + } + [false, true] => { + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + join_type, + left_size, + ); + // In this special case, left side ordering can be prefixed with + // the right side ordering. + if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + // Left side ordering equivalence properties should be prepended + // with those of the right side while constructing output ordering + // equivalence properties since stream side is the right side. + // + // For example, if the left side ordering equivalences contain + // `a ASC`, and the right side ordering equivalences contain `b ASC`, + // then we should add `b ASC, a ASC` to the ordering equivalences + // of the join output. + let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); + result.add_ordering_equivalence_class(out_oeq_class); + } else { + result.add_ordering_equivalence_class(right_oeq_class); + } + } + [false, false] => {} + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + } + result +} + +/// In the context of a join, update the right side `OrderingEquivalenceClass` +/// so that they point to valid indices in the join output schema. +/// +/// To do so, we increment column indices by the size of the left table when +/// join schema consists of a combination of the left and right schemas. This +/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases, +/// indices do not change. +fn updated_right_ordering_equivalence_class( + right_oeq_class: &mut OrderingEquivalenceClass, + join_type: &JoinType, + left_size: usize, +) { + if matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right + ) { + right_oeq_class.add_offset(left_size); + } +} + +/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. +/// The node can either be a leaf node, or an intermediate node: +/// - If it is a leaf node, we directly find the order of the node by looking +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +fn update_ordering( + mut node: ExprOrdering, + eq_properties: &EquivalenceProperties, +) -> Transformed { + // We have a Column, which is one of the two possible leaf node types: + let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + if eq_properties.is_expr_constant(&normalized_expr) { + node.state = SortProperties::Singleton; + } else if let Some(options) = eq_properties + .normalized_oeq_class() + .get_options(&normalized_expr) + { + node.state = SortProperties::Ordered(options); + } else if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + node.state = node.expr.get_ordering(&node.children_state()); + } else if node.expr.as_any().is::() { + // We have a Literal, which is the other possible leaf node type: + node.state = node.expr.get_ordering(&[]); + } else { + return Transformed::No(node); + } + Transformed::Yes(node) +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + use std::sync::Arc; + + use super::*; + use crate::execution_props::ExecutionProps; + use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; + use crate::functions::create_physical_expr; + + use arrow::compute::{lexsort_to_indices, SortColumn}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{BuiltinScalarFunction, Operator}; + + use itertools::{izip, Itertools}; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::{Rng, SeedableRng}; + + // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) + } + + /// Construct a schema with following properties + /// Schema satisfies following orderings: + /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + /// and + /// Column [a=c] (e.g they are aliases). + fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + eq_properties.add_equal_conditions(col_a, col_c); + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) + } + + // Generate a schema which consists of 6 columns (a, b, c, d, e, f) + fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + + /// Construct a schema with random ordering + /// among column a, b, c, d + /// where + /// Column [a=f] (e.g they are aliases). + /// Column e is constant. + fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f); + // Column e has constant value. + eq_properties = eq_properties.add_constants([col_e.clone()]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) + } + + // Convert each tuple to PhysicalSortRequirement + fn convert_to_sort_reqs( + in_data: &[(&Arc, Option)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| { + PhysicalSortRequirement::new((*expr).clone(), *options) + }) + .collect::>() + } + + // Convert each tuple to PhysicalSortExpr + fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect::>() + } + + // Convert each inner tuple to PhysicalSortExpr + fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], + ) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() + } + + #[test] + fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) + } + + #[test] + fn project_equivalence_properties_test() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let input_properties = EquivalenceProperties::new(input_schema.clone()); + let col_a = col("a", &input_schema)?; + + let out_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int64, true), + Field::new("a2", DataType::Int64, true), + Field::new("a3", DataType::Int64, true), + Field::new("a4", DataType::Int64, true), + ])); + + // a as a1, a as a2, a as a3, a as a3 + let proj_exprs = vec![ + (col_a.clone(), "a1".to_string()), + (col_a.clone(), "a2".to_string()), + (col_a.clone(), "a3".to_string()), + (col_a.clone(), "a4".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + // a as a1, a as a2, a as a3, a as a3 + let col_a1 = &col("a1", &out_schema)?; + let col_a2 = &col("a2", &out_schema)?; + let col_a3 = &col("a3", &out_schema)?; + let col_a4 = &col("a4", &out_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // At the output a1=a2=a3=a4 + assert_eq!(out_properties.eq_group().len(), 1); + let eq_class = &out_properties.eq_group().classes[0]; + assert_eq!(eq_class.len(), 4); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); + + Ok(()) + } + + #[test] + fn test_ordering_satisfy() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ])); + let crude = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]; + let finer = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions::default(), + }, + ]; + // finer ordering satisfies, crude ordering should return true + let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + eq_properties_finer.oeq_class.push(finer.clone()); + assert!(eq_properties_finer.ordering_satisfy(&crude)); + + // Crude ordering doesn't satisfy finer ordering. should return false + let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + eq_properties_crude.oeq_class.push(crude.clone()); + assert!(!eq_properties_crude.ordering_satisfy(&finer)); + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: expr.clone(), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence2() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let floor_a = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let floor_f = &create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("f", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let exp_a = &create_physical_expr( + &BuiltinScalarFunction::Exp, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let test_cases = vec![ + // ------------ TEST CASE 1 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC], requirement is not satisfied. + vec![(col_a, options), (col_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 2 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC], + vec![(floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 2.1 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(f) ASC], (Please note that a=f) + vec![(floor_f, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 3 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, a+b ASC], + vec![(col_a, options), (col_c, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 4 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [floor(a) ASC, a+b ASC], + vec![(floor_a, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + false, + ), + // ------------ TEST CASE 5 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [exp(a) ASC, a+b ASC], + vec![(exp_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + // TODO: If we know that exp function is 1-to-1 function. + // we could have deduced that above requirement is satisfied. + false, + ), + // ------------ TEST CASE 6 ------------ + ( + // orderings + vec![ + // [a ASC, d ASC, b ASC] + vec![(col_a, options), (col_d, options), (col_b, options)], + // [c ASC] + vec![(col_c, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, d ASC, floor(a) ASC], + vec![(col_a, options), (col_d, options), (floor_a, options)], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 7 ------------ + ( + // orderings + vec![ + // [a ASC, c ASC, b ASC] + vec![(col_a, options), (col_c, options), (col_b, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, floor(a) ASC, a + b ASC], + vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 8 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [d ASC] + vec![(col_d, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, c ASC, floor(a) ASC, a + b ASC], + vec![ + (col_a, options), + (col_c, options), + (&floor_a, options), + (&a_plus_b, options), + ], + // expected: requirement is not satisfied. + false, + ), + // ------------ TEST CASE 9 ------------ + ( + // orderings + vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, options), + (col_b, options), + (col_c, options), + (col_d, options), + ], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [a ASC, b ASC, c ASC, floor(a) ASC], + vec![ + (col_a, options), + (col_b, options), + (&col_c, options), + (&floor_a, options), + ], + // expected: requirement is satisfied. + true, + ), + // ------------ TEST CASE 10 ------------ + ( + // orderings + vec![ + // [d ASC, b ASC] + vec![(col_d, options), (col_b, options)], + // [c ASC, a ASC] + vec![(col_c, options), (col_a, options)], + ], + // equivalence classes + vec![vec![col_a, col_f]], + // constants + vec![col_e], + // requirement [c ASC, d ASC, a + b ASC], + vec![(col_c, options), (col_d, options), (&a_plus_b, options)], + // expected: requirement is satisfied. + true, + ), + ]; + + for (orderings, eq_group, constants, reqs, expected) in test_cases { + let err_msg = + format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + let eq_group = eq_group + .into_iter() + .map(|eq_class| { + let eq_classes = eq_class.into_iter().cloned().collect::>(); + EquivalenceClass::new(eq_classes) + }) + .collect::>(); + let eq_group = EquivalenceGroup::new(eq_group); + eq_properties.add_equivalence_group(eq_group); + + let constants = constants.into_iter().cloned(); + eq_properties = eq_properties.add_constants(constants); + + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_ordering_satisfy_different_lengths() -> Result<()> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let options = SortOptions { + descending: false, + nulls_first: false, + }; + // a=c (e.g they are aliases). + let mut eq_properties = EquivalenceProperties::new(test_schema); + eq_properties.add_equal_conditions(col_a, col_c); + + let orderings = vec![ + vec![(col_a, options)], + vec![(col_e, options)], + vec![(col_d, options), (col_f, options)], + ]; + let orderings = convert_to_orderings(&orderings); + + // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. + eq_properties.add_new_orderings(orderings); + + // First entry in the tuple is required ordering, second entry is the expected flag + // that indicates whether this required ordering is satisfied. + // ([a ASC], true) indicate a ASC requirement is already satisfied by existing orderings. + let test_cases = vec![ + // [c ASC, a ASC, e ASC], expected represents this requirement is satisfied + ( + vec![(col_c, options), (col_a, options), (col_e, options)], + true, + ), + (vec![(col_c, options), (col_b, options)], false), + (vec![(col_c, options), (col_d, options)], true), + ( + vec![(col_d, options), (col_f, options), (col_b, options)], + false, + ), + (vec![(col_d, options), (col_f, options)], true), + ]; + + for (reqs, expected) in test_cases { + let err_msg = + format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + let reqs = convert_to_sort_exprs(&reqs); + assert_eq!( + eq_properties.ordering_satisfy(&reqs), + expected, + "{}", + err_msg + ); + } + + Ok(()) + } + + #[test] + fn test_bridge_groups() -> Result<()> { + // First entry in the tuple is argument, second entry is the bridged result + let test_cases = vec![ + // ------- TEST CASE 1 -----------// + ( + vec![vec![1, 2, 3], vec![2, 4, 5], vec![11, 12, 9], vec![7, 6, 5]], + // Expected is compared with set equality. Order of the specific results may change. + vec![vec![1, 2, 3, 4, 5, 6, 7], vec![9, 11, 12]], + ), + // ------- TEST CASE 2 -----------// + ( + vec![vec![1, 2, 3], vec![3, 4, 5], vec![9, 8, 7], vec![7, 6, 5]], + // Expected + vec![vec![1, 2, 3, 4, 5, 6, 7, 8, 9]], + ), + ]; + for (entries, expected) in test_cases { + let entries = entries + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let expected = expected + .into_iter() + .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) + .collect::>(); + let mut eq_groups = EquivalenceGroup::new(entries.clone()); + eq_groups.bridge_classes(); + let eq_groups = eq_groups.classes; + let err_msg = format!( + "error in test entries: {:?}, expected: {:?}, actual:{:?}", + entries, expected, eq_groups + ); + assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + for idx in 0..eq_groups.len() { + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + } + } + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_eq_group() -> Result<()> { + let entries = vec![ + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), + // This group is meaningless should be removed + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + // Given equivalences classes are not in succinct form. + // Expected form is the most plain representation that is functionally same. + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; + let mut eq_groups = EquivalenceGroup::new(entries); + eq_groups.remove_redundant_entries(); + + let eq_groups = eq_groups.classes; + assert_eq!(eq_groups.len(), expected.len()); + assert_eq!(eq_groups.len(), 2); + + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); + Ok(()) + } + + #[test] + fn test_remove_redundant_entries_oeq_class() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + + // First entry in the tuple is the given orderings for the table + // Second entry is the simplest version of the given orderings that is functionally equivalent. + let test_cases = vec![ + // ------- TEST CASE 1 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + ], + ), + // ------- TEST CASE 2 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 3 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC] + vec![(col_a, option_asc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b DESC] + vec![(col_a, option_asc), (col_b, option_desc)], + // [a ASC, c ASC] + vec![(col_a, option_asc), (col_c, option_asc)], + ], + ), + // ------- TEST CASE 4 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [a ASC] + vec![(col_a, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + ), + // ------- TEST CASE 5 --------- + // Empty ordering + ( + vec![vec![]], + // No ordering in the state (empty ordering is ignored). + vec![], + ), + // ------- TEST CASE 6 --------- + ( + // ORDERINGS GIVEN + vec![ + // [a ASC, b ASC] + vec![(col_a, option_asc), (col_b, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + // EXPECTED orderings that is succinct. + vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [b ASC] + vec![(col_b, option_asc)], + ], + ), + // ------- TEST CASE 7 --------- + // b, a + // c, a + // d, b, c + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, c ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, a ASC] + vec![(col_b, option_asc), (col_a, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 8 --------- + // b, e + // c, a + // d, b, e, c, a + ( + // ORDERINGS GIVEN + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC, b ASC, e ASC, c ASC, a ASC] + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_c, option_asc), + (col_a, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC, e ASC] + vec![(col_b, option_asc), (col_e, option_asc)], + // [c ASC, a ASC] + vec![(col_c, option_asc), (col_a, option_asc)], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + // ------- TEST CASE 9 --------- + // b + // a, b, c + // d, a, b + ( + // ORDERINGS GIVEN + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC, a ASC, b ASC] + vec![ + (col_d, option_asc), + (col_a, option_asc), + (col_b, option_asc), + ], + ], + // EXPECTED orderings that is succinct. + vec![ + // [b ASC] + vec![(col_b, option_asc)], + // [a ASC, b ASC, c ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + ], + // [d ASC] + vec![(col_d, option_asc)], + ], + ), + ]; + for (orderings, expected) in test_cases { + let orderings = convert_to_orderings(&orderings); + let expected = convert_to_orderings(&expected); + let actual = OrderingEquivalenceClass::new(orderings.clone()); + let actual = actual.orderings; + let err_msg = format!( + "orderings: {:?}, expected: {:?}, actual :{:?}", + orderings, expected, actual + ); + assert_eq!(actual.len(), expected.len(), "{}", err_msg); + for elem in actual { + assert!(expected.contains(&elem), "{}", err_msg); + } + } + + Ok(()) + } + + #[test] + fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> { + let join_type = JoinType::Inner; + // Join right child schema + let child_fields: Fields = ["x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + let child_schema = Schema::new(child_fields); + let col_x = &col("x", &child_schema)?; + let col_y = &col("y", &child_schema)?; + let col_z = &col("z", &child_schema)?; + let col_w = &col("w", &child_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + // Right child ordering equivalences + let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + + let left_columns_len = 4; + + let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"] + .into_iter() + .map(|name| Field::new(name, DataType::Int32, true)) + .collect(); + + // Join Schema + let schema = Schema::new(fields); + let col_a = &col("a", &schema)?; + let col_d = &col("d", &schema)?; + let col_x = &col("x", &schema)?; + let col_y = &col("y", &schema)?; + let col_z = &col("z", &schema)?; + let col_w = &col("w", &schema)?; + + let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); + // a=x and d=w + join_eq_properties.add_equal_conditions(col_a, col_x); + join_eq_properties.add_equal_conditions(col_d, col_w); + + updated_right_ordering_equivalence_class( + &mut right_oeq_class, + &join_type, + left_columns_len, + ); + join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + let result = join_eq_properties.oeq_class().clone(); + + // [x ASC, y ASC], [z ASC, w ASC] + let orderings = vec![ + vec![(col_x, option_asc), (col_y, option_asc)], + vec![(col_z, option_asc), (col_w, option_asc)], + ]; + let orderings = convert_to_orderings(&orderings); + let expected = OrderingEquivalenceClass::new(orderings); + + assert_eq!(result, expected); + + Ok(()) + } + + /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. + /// + /// The function works by adding a unique column of ascending integers to the original table. This column ensures + /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can + /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce + /// deterministic sorting results. + /// + /// If the table remains the same after sorting with the added unique column, it indicates that the table was + /// already sorted according to `required_ordering` to begin with. + fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, + ) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(unique_col.clone()); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = + Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) + } + + // If we already generated a random result for one of the + // expressions in the equivalence classes. For other expressions in the same + // equivalence class use same result. This util gets already calculated result, when available. + fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, + ) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(res.clone()); + } + } + None + } + + // Generate a table that satisfies the given equivalence properties; i.e. + // equivalences, ordering equivalences, and constants. + fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, + ) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) + as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, schema.clone()) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(representative_array.clone()); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) + } + + #[test] + fn test_schema_normalize_expr_with_equivalence() -> Result<()> { + let col_a = &Column::new("a", 0); + let col_b = &Column::new("b", 1); + let col_c = &Column::new("c", 2); + // Assume that column a and c are aliases. + let (_test_schema, eq_properties) = create_test_params()?; + + let col_a_expr = Arc::new(col_a.clone()) as Arc; + let col_b_expr = Arc::new(col_b.clone()) as Arc; + let col_c_expr = Arc::new(col_c.clone()) as Arc; + // Test cases for equivalence normalization, + // First entry in the tuple is argument, second entry is expected result after normalization. + let expressions = vec![ + // Normalized version of the column a and c should go to a + // (by convention all the expressions inside equivalence class are mapped to the first entry + // in this case a is the first entry in the equivalence class.) + (&col_a_expr, &col_a_expr), + (&col_c_expr, &col_a_expr), + // Cannot normalize column b + (&col_b_expr, &col_b_expr), + ]; + let eq_group = eq_properties.eq_group(); + for (expr, expected_eq) in expressions { + assert!( + expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + "error in test: expr: {expr:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_schema_normalize_sort_requirement_with_equivalence() -> Result<()> { + let option1 = SortOptions { + descending: false, + nulls_first: false, + }; + // Assume that column a and c are aliases. + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + + // Test cases for equivalence normalization + // First entry in the tuple is PhysicalSortRequirement, second entry in the tuple is + // expected PhysicalSortRequirement after normalization. + let test_cases = vec![ + (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), + // In the normalized version column c should be replace with column a + (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), + (vec![(col_c, None)], vec![(col_a, None)]), + (vec![(col_d, Some(option1))], vec![(col_d, Some(option1))]), + ]; + for (reqs, expected) in test_cases.into_iter() { + let reqs = convert_to_sort_reqs(&reqs); + let expected = convert_to_sort_reqs(&expected); + + let normalized = eq_properties.normalize_sort_requirements(&reqs); + assert!( + expected.eq(&normalized), + "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" + ); + } + + Ok(()) + } + + #[test] + fn test_normalize_sort_reqs() -> Result<()> { + // Schema satisfies following properties + // a=c + // and following orderings are valid + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + ( + vec![(col_a, Some(option_desc))], + vec![(col_a, Some(option_desc))], + ), + (vec![(col_a, None)], vec![(col_a, None)]), + // Test whether equivalence works as expected + ( + vec![(col_c, Some(option_asc))], + vec![(col_a, Some(option_asc))], + ), + (vec![(col_c, None)], vec![(col_a, None)]), + // Test whether ordering equivalence works as expected + ( + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_d, Some(option_asc)), (col_b, Some(option_asc))], + ), + ( + vec![(col_d, None), (col_b, None)], + vec![(col_d, None), (col_b, None)], + ), + ( + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + vec![(col_e, Some(option_desc)), (col_f, Some(option_asc))], + ), + // We should be able to normalize in compatible requirements also (not exactly equal) + ( + vec![(col_e, Some(option_desc)), (col_f, None)], + vec![(col_e, Some(option_desc)), (col_f, None)], + ), + ( + vec![(col_e, None), (col_f, None)], + vec![(col_e, None), (col_f, None)], + ), + ]; + + for (reqs, expected_normalized) in requirements.into_iter() { + let req = convert_to_sort_reqs(&reqs); + let expected_normalized = convert_to_sort_reqs(&expected_normalized); + + assert_eq!( + eq_properties.normalize_sort_requirements(&req), + expected_normalized + ); + } + + Ok(()) + } + + #[test] + fn test_get_finer() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. + // Third entry is the expected result. + let tests_cases = vec![ + // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC)] + ( + vec![(col_a, Some(option_asc))], + vec![(col_a, None), (col_b, Some(option_asc))], + Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] + // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] + ( + vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ], + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + Some(vec![ + (col_a, Some(option_asc)), + (col_b, Some(option_asc)), + (col_c, Some(option_asc)), + ]), + ), + // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] + // result should be None + ( + vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], + vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], + None, + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_reqs(&lhs); + let rhs = convert_to_sort_reqs(&rhs); + let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); + let finer = eq_properties.get_finer_requirement(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_get_meet_ordering() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let eq_properties = EquivalenceProperties::new(schema); + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let tests_cases = vec![ + // Get meet ordering between [a ASC] and [a ASC, b ASC] + // result should be [a ASC] + ( + vec![(col_a, option_asc)], + vec![(col_a, option_asc), (col_b, option_asc)], + Some(vec![(col_a, option_asc)]), + ), + // Get meet ordering between [a ASC] and [a DESC] + // result should be None. + (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), + // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] + // result should be [a ASC]. + ( + vec![(col_a, option_asc), (col_b, option_asc)], + vec![(col_a, option_asc), (col_b, option_desc)], + Some(vec![(col_a, option_asc)]), + ), + ]; + for (lhs, rhs, expected) in tests_cases { + let lhs = convert_to_sort_exprs(&lhs); + let rhs = convert_to_sort_exprs(&rhs); + let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); + let finer = eq_properties.get_meet_ordering(&lhs, &rhs); + assert_eq!(finer, expected) + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + // At below we add [d ASC, h DESC] also, for test purposes + let (test_schema, mut eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_h = &col("h", &test_schema)?; + // a + d + let a_plus_d = Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + // [d ASC, h ASC] also satisfies schema. + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }, + PhysicalSortExpr { + expr: col_h.clone(), + options: option_desc, + }, + ]]); + let test_cases = vec![ + // TEST CASE 1 + (vec![col_a], vec![(col_a, option_asc)]), + // TEST CASE 2 + (vec![col_c], vec![(col_c, option_asc)]), + // TEST CASE 3 + ( + vec![col_d, col_e, col_b], + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + ), + // TEST CASE 4 + (vec![col_b], vec![]), + // TEST CASE 5 + (vec![col_d], vec![(col_d, option_asc)]), + // TEST CASE 5 + (vec![&a_plus_d], vec![(&a_plus_d, option_asc)]), + // TEST CASE 6 + ( + vec![col_b, col_d], + vec![(col_d, option_asc), (col_b, option_asc)], + ), + // TEST CASE 6 + ( + vec![col_c, col_e], + vec![(col_c, option_asc), (col_e, option_desc)], + ), + ]; + for (exprs, expected) in test_cases { + let exprs = exprs.into_iter().cloned().collect::>(); + let expected = convert_to_sort_exprs(&expected); + let (actual, _) = eq_properties.find_longest_permutation(&exprs); + assert_eq!(actual, expected); + } + + Ok(()) + } + + #[test] + fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let floor_a = create_physical_expr( + &BuiltinScalarFunction::Floor, + &[col("a", &test_schema)?], + &test_schema, + &ExecutionProps::default(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = vec![ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: exprs[idx].clone(), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) + } + + #[test] + fn test_update_ordering() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + ]); + + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + // b=a (e.g they are aliases) + eq_properties.add_equal_conditions(col_b, col_a); + // [b ASC], [d ASC] + eq_properties.add_new_orderings(vec![ + vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_asc, + }], + vec![PhysicalSortExpr { + expr: col_d.clone(), + options: option_asc, + }], + ]); + + let test_cases = vec![ + // d + b + ( + Arc::new(BinaryExpr::new( + col_d.clone(), + Operator::Plus, + col_b.clone(), + )) as Arc, + SortProperties::Ordered(option_asc), + ), + // b + (col_b.clone(), SortProperties::Ordered(option_asc)), + // a + (col_a.clone(), SortProperties::Ordered(option_asc)), + // a + c + ( + Arc::new(BinaryExpr::new( + col_a.clone(), + Operator::Plus, + col_c.clone(), + )), + SortProperties::Unordered, + ), + ]; + for (expr, expected) in test_cases { + let leading_orderings = eq_properties + .oeq_class() + .iter() + .flat_map(|ordering| ordering.first().cloned()) + .collect::>(); + let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let err_msg = format!( + "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", + expr, expected, expr_ordering.state + ); + assert_eq!(expr_ordering.state, expected, "{}", err_msg); + } + + Ok(()) + } + + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } + + #[test] + fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { + let sort_options = SortOptions::default(); + let sort_options_not = SortOptions::default().not(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let required_columns = [col_b.clone(), col_a.clone()]; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + eq_properties.add_new_orderings([ + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }], + vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ], + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0, 1]); + assert_eq!( + result, + vec![ + PhysicalSortExpr { + expr: col_b.clone(), + options: sort_options_not + }, + PhysicalSortExpr { + expr: col_a.clone(), + options: sort_options + } + ] + ); + + let required_columns = [ + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("a", 0)) as _, + ]; + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); + + // not satisfied orders + eq_properties.add_new_orderings([vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: sort_options_not, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: sort_options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: sort_options, + }, + ]]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + assert_eq!(idxs, vec![0]); + + Ok(()) + } + + #[test] + fn test_normalize_ordering_equivalence_classes() -> Result<()> { + let sort_options = SortOptions::default(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let col_a_expr = col("a", &schema)?; + let col_b_expr = col("b", &schema)?; + let col_c_expr = col("c", &schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); + + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + let others = vec![ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]; + eq_properties.add_new_orderings(others); + + let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); + expected_eqs.add_new_orderings([ + vec![PhysicalSortExpr { + expr: col_b_expr.clone(), + options: sort_options, + }], + vec![PhysicalSortExpr { + expr: col_c_expr.clone(), + options: sort_options, + }], + ]); + + let oeq_class = eq_properties.oeq_class().clone(); + let expected = expected_eqs.oeq_class(); + assert!(oeq_class.eq(expected)); + + Ok(()) + } + + #[test] + fn test_expr_consists_of_constants() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_d = col("d", &schema)?; + let b_plus_d = Arc::new(BinaryExpr::new( + col_b.clone(), + Operator::Plus, + col_d.clone(), + )) as Arc; + + let constants = vec![col_a.clone(), col_b.clone()]; + let expr = b_plus_d.clone(); + assert!(!is_constant_recurse(&constants, &expr)); + + let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; + let expr = b_plus_d.clone(); + assert!(is_constant_recurse(&constants, &expr)); Ok(()) } } diff --git a/datafusion/physical-expr/src/execution_props.rs b/datafusion/physical-expr/src/execution_props.rs index 5849850031b19..8fdbbb7c5452f 100644 --- a/datafusion/physical-expr/src/execution_props.rs +++ b/datafusion/physical-expr/src/execution_props.rs @@ -17,10 +17,11 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, TimeZone, Utc}; +use datafusion_common::alias::AliasGenerator; use std::collections::HashMap; use std::sync::Arc; -/// Holds per-query execution properties and data (such as statment +/// Holds per-query execution properties and data (such as statement /// starting timestamps). /// /// An [`ExecutionProps`] is created each time a [`LogicalPlan`] is @@ -34,6 +35,8 @@ use std::sync::Arc; #[derive(Clone, Debug)] pub struct ExecutionProps { pub query_execution_start_time: DateTime, + /// Alias generator used by subquery optimizer rules + pub alias_generator: Arc, /// Providers for scalar variables pub var_providers: Option>>, } @@ -51,13 +54,25 @@ impl ExecutionProps { // Set this to a fixed sentinel to make it obvious if this is // not being updated / propagated correctly query_execution_start_time: Utc.timestamp_nanos(0), + alias_generator: Arc::new(AliasGenerator::new()), var_providers: None, } } - /// Marks the execution of query started timestamp + /// Set the query execution start time to use + pub fn with_query_execution_start_time( + mut self, + query_execution_start_time: DateTime, + ) -> Self { + self.query_execution_start_time = query_execution_start_time; + self + } + + /// Marks the execution of query started timestamp. + /// This also instantiates a new alias generator. pub fn start_execution(&mut self) -> &Self { self.query_execution_start_time = Utc::now(); + self.alias_generator = Arc::new(AliasGenerator::new()); &*self } @@ -94,6 +109,6 @@ mod test { #[test] fn debug() { let props = ExecutionProps::new(); - assert_eq!("ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, var_providers: None }", format!("{props:?}")); + assert_eq!("ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, alias_generator: AliasGenerator { next_id: 1 }, var_providers: None }", format!("{props:?}")); } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index e5b66d4a39871..9c7fdd2e814b1 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -15,88 +15,44 @@ // specific language governing permissions and limitations // under the License. -mod adapter; mod kernels; -mod kernels_arrow; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use arrow::array::*; -use arrow::compute::kernels::arithmetic::{ - add_dyn, add_scalar_dyn as add_dyn_scalar, divide_dyn_opt, - divide_scalar_dyn as divide_dyn_scalar, modulus_dyn, - modulus_scalar_dyn as modulus_dyn_scalar, multiply_dyn, - multiply_scalar_dyn as multiply_dyn_scalar, subtract_dyn, - subtract_scalar_dyn as subtract_dyn_scalar, +use crate::array_expressions::{ + array_append, array_concat, array_has_all, array_prepend, }; +use crate::expressions::datum::{apply, apply_cmp}; +use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; +use crate::physical_expr::down_cast_any_ref; +use crate::sort_properties::SortProperties; +use crate::PhysicalExpr; + +use arrow::array::*; +use arrow::compute::cast; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; +use arrow::compute::kernels::cmp::*; use arrow::compute::kernels::comparison::regexp_is_match_utf8; use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; -use arrow::compute::kernels::comparison::{ - eq_dyn_binary_scalar, gt_dyn_binary_scalar, gt_eq_dyn_binary_scalar, - lt_dyn_binary_scalar, lt_eq_dyn_binary_scalar, neq_dyn_binary_scalar, -}; -use arrow::compute::kernels::comparison::{ - eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, - lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar, -}; -use arrow::compute::kernels::comparison::{ - eq_dyn_scalar, gt_dyn_scalar, gt_eq_dyn_scalar, lt_dyn_scalar, lt_eq_dyn_scalar, - neq_dyn_scalar, -}; -use arrow::compute::kernels::comparison::{ - eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar, - lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar, -}; -use arrow::compute::{cast, CastOptions}; +use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::datatypes::*; +use arrow::record_batch::RecordBatch; -use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn}; -use arrow::compute::kernels::concat_elements::concat_elements_utf8; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; +use datafusion_expr::type_coercion::binary::get_result_type; +use datafusion_expr::{ColumnarValue, Operator}; -use datafusion_expr::type_coercion::{is_decimal, is_timestamp, is_utf8_or_large_utf8}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar, }; -use kernels_arrow::{ - add_decimal_dyn_scalar, add_dyn_decimal, add_dyn_temporal, divide_decimal_dyn_scalar, - divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_binary, - is_distinct_from_bool, is_distinct_from_decimal, is_distinct_from_f32, - is_distinct_from_f64, is_distinct_from_null, is_distinct_from_utf8, - is_not_distinct_from, is_not_distinct_from_binary, is_not_distinct_from_bool, - is_not_distinct_from_decimal, is_not_distinct_from_f32, is_not_distinct_from_f64, - is_not_distinct_from_null, is_not_distinct_from_utf8, modulus_decimal_dyn_scalar, - modulus_dyn_decimal, multiply_decimal_dyn_scalar, multiply_dyn_decimal, - subtract_decimal_dyn_scalar, subtract_dyn_decimal, subtract_dyn_temporal, -}; - -use arrow::datatypes::{DataType, Schema, TimeUnit}; -use arrow::record_batch::RecordBatch; - -use self::kernels_arrow::{ - add_dyn_temporal_left_scalar, add_dyn_temporal_right_scalar, - subtract_dyn_temporal_left_scalar, subtract_dyn_temporal_right_scalar, -}; - -use super::column::Column; -use crate::expressions::cast_column; -use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::intervals::{apply_operator, Interval}; -use crate::physical_expr::down_cast_any_ref; -use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr}; -use datafusion_common::cast::as_boolean_array; - -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::type_coercion::binary::{ - coercion_decimal_mathematics_type, get_result_type, -}; -use datafusion_expr::{ColumnarValue, Operator}; /// Binary expression -#[derive(Debug)] +#[derive(Debug, Hash, Clone)] pub struct BinaryExpr { left: Arc, op: Operator, @@ -162,67 +118,6 @@ impl std::fmt::Display for BinaryExpr { } } -macro_rules! compute_decimal_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - if let ScalarValue::Decimal128(Some(v_i128), _, _) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}($LEFT, v_i128)?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE type - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -macro_rules! compute_decimal_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap(); - let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap(); - Ok(Arc::new(paste::expr! {[<$OP _decimal>]}(ll, rr)?)) - }}; -} - -macro_rules! compute_f32_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _f32>]}(ll, rr)?)) - }}; -} - -macro_rules! compute_f64_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _f64>]}(ll, rr)?)) - }}; -} - -macro_rules! compute_null_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?)) - }}; -} - /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -238,325 +133,15 @@ macro_rules! compute_utf8_op { }}; } -/// Invoke a compute kernel on a pair of binary data arrays -macro_rules! compute_binary_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _binary>]}(&ll, &rr)?)) - }}; -} - -/// Invoke a compute kernel on a data array and a scalar value -macro_rules! compute_utf8_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident, $OP_TYPE:expr) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - if let ScalarValue::Utf8(Some(string_value)) - | ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT - { - Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}( - &ll, - &string_value, - )?)) - } else if $RIGHT.is_null() { - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } else { - Err(DataFusionError::Internal(format!( - "compute_utf8_op_scalar for '{}' failed to cast literal value {}", - stringify!($OP), - $RIGHT - ))) - } - }}; -} - -/// Invoke a compute kernel on a data array and a scalar value -macro_rules! compute_utf8_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - if let Some(string_value) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}( - $LEFT, - &string_value, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a compute kernel on a data array and a scalar value -macro_rules! compute_binary_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - if let Some(binary_value) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_binary_scalar>]}( - $LEFT, - &binary_value, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a compute kernel on a boolean data array and a scalar value -macro_rules! compute_bool_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - // generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter - // (which could have a value of lt) and the suffix _scalar - if let Some(b) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_bool_scalar>]}( - $LEFT, - b, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a bool compute kernel on array(s) -macro_rules! compute_bool_op { - // invoke binary operator - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&ll, &rr)?)) - }}; - // invoke unary operator - ($OPERAND:expr, $OP:ident, $DT:ident) => {{ - let operand = $OPERAND - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast operant array"); - Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&operand)?)) - }}; -} - -/// Invoke a dyn compute kernel on a data array and a scalar value -/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value -/// OP_TYPE is the return type of scalar function -macro_rules! compute_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter - // (which could have a value of lt_dyn) and the suffix _scalar - if let Some(value) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}( - $LEFT, - value, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a dyn compute kernel on a data array and a scalar value -/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value -/// OP_TYPE is the return type of scalar function -/// SCALAR_TYPE is the type of the scalar value -/// Different to `compute_op_dyn_scalar`, this calls the `_dyn_scalar` functions that -/// take a `SCALAR_TYPE`. -macro_rules! compute_primitive_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $SCALAR_TYPE:ident) => {{ - // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter - // (which could have a value of lt_dyn) and the suffix _scalar - if let Some(value) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]::<$SCALAR_TYPE>}( - $LEFT, - value, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a dyn decimal compute kernel on a data array and a scalar value -/// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar value -/// OP_TYPE is the return type of scalar function -macro_rules! compute_primitive_decimal_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $RET_TYPE:expr) => {{ - // generate the scalar function name, such as add_decimal_dyn_scalar, - // from the $OP parameter (which could have a value of add) and the - // suffix _decimal_dyn_scalar - if let Some(value) = $RIGHT { - Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}( - $LEFT, value, $RET_TYPE, - )?) - } else { - // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE - Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) - } - }}; -} - -/// Invoke a compute kernel on array(s) -macro_rules! compute_op { - // invoke binary operator - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new($OP(&ll, &rr)?)) - }}; - // invoke unary operator - ($OPERAND:expr, $OP:ident, $DT:ident) => {{ - let operand = $OPERAND - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - Ok(Arc::new($OP(&operand)?)) - }}; -} - macro_rules! binary_string_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray), - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Data type {:?} not supported for binary operation '{}' on string arrays", other, stringify!($OP) - ))), - } - }}; -} - -/// Invoke a compute kernel on a pair of arrays -/// The binary_primitive_array_op macro only evaluates for primitive types -/// like integers and floats. -macro_rules! binary_primitive_array_op_dyn { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $RET_TYPE:expr) => {{ - match $LEFT.data_type() { - DataType::Decimal128(_, _) => { - Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT, $RET_TYPE)?) - } - DataType::Dictionary(_, value_type) - if matches!(value_type.as_ref(), &DataType::Decimal128(_, _)) => - { - Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT, $RET_TYPE)?) - } - _ => Ok(Arc::new( - $OP(&$LEFT, &$RIGHT).map_err(|err| DataFusionError::ArrowError(err))?, - )), - } - }}; -} - -/// Invoke a compute dyn kernel on an array and a scalar -/// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive -/// types like integers and floats. -macro_rules! binary_primitive_array_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $RET_TYPE:expr) => {{ - // unwrap underlying (non dictionary) value - let right = unwrap_dict_value($RIGHT); - let op_type = $LEFT.data_type(); - - let result: Result> = match right { - ScalarValue::Decimal128(v, _, _) => compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, op_type, $RET_TYPE), - ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int8Type), - ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int16Type), - ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int32Type), - ScalarValue::Int64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int64Type), - ScalarValue::UInt8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, UInt8Type), - ScalarValue::UInt16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, UInt16Type), - ScalarValue::UInt32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, UInt32Type), - ScalarValue::UInt64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, UInt64Type), - ScalarValue::Float32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Float32Type), - ScalarValue::Float64(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Float64Type), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on dyn array", - other, stringify!($OP))) - ) - }; - - Some(result) - }} -} - -/// The binary_array_op macro includes types that extend beyond the primitive, -/// such as Utf8 strings. -#[macro_export] -macro_rules! binary_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray), - DataType::Decimal128(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, Decimal128Array), - DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_f32_op!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_f64_op!($LEFT, $RIGHT, $OP, Float64Array), - DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), - DataType::Binary => compute_binary_op!($LEFT, $RIGHT, $OP, BinaryArray), - DataType::LargeBinary => compute_binary_op!($LEFT, $RIGHT, $OP, LargeBinaryArray), - DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray), - - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - compute_op!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - compute_op!($LEFT, $RIGHT, $OP, TimestampMillisecondArray) - } - DataType::Timestamp(TimeUnit::Second, _) => { - compute_op!($LEFT, $RIGHT, $OP, TimestampSecondArray) - } - DataType::Date32 => { - compute_op!($LEFT, $RIGHT, $OP, Date32Array) - } - DataType::Date64 => { - compute_op!($LEFT, $RIGHT, $OP, Date64Array) - } - DataType::Time32(TimeUnit::Second) => { - compute_op!($LEFT, $RIGHT, $OP, Time32SecondArray) - } - DataType::Time32(TimeUnit::Millisecond) => { - compute_op!($LEFT, $RIGHT, $OP, Time32MillisecondArray) - } - DataType::Time64(TimeUnit::Microsecond) => { - compute_op!($LEFT, $RIGHT, $OP, Time64MicrosecondArray) - } - DataType::Time64(TimeUnit::Nanosecond) => { - compute_op!($LEFT, $RIGHT, $OP, Time64NanosecondArray) - } - DataType::Boolean => compute_bool_op!($LEFT, $RIGHT, $OP, BooleanArray), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for binary operation '{}' on dyn arrays", - other, stringify!($OP) - ))), + ), } }}; } @@ -579,10 +164,10 @@ macro_rules! binary_string_array_flag_op { DataType::LargeUtf8 => { compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array", other, stringify!($OP) - ))), + ), } }}; } @@ -621,10 +206,10 @@ macro_rules! binary_string_array_flag_op_scalar { DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", other, stringify!($OP) - ))), + ), }; Some(result) }}; @@ -647,10 +232,10 @@ macro_rules! compute_utf8_flag_op_scalar { } Ok(Arc::new(array)) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", $RIGHT, stringify!($OP) - ))) + ) } }}; } @@ -674,65 +259,57 @@ impl PhysicalExpr for BinaryExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let left_value = self.left.evaluate(batch)?; - let right_value = self.right.evaluate(batch)?; - let left_data_type = left_value.data_type(); - let right_data_type = right_value.data_type(); + use arrow::compute::kernels::numeric::*; + + let lhs = self.left.evaluate(batch)?; + let rhs = self.right.evaluate(batch)?; + let left_data_type = lhs.data_type(); + let right_data_type = rhs.data_type(); let schema = batch.schema(); let input_schema = schema.as_ref(); - // Coerce decimal types to the same scale and precision - let coerced_type = coercion_decimal_mathematics_type( - &self.op, - &left_data_type, - &right_data_type, - ); - let (left_value, right_value) = if let Some(coerced_type) = coerced_type { - let options = CastOptions::default(); - let left_value = cast_column(&left_value, &coerced_type, Some(&options))?; - let right_value = cast_column(&right_value, &coerced_type, Some(&options))?; - (left_value, right_value) - } else { - // No need to coerce if it is not decimal or not math operation - (left_value, right_value) - }; + match self.op { + Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), + Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), + Operator::Divide => return apply(&lhs, &rhs, div), + Operator::Modulo => return apply(&lhs, &rhs, rem), + Operator::Eq => return apply_cmp(&lhs, &rhs, eq), + Operator::NotEq => return apply_cmp(&lhs, &rhs, neq), + Operator::Lt => return apply_cmp(&lhs, &rhs, lt), + Operator::Gt => return apply_cmp(&lhs, &rhs, gt), + Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq), + Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq), + Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct), + Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct), + _ => {} + } let result_type = self.data_type(input_schema)?; // Attempt to use special kernels if one input is scalar and the other is an array - let scalar_result = match (&left_value, &right_value) { + let scalar_result = match (&lhs, &rhs) { (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { // if left is array and right is literal - use scalar operations - self.evaluate_array_scalar(array, scalar.clone(), &result_type)? - .map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) - } - (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => { - // if right is literal and left is array - reverse operator and parameters - self.evaluate_scalar_array(scalar.clone(), array)? + self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { + r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) + }) } (_, _) => None, // default to array implementation }; if let Some(result) = scalar_result { - return result.map(|a| ColumnarValue::Array(a)); + return result.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution let (left, right) = ( - left_value.into_array(batch.num_rows()), - right_value.into_array(batch.num_rows()), + lhs.into_array(batch.num_rows())?, + rhs.into_array(batch.num_rows())?, ); - self.evaluate_with_resolved_args( - left, - &left_data_type, - right, - &right_data_type, - &result_type, - ) - .map(|a| ColumnarValue::Array(a)) + self.evaluate_with_resolved_args(left, &left_data_type, right, &right_data_type) + .map(ColumnarValue::Array) } fn children(&self) -> Vec> { @@ -750,55 +327,6 @@ impl PhysicalExpr for BinaryExpr { ))) } - /// Return the boundaries of this binary expression's result. - fn analyze(&self, context: AnalysisContext) -> AnalysisContext { - match &self.op { - Operator::Eq - | Operator::Gt - | Operator::Lt - | Operator::LtEq - | Operator::GtEq => { - // We currently only support comparison when we know at least one of the sides are - // a known value (a scalar). This includes predicates like a > 20 or 5 > a. - let context = self.left.analyze(context); - let left_boundaries = - analysis_expect!(context, context.boundaries()).clone(); - - let context = self.right.analyze(context); - let right_boundaries = - analysis_expect!(context, context.boundaries.clone()); - - match (left_boundaries.reduce(), right_boundaries.reduce()) { - (_, Some(right_value)) => { - // We know the right side is a scalar, so we can use the operator as is - analyze_expr_scalar_comparison( - context, - &self.op, - &self.left, - right_value, - ) - } - (Some(left_value), _) => { - // If not, we have to swap the operator and left/right (since this means - // left has to be a scalar). - let swapped_op = analysis_expect!(context, self.op.swap()); - analyze_expr_scalar_comparison( - context, - &swapped_op, - &self.right, - left_value, - ) - } - _ => { - // Both sides are columns, so we give up. - context.with_boundaries(None) - } - } - } - _ => context.with_boundaries(None), - } - } - fn evaluate_bounds(&self, children: &[&Interval]) -> Result { // Get children intervals: let left_interval = children[0]; @@ -811,31 +339,121 @@ impl PhysicalExpr for BinaryExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { - // Get children intervals. Graph brings + ) -> Result>> { + // Get children intervals. let left_interval = children[0]; let right_interval = children[1]; - let (left, right) = if self.op.is_logic_operator() { - // TODO: Currently, this implementation only supports the AND operator - // and does not require any further propagation. In the future, - // upon adding support for additional logical operators, this - // method will require modification to support propagating the - // changes accordingly. - return Ok(vec![]); - } else if self.op.is_comparison_operator() { - if interval == &Interval::CERTAINLY_FALSE { - // TODO: We will handle strictly false clauses by negating - // the comparison operator (e.g. GT to LE, LT to GE) - // once open/closed intervals are supported. - return Ok(vec![]); + + if self.op.eq(&Operator::And) { + if interval.eq(&Interval::CERTAINLY_TRUE) { + // A certainly true logical conjunction can only derive from possibly + // true operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_FALSE) + && !right_interval.eq(&Interval::CERTAINLY_FALSE)) + .then(|| vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_TRUE])) + } else if interval.eq(&Interval::CERTAINLY_FALSE) { + // If the logical conjunction is certainly false, one of the + // operands must be false. However, it's not always possible to + // determine which operand is false, leading to different scenarios. + + // If one operand is certainly true and the other one is uncertain, + // then the latter must be certainly false. + if left_interval.eq(&Interval::CERTAINLY_TRUE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_TRUE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } + // If both children are uncertain, or if one is certainly false, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical conjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) + } + } else if self.op.eq(&Operator::Or) { + if interval.eq(&Interval::CERTAINLY_FALSE) { + // A certainly false logical conjunction can only derive from certainly + // false operands. Otherwise, we prove infeasability. + Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) + && !right_interval.eq(&Interval::CERTAINLY_TRUE)) + .then(|| vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE])) + } else if interval.eq(&Interval::CERTAINLY_TRUE) { + // If the logical disjunction is certainly true, one of the + // operands must be true. However, it's not always possible to + // determine which operand is true, leading to different scenarios. + + // If one operand is certainly false and the other one is uncertain, + // then the latter must be certainly true. + if left_interval.eq(&Interval::CERTAINLY_FALSE) + && right_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_FALSE, + Interval::CERTAINLY_TRUE, + ])) + } else if right_interval.eq(&Interval::CERTAINLY_FALSE) + && left_interval.eq(&Interval::UNCERTAIN) + { + Ok(Some(vec![ + Interval::CERTAINLY_TRUE, + Interval::CERTAINLY_FALSE, + ])) + } + // If both children are uncertain, or if one is certainly true, + // we cannot conclusively refine their intervals. In this case, + // propagation does not result in any interval changes. + else { + Ok(Some(vec![])) + } + } else { + // An uncertain logical disjunction result can not shrink the + // end-points of its children. + Ok(Some(vec![])) } - // Propagate the comparison operator. - propagate_comparison(&self.op, left_interval, right_interval)? + } else if self.op.is_comparison_operator() { + Ok( + propagate_comparison(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) } else { - // Propagate the arithmetic operator. - propagate_arithmetic(&self.op, interval, left_interval, right_interval)? - }; - Ok(vec![left, right]) + Ok( + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + .map(|(left, right)| vec![left, right]), + ) + } + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + /// For each operator, [`BinaryExpr`] has distinct ordering rules. + /// TODO: There may be rules specific to some data types (such as division and multiplication on unsigned integers) + fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { + let (left_child, right_child) = (&children[0], &children[1]); + match self.op() { + Operator::Plus => left_child.add(right_child), + Operator::Minus => left_child.sub(right_child), + Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), + Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), + Operator::And | Operator::Or => left_child.and_or(right_child), + _ => SortProperties::Unordered, + } } } @@ -848,199 +466,6 @@ impl PartialEq for BinaryExpr { } } -// Analyze the comparison between an expression (on the left) and a scalar value -// (on the right). The new boundaries will indicate whether it is always true, always -// false, or unknown (with a probablistic selectivity value attached). This operation -// will also include the new upper/lower boundaries for the operand on the left if -// they can be determined. -fn analyze_expr_scalar_comparison( - context: AnalysisContext, - op: &Operator, - left: &Arc, - right: ScalarValue, -) -> AnalysisContext { - let left_bounds = analysis_expect!(context, left.analyze(context.clone()).boundaries); - let left_min = left_bounds.min_value; - let left_max = left_bounds.max_value; - - // Direct selectivity is applicable when we can determine that this comparison will - // always be true or false (e.g. `x > 10` where the `x`'s min value is 11 or `a < 5` - // where the `a`'s max value is 4). - let (always_selects, never_selects) = match op { - Operator::Lt => (right > left_max, right <= left_min), - Operator::LtEq => (right >= left_max, right < left_min), - Operator::Gt => (right < left_min, right >= left_max), - Operator::GtEq => (right <= left_min, right > left_max), - Operator::Eq => ( - // Since min/max can be artificial (e.g. the min or max value of a column - // might be under/over the real value), we can't assume if the right equals - // to any left.min / left.max values it is always going to be selected. But - // we can assume that if the range(left) doesn't overlap with right, it is - // never going to be selected. - false, - right < left_min || right > left_max, - ), - _ => unreachable!(), - }; - - // Both can not be true at the same time. - assert!(!(always_selects && never_selects)); - - let selectivity = match (always_selects, never_selects) { - (true, _) => 1.0, - (_, true) => 0.0, - (false, false) => { - // If there is a partial overlap, then we can estimate the selectivity - // by computing the ratio of the existing overlap to the total range. Since we - // currently don't have access to a value distribution histogram, the part below - // assumes a uniform distribution by default. - - // Our [min, max] is inclusive, so we need to add 1 to the difference. - let total_range = analysis_expect!(context, left_max.distance(&left_min)) + 1; - let overlap_between_boundaries = analysis_expect!( - context, - match op { - Operator::Lt => right.distance(&left_min), - Operator::Gt => left_max.distance(&right), - Operator::LtEq => right.distance(&left_min).map(|dist| dist + 1), - Operator::GtEq => left_max.distance(&right).map(|dist| dist + 1), - Operator::Eq => Some(1), - _ => None, - } - ); - - overlap_between_boundaries as f64 / total_range as f64 - } - }; - - // The context represents all the knowledge we have gathered during the - // analysis process, which we can now add more since the expression's upper - // and lower boundaries might have changed. - let context = match left.as_any().downcast_ref::() { - Some(column_expr) => { - let (left_min, left_max) = match op { - // TODO: for lt/gt, we technically should shrink the possibility space - // by one since a < 5 means that 5 is not a possible value for `a`. However, - // it is currently tricky to do so (e.g. for floats, we can get away with 4.999 - // so we need a smarter logic to find out what is the closest value that is - // different from the scalar_value). - Operator::Lt | Operator::LtEq => { - // We only want to update the upper bound when we know it will help us (e.g. - // it is actually smaller than what we have right now) and it is a valid - // value (e.g. [0, 100] < -100 would update the boundaries to [0, -100] if - // there weren't the selectivity check). - if right < left_max && selectivity > 0.0 { - (left_min, right) - } else { - (left_min, left_max) - } - } - Operator::Gt | Operator::GtEq => { - // Same as above, but this time we want to limit the lower bound. - if right > left_min && selectivity > 0.0 { - (right, left_max) - } else { - (left_min, left_max) - } - } - // For equality, we don't have the range problem so even if the selectivity - // is 0.0, we can still update the boundaries. - Operator::Eq => (right.clone(), right), - _ => unreachable!(), - }; - - let left_bounds = - ExprBoundaries::new(left_min, left_max, left_bounds.distinct_count); - context.with_column_update(column_expr.index(), left_bounds) - } - None => context, - }; - - // The selectivity can't be be greater than 1.0. - assert!(selectivity <= 1.0); - - let (pred_min, pred_max, pred_distinct) = match (always_selects, never_selects) { - (false, true) => (false, false, 1), - (true, false) => (true, true, 1), - _ => (false, true, 2), - }; - - let result_boundaries = Some(ExprBoundaries::new_with_selectivity( - ScalarValue::Boolean(Some(pred_min)), - ScalarValue::Boolean(Some(pred_max)), - Some(pred_distinct), - Some(selectivity), - )); - context.with_boundaries(result_boundaries) -} - -/// unwrap underlying (non dictionary) value, if any, to pass to a scalar kernel -fn unwrap_dict_value(v: ScalarValue) -> ScalarValue { - if let ScalarValue::Dictionary(_key_type, v) = v { - unwrap_dict_value(*v) - } else { - v - } -} - -/// The binary_array_op_dyn_scalar macro includes types that extend -/// beyond the primitive, such as Utf8 strings. -#[macro_export] -macro_rules! binary_array_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - // unwrap underlying (non dictionary) value - let right = unwrap_dict_value($RIGHT); - - let result: Result> = match right { - ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE), - ScalarValue::Decimal128(..) => compute_decimal_op_dyn_scalar!($LEFT, right, $OP, $OP_TYPE), - ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Binary(v) => compute_binary_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::LargeBinary(v) => compute_binary_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Date32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Date64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Time32Second(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Time32Millisecond(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Time64Microsecond(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::Time64Nanosecond(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::TimestampSecond(v, _) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::TimestampMillisecond(v, _) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::TimestampMicrosecond(v, _) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - ScalarValue::TimestampNanosecond(v, _) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on dyn array", - other, stringify!($OP))) - ) - }; - Some(result) - }} -} - -/// Compares the array with the scalar value for equality, sometimes -/// used in other kernels -pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result { - binary_array_op_dyn_scalar!(lhs, rhs.clone(), eq, &DataType::Boolean).ok_or_else( - || { - DataFusionError::Internal(format!( - "Data type {:?} and scalar {:?} not supported for array_eq_scalar", - lhs.data_type(), - rhs.get_datatype() - )) - }, - )? -} - /// Casts dictionary array to result type for binary numerical operators. Such operators /// between array and scalar produce a dictionary array other than primitive array of the /// same operators between array and array. This leads to inconsistent result types causing @@ -1059,10 +484,9 @@ fn to_result_type_array( if value_type.as_ref() == result_type { Ok(cast(&array, result_type)?) } else { - Err(DataFusionError::Internal(format!( - "Incompatible Dictionary value type {:?} with result type {:?} of Binary operator {:?}", - value_type, result_type, op - ))) + internal_err!( + "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}" + ) } } _ => Ok(array), @@ -1079,38 +503,9 @@ impl BinaryExpr { &self, array: &dyn Array, scalar: ScalarValue, - result_type: &DataType, ) -> Result>> { use Operator::*; - let bool_type = &DataType::Boolean; let scalar_result = match &self.op { - Lt => binary_array_op_dyn_scalar!(array, scalar, lt, bool_type), - LtEq => binary_array_op_dyn_scalar!(array, scalar, lt_eq, bool_type), - Gt => binary_array_op_dyn_scalar!(array, scalar, gt, bool_type), - GtEq => binary_array_op_dyn_scalar!(array, scalar, gt_eq, bool_type), - Eq => binary_array_op_dyn_scalar!(array, scalar, eq, bool_type), - NotEq => binary_array_op_dyn_scalar!(array, scalar, neq, bool_type), - Plus => { - binary_primitive_array_op_dyn_scalar!(array, scalar, add, result_type) - } - Minus => binary_primitive_array_op_dyn_scalar!( - array, - scalar, - subtract, - result_type - ), - Multiply => binary_primitive_array_op_dyn_scalar!( - array, - scalar, - multiply, - result_type - ), - Divide => { - binary_primitive_array_op_dyn_scalar!(array, scalar, divide, result_type) - } - Modulo => { - binary_primitive_array_op_dyn_scalar!(array, scalar, modulus, result_type) - } RegexMatch => binary_string_array_flag_op_scalar!( array, scalar, @@ -1151,88 +546,39 @@ impl BinaryExpr { Ok(scalar_result) } - /// Evaluate the expression if the left input is a literal and the - /// right is an array - reverse operator and parameters - fn evaluate_scalar_array( - &self, - scalar: ScalarValue, - array: &ArrayRef, - ) -> Result>> { - use Operator::*; - let bool_type = &DataType::Boolean; - let scalar_result = match &self.op { - Lt => binary_array_op_dyn_scalar!(array, scalar, gt, bool_type), - LtEq => binary_array_op_dyn_scalar!(array, scalar, gt_eq, bool_type), - Gt => binary_array_op_dyn_scalar!(array, scalar, lt, bool_type), - GtEq => binary_array_op_dyn_scalar!(array, scalar, lt_eq, bool_type), - Eq => binary_array_op_dyn_scalar!(array, scalar, eq, bool_type), - NotEq => binary_array_op_dyn_scalar!(array, scalar, neq, bool_type), - // if scalar operation is not supported - fallback to array implementation - _ => None, - }; - Ok(scalar_result) - } - fn evaluate_with_resolved_args( &self, left: Arc, left_data_type: &DataType, right: Arc, right_data_type: &DataType, - result_type: &DataType, ) -> Result { use Operator::*; match &self.op { - Lt => lt_dyn(&left, &right), - LtEq => lt_eq_dyn(&left, &right), - Gt => gt_dyn(&left, &right), - GtEq => gt_eq_dyn(&left, &right), - Eq => eq_dyn(&left, &right), - NotEq => neq_dyn(&left, &right), - IsDistinctFrom => { - match (left_data_type, right_data_type) { - // exchange lhs and rhs when lhs is Null, since `binary_array_op` is - // always try to down cast array according to $LEFT expression. - (DataType::Null, _) => { - binary_array_op!(right, left, is_distinct_from) - } - _ => binary_array_op!(left, right, is_distinct_from), - } - } - IsNotDistinctFrom => binary_array_op!(left, right, is_not_distinct_from), - Plus => binary_primitive_array_op_dyn!(left, right, add_dyn, result_type), - Minus => { - binary_primitive_array_op_dyn!(left, right, subtract_dyn, result_type) - } - Multiply => { - binary_primitive_array_op_dyn!(left, right, multiply_dyn, result_type) - } - Divide => { - binary_primitive_array_op_dyn!(left, right, divide_dyn_opt, result_type) - } - Modulo => { - binary_primitive_array_op_dyn!(left, right, modulus_dyn, result_type) - } + IsDistinctFrom | IsNotDistinctFrom | Lt | LtEq | Gt | GtEq | Eq | NotEq + | Plus | Minus | Multiply | Divide | Modulo => unreachable!(), And => { if left_data_type == &DataType::Boolean { boolean_op!(&left, &right, and_kleene) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", self.op, left.data_type(), right.data_type() - ))) + ) } } Or => { if left_data_type == &DataType::Boolean { boolean_op!(&left, &right, or_kleene) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, left_data_type, right_data_type - ))) + self.op, + left_data_type, + right_data_type + ) } } RegexMatch => { @@ -1252,9 +598,14 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => { - binary_string_array_op!(left, right, concat_elements) - } + StringConcat => match (left_data_type, right_data_type) { + (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]), + (DataType::List(_), _) => array_append(&[left, right]), + (_, DataType::List(_)) => array_prepend(&[left, right]), + _ => binary_string_array_op!(left, right, concat_elements), + }, + AtArrow => array_has_all(&[left, right]), + ArrowAt => array_has_all(&[right, left]), } } } @@ -1266,76 +617,36 @@ pub fn binary( lhs: Arc, op: Operator, rhs: Arc, - input_schema: &Schema, + _input_schema: &Schema, ) -> Result> { - let lhs_type = &lhs.data_type(input_schema)?; - let rhs_type = &rhs.data_type(input_schema)?; - if (is_utf8_or_large_utf8(lhs_type) && is_timestamp(rhs_type)) - || (is_timestamp(lhs_type) && is_utf8_or_large_utf8(rhs_type)) - { - return Err(DataFusionError::Plan(format!( - "The type of {lhs_type} {op:?} {rhs_type} of binary physical should be same" - ))); - } - if !lhs_type.eq(rhs_type) && (!is_decimal(lhs_type) && !is_decimal(rhs_type)) { - return Err(DataFusionError::Internal(format!( - "The type of {lhs_type} {op:?} {rhs_type} of binary physical should be same" - ))); - } Ok(Arc::new(BinaryExpr::new(lhs, op, rhs))) } -pub fn resolve_temporal_op( - lhs: &ArrayRef, - sign: i32, - rhs: &ArrayRef, -) -> Result { - match sign { - 1 => add_dyn_temporal(lhs, rhs), - -1 => subtract_dyn_temporal(lhs, rhs), - other => Err(DataFusionError::Internal(format!( - "Undefined operation for temporal types {other}" - ))), - } -} - -pub fn resolve_temporal_op_scalar( - arr: &ArrayRef, - sign: i32, - scalar: &ScalarValue, - swap: bool, -) -> Result { - match (sign, swap) { - (1, false) => add_dyn_temporal_right_scalar(arr, scalar), - (1, true) => add_dyn_temporal_left_scalar(scalar, arr), - (-1, false) => subtract_dyn_temporal_right_scalar(arr, scalar), - (-1, true) => subtract_dyn_temporal_left_scalar(scalar, arr), - _ => Err(DataFusionError::Internal( - "Undefined operation for temporal types".to_string(), - )), - } -} - #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; - use crate::expressions::{try_cast, Literal}; + use crate::expressions::{col, lit, try_cast, Literal}; use arrow::datatypes::{ ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef, }; - use datafusion_common::{ColumnStatistics, Result, Statistics}; - use datafusion_expr::type_coercion::binary::{coerce_types, math_decimal_coercion}; + use arrow_schema::ArrowError; + use datafusion_common::Result; + use datafusion_expr::type_coercion::binary::get_input_types; - // Create a binary expression without coercion. Used here when we do not want to coerce the expressions - // to valid types. Usage can result in an execution (after plan) error. - fn binary_simple( - l: Arc, + /// Performs a binary operation, applying any type coercion necessary + fn binary_op( + left: Arc, op: Operator, - r: Arc, - input_schema: &Schema, - ) -> Arc { - binary(l, op, r, input_schema).unwrap() + right: Arc, + schema: &Schema, + ) -> Result> { + let left_type = left.data_type(schema)?; + let right_type = right.data_type(schema)?; + let (lhs, rhs) = get_input_types(&left_type, &op, &right_type)?; + + let left_expr = try_cast(left, schema, lhs)?; + let right_expr = try_cast(right, schema, rhs)?; + binary(left_expr, op, right_expr, schema) } #[test] @@ -1348,19 +659,22 @@ mod tests { let b = Int32Array::from(vec![1, 2, 4, 8, 16]); // expression: "a < b" - let lt = binary_simple( + let lt = binary( col("a", &schema)?, Operator::Lt, col("b", &schema)?, &schema, - ); + )?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; - let result = lt.evaluate(&batch)?.into_array(batch.num_rows()); + let result = lt + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); - let expected = vec![false, false, true, true, true]; + let expected = [false, false, true, true, true]; let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); for (i, &expected_item) in expected.iter().enumerate().take(5) { @@ -1380,31 +694,34 @@ mod tests { let b = Int32Array::from(vec![2, 5, 4, 8, 8]); // expression: "a < b OR a == b" - let expr = binary_simple( - binary_simple( + let expr = binary( + binary( col("a", &schema)?, Operator::Lt, col("b", &schema)?, &schema, - ), + )?, Operator::Or, - binary_simple( + binary( col("a", &schema)?, Operator::Eq, col("b", &schema)?, &schema, - ), + )?, &schema, - ); + )?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{expr}")); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.len(), 5); - let expected = vec![true, true, false, true, false]; + let expected = [true, true, false, true, false]; let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); for (i, &expected_item) in expected.iter().enumerate().take(5) { @@ -1429,10 +746,10 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); - let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + let (lhs, rhs) = get_input_types(&$A_TYPE, &$OP, &$B_TYPE)?; - let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?; - let right = try_cast(col("b", &schema)?, &schema, common_type)?; + let left = try_cast(col("a", &schema)?, &schema, lhs)?; + let right = try_cast(col("b", &schema)?, &schema, rhs)?; // verify that we can construct the expression let expression = binary(left, $OP, right, &schema)?; @@ -1445,7 +762,7 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -1468,7 +785,7 @@ mod tests { } #[test] - fn test_type_coersion() -> Result<()> { + fn test_type_coercion() -> Result<()> { test_coercion!( Int32Array, DataType::Int32, @@ -1479,7 +796,7 @@ mod tests { Operator::Plus, Int32Array, DataType::Int32, - vec![2i32, 4i32], + [2i32, 4i32], ); test_coercion!( Int32Array, @@ -1491,7 +808,7 @@ mod tests { Operator::Plus, Int32Array, DataType::Int32, - vec![2i32], + [2i32], ); test_coercion!( Float32Array, @@ -1503,7 +820,7 @@ mod tests { Operator::Plus, Float32Array, DataType::Float32, - vec![2f32], + [2f32], ); test_coercion!( Float32Array, @@ -1515,7 +832,7 @@ mod tests { Operator::Multiply, Float32Array, DataType::Float32, - vec![2f32], + [2f32], ); test_coercion!( StringArray, @@ -1527,7 +844,7 @@ mod tests { Operator::Eq, BooleanArray, DataType::Boolean, - vec![true, true], + [true, true], ); test_coercion!( StringArray, @@ -1539,7 +856,7 @@ mod tests { Operator::Lt, BooleanArray, DataType::Boolean, - vec![true, false], + [true, false], ); test_coercion!( StringArray, @@ -1551,7 +868,7 @@ mod tests { Operator::Eq, BooleanArray, DataType::Boolean, - vec![true, true], + [true, true], ); test_coercion!( StringArray, @@ -1563,7 +880,7 @@ mod tests { Operator::Lt, BooleanArray, DataType::Boolean, - vec![true, false], + [true, false], ); test_coercion!( StringArray, @@ -1575,7 +892,7 @@ mod tests { Operator::RegexMatch, BooleanArray, DataType::Boolean, - vec![true, false, true, false, false], + [true, false, true, false, false], ); test_coercion!( StringArray, @@ -1587,7 +904,7 @@ mod tests { Operator::RegexIMatch, BooleanArray, DataType::Boolean, - vec![true, true, true, true, false], + [true, true, true, true, false], ); test_coercion!( StringArray, @@ -1599,7 +916,7 @@ mod tests { Operator::RegexNotMatch, BooleanArray, DataType::Boolean, - vec![false, true, false, true, true], + [false, true, false, true, true], ); test_coercion!( StringArray, @@ -1611,7 +928,7 @@ mod tests { Operator::RegexNotIMatch, BooleanArray, DataType::Boolean, - vec![false, false, false, false, true], + [false, false, false, false, true], ); test_coercion!( LargeStringArray, @@ -1623,7 +940,7 @@ mod tests { Operator::RegexMatch, BooleanArray, DataType::Boolean, - vec![true, false, true, false, false], + [true, false, true, false, false], ); test_coercion!( LargeStringArray, @@ -1635,7 +952,7 @@ mod tests { Operator::RegexIMatch, BooleanArray, DataType::Boolean, - vec![true, true, true, true, false], + [true, true, true, true, false], ); test_coercion!( LargeStringArray, @@ -1647,7 +964,7 @@ mod tests { Operator::RegexNotMatch, BooleanArray, DataType::Boolean, - vec![false, true, false, true, true], + [false, true, false, true, true], ); test_coercion!( LargeStringArray, @@ -1659,7 +976,7 @@ mod tests { Operator::RegexNotIMatch, BooleanArray, DataType::Boolean, - vec![false, false, false, false, true], + [false, false, false, false, true], ); test_coercion!( Int16Array, @@ -1671,7 +988,7 @@ mod tests { Operator::BitwiseAnd, Int64Array, DataType::Int64, - vec![0i64, 0i64, 1i64], + [0i64, 0i64, 1i64], ); test_coercion!( UInt16Array, @@ -1683,7 +1000,19 @@ mod tests { Operator::BitwiseAnd, UInt64Array, DataType::UInt64, - vec![0u64, 0u64, 1u64], + [0u64, 0u64, 1u64], + ); + test_coercion!( + Int16Array, + DataType::Int16, + vec![3i16, 2i16, 3i16], + Int64Array, + DataType::Int64, + vec![10i64, 6i64, 5i64], + Operator::BitwiseOr, + Int64Array, + DataType::Int64, + [11i64, 6i64, 7i64], ); test_coercion!( UInt16Array, @@ -1695,7 +1024,7 @@ mod tests { Operator::BitwiseOr, UInt64Array, DataType::UInt64, - vec![11u64, 6u64, 7u64], + [11u64, 6u64, 7u64], ); test_coercion!( Int16Array, @@ -1707,7 +1036,7 @@ mod tests { Operator::BitwiseXor, Int64Array, DataType::Int64, - vec![9i64, 4i64, 6i64], + [9i64, 4i64, 6i64], ); test_coercion!( UInt16Array, @@ -1719,7 +1048,55 @@ mod tests { Operator::BitwiseXor, UInt64Array, DataType::UInt64, - vec![9u64, 4u64, 6u64], + [9u64, 4u64, 6u64], + ); + test_coercion!( + Int16Array, + DataType::Int16, + vec![4i16, 27i16, 35i16], + Int64Array, + DataType::Int64, + vec![2i64, 3i64, 4i64], + Operator::BitwiseShiftRight, + Int64Array, + DataType::Int64, + [1i64, 3i64, 2i64], + ); + test_coercion!( + UInt16Array, + DataType::UInt16, + vec![4u16, 27u16, 35u16], + UInt64Array, + DataType::UInt64, + vec![2u64, 3u64, 4u64], + Operator::BitwiseShiftRight, + UInt64Array, + DataType::UInt64, + [1u64, 3u64, 2u64], + ); + test_coercion!( + Int16Array, + DataType::Int16, + vec![2i16, 3i16, 4i16], + Int64Array, + DataType::Int64, + vec![4i64, 12i64, 7i64], + Operator::BitwiseShiftLeft, + Int64Array, + DataType::Int64, + [32i64, 12288i64, 512i64], + ); + test_coercion!( + UInt16Array, + DataType::UInt16, + vec![2u16, 3u16, 4u16], + UInt64Array, + DataType::UInt64, + vec![4u64, 12u64, 7u64], + Operator::BitwiseShiftLeft, + UInt64Array, + DataType::UInt64, + [32u64, 12288u64, 512u64], ); Ok(()) } @@ -1730,8 +1107,7 @@ mod tests { // is no way at the time of this writing to create a dictionary // array using the `From` trait #[test] - #[cfg(feature = "dictionary_expressions")] - fn test_dictionary_type_to_array_coersion() -> Result<()> { + fn test_dictionary_type_to_array_coercion() -> Result<()> { // Test string a string dictionary let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); @@ -1794,7 +1170,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn plus_op_dict() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -1828,7 +1203,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn plus_op_dict_decimal() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2012,7 +1386,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn minus_op_dict() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2046,7 +1419,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn minus_op_dict_decimal() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2222,7 +1594,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn multiply_op_dict() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2256,7 +1627,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn multiply_op_dict_decimal() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2430,7 +1800,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn divide_op_dict() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2470,7 +1839,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn divide_op_dict_decimal() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2525,14 +1893,14 @@ mod tests { Operator::Divide, create_decimal_array( &[ - Some(99193548387), // 0.99193548387 + Some(9919), // 0.9919 None, None, - Some(100813008130), // 1.0081300813 - Some(100000000000), // 1.0 + Some(10081), // 1.0081 + Some(10000), // 1.0 ], - 21, - 11, + 14, + 4, ), )?; @@ -2611,15 +1979,9 @@ mod tests { let a = DictionaryArray::try_new(keys, decimal_array)?; let decimal_array = Arc::new(create_decimal_array( - &[ - Some(6150000000000), - Some(6100000000000), - None, - Some(6200000000000), - Some(6150000000000), - ], - 21, - 11, + &[Some(615000), Some(610000), None, Some(620000), Some(615000)], + 14, + 4, )); apply_arithmetic_scalar( @@ -2656,7 +2018,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn modulus_op_dict() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2696,7 +2057,6 @@ mod tests { } #[test] - #[cfg(feature = "dictionary_expressions")] fn modulus_op_dict_decimal() -> Result<()> { let schema = Schema::new(vec![ Field::new( @@ -2853,9 +2213,12 @@ mod tests { expected: PrimitiveArray, ) -> Result<()> { let arithmetic_op = - binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema); + binary_op(col("a", &schema)?, op, col("b", &schema)?, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2869,9 +2232,12 @@ mod tests { expected: ArrayRef, ) -> Result<()> { let lit = Arc::new(Literal::new(literal)); - let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, &schema); + let arithmetic_op = binary_op(col("a", &schema)?, op, lit, &schema)?; let batch = RecordBatch::try_new(schema, data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(&result, &expected); Ok(()) @@ -2884,16 +2250,13 @@ mod tests { op: Operator, expected: BooleanArray, ) -> Result<()> { - let left_type = left.data_type(); - let right_type = right.data_type(); - let common_type = coerce_types(left_type, &op, right_type)?; - - let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?; - let right_expr = try_cast(col("b", schema)?, schema, common_type)?; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), &expected); Ok(()) @@ -2908,21 +2271,12 @@ mod tests { expected: &BooleanArray, ) -> Result<()> { let scalar = lit(scalar.clone()); - let op_type = coerce_types(&scalar.data_type(schema)?, &op, arr.data_type())?; - let left_expr = if op_type.eq(&scalar.data_type(schema)?) { - scalar - } else { - try_cast(scalar, schema, op_type.clone())? - }; - let right_expr = if op_type.eq(arr.data_type()) { - col("a", schema)? - } else { - try_cast(col("a", schema)?, schema, op_type)? - }; - - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let op = binary_op(scalar, op, col("a", schema)?, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -2937,21 +2291,12 @@ mod tests { expected: &BooleanArray, ) -> Result<()> { let scalar = lit(scalar.clone()); - let op_type = coerce_types(arr.data_type(), &op, &scalar.data_type(schema)?)?; - let right_expr = if op_type.eq(&scalar.data_type(schema)?) { - scalar - } else { - try_cast(scalar, schema, op_type.clone())? - }; - let left_expr = if op_type.eq(arr.data_type()) { - col("a", schema)? - } else { - try_cast(col("a", schema)?, schema, op_type)? - }; - - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let op = binary_op(col("a", schema)?, op, scalar, schema)?; let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected); Ok(()) @@ -3088,14 +2433,14 @@ mod tests { /// Returns (schema, BooleanArray) with [true, NULL, false] fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let a: BooleanArray = vec![Some(true), None, Some(false)].iter().collect(); + let a: BooleanArray = [Some(true), None, Some(false)].iter().collect(); (Arc::new(schema), Arc::new(a)) } #[test] fn eq_op_bool() { let (schema, a, b) = bool_test_arrays(); - let expected = vec![ + let expected = [ Some(true), None, Some(false), @@ -3517,13 +2862,14 @@ mod tests { let tree_depth: i32 = 100; let expr = (0..tree_depth) .map(|_| col("a", schema.as_ref()).unwrap()) - .reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema)) + .reduce(|l, r| binary(l, Operator::Plus, r, &schema).unwrap()) .unwrap(); let result = expr .evaluate(&batch) .expect("evaluation") - .into_array(batch.num_rows()); + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let expected: Int32Array = input .into_iter() @@ -3999,44 +3345,13 @@ mod tests { op: Operator, expected: ArrayRef, ) -> Result<()> { - let (lhs_op_type, rhs_op_type) = - math_decimal_coercion(left.data_type(), right.data_type()); - - let (left_expr, lhs_type) = if let Some(lhs_op_type) = lhs_op_type { - ( - try_cast(col("a", schema)?, schema, lhs_op_type.clone())?, - lhs_op_type, - ) - } else { - (col("a", schema)?, left.data_type().clone()) - }; - - let (right_expr, rhs_type) = if let Some(rhs_op_type) = rhs_op_type { - ( - try_cast(col("b", schema)?, schema, rhs_op_type.clone())?, - rhs_op_type, - ) - } else { - (col("b", schema)?, right.data_type().clone()) - }; - - let coerced_schema = Schema::new(vec![ - Field::new( - schema.field(0).name(), - lhs_type, - schema.field(0).is_nullable(), - ), - Field::new( - schema.field(1).name(), - rhs_type, - schema.field(1).is_nullable(), - ), - ]); - - let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema); + let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; - let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + let result = arithmetic_op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert_eq!(result.as_ref(), expected.as_ref()); Ok(()) @@ -4121,14 +3436,9 @@ mod tests { Field::new("b", DataType::Decimal128(10, 2), true), ])); let expect = Arc::new(create_decimal_array( - &[ - Some(10000000000000), - None, - Some(10081967213114), - Some(10000000000000), - ], - 23, - 11, + &[Some(1000000), None, Some(1008196), Some(1000000)], + 16, + 4, )) as ArrayRef; apply_decimal_arithmetic_op( &schema, @@ -4287,27 +3597,31 @@ mod tests { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), ])); - let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048, 100])); - let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32, 0])); + let a = Arc::new(Int32Array::from(vec![100])); + let b = Arc::new(Int32Array::from(vec![0])); - apply_arithmetic::( + let err = apply_arithmetic::( schema, vec![a, b], Operator::Divide, - Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64), None]), - )?; + Int32Array::from(vec![Some(4), Some(8), Some(16), Some(32), Some(64)]), + ) + .unwrap_err(); + + assert!( + matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), + "{err}" + ); // decimal let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Decimal128(25, 3), true), Field::new("b", DataType::Decimal128(25, 3), true), ])); - let left_decimal_array = - Arc::new(create_decimal_array(&[Some(1234567), Some(1234567)], 25, 3)); - let right_decimal_array = - Arc::new(create_decimal_array(&[Some(10), Some(0)], 25, 3)); + let left_decimal_array = Arc::new(create_decimal_array(&[Some(1234567)], 25, 3)); + let right_decimal_array = Arc::new(create_decimal_array(&[Some(0)], 25, 3)); - apply_arithmetic::( + let err = apply_arithmetic::( schema, vec![left_decimal_array, right_decimal_array], Operator::Divide, @@ -4316,7 +3630,13 @@ mod tests { 38, 29, ), - )?; + ) + .unwrap_err(); + + assert!( + matches!(err, DataFusionError::ArrowError(ArrowError::DivideByZero)), + "{err}" + ); Ok(()) } @@ -4461,287 +3781,6 @@ mod tests { Ok(()) } - /// Return a pair of (schema, statistics) for a table with a single column (called "a") with - /// the same type as the `min_value`/`max_value`. - fn get_test_table_stats( - min_value: ScalarValue, - max_value: ScalarValue, - ) -> (Schema, Statistics) { - assert_eq!(min_value.get_datatype(), max_value.get_datatype()); - let schema = Schema::new(vec![Field::new("a", min_value.get_datatype(), false)]); - let columns = vec![ColumnStatistics { - min_value: Some(min_value), - max_value: Some(max_value), - null_count: None, - distinct_count: None, - }]; - let statistics = Statistics { - column_statistics: Some(columns), - ..Default::default() - }; - (schema, statistics) - } - - #[test] - fn test_analyze_expr_scalar_comparison() -> Result<()> { - // A table where the column 'a' has a min of 1, a max of 100. - let (schema, statistics) = - get_test_table_stats(ScalarValue::from(1i64), ScalarValue::from(100i64)); - - let cases = [ - // (operator, rhs), (expected selectivity, expected min, expected max) - // ------------------------------------------------------------------- - // - // Table: - // - a (min = 1, max = 100, distinct_count = null) - // - // Equality (a = $): - // - ((Operator::Eq, 1), (1.0 / 100.0, 1, 1)), - ((Operator::Eq, 5), (1.0 / 100.0, 5, 5)), - ((Operator::Eq, 99), (1.0 / 100.0, 99, 99)), - ((Operator::Eq, 100), (1.0 / 100.0, 100, 100)), - // For never matches like the following, we still produce the correct - // min/max values since if this condition holds by an off chance, then - // the result of expression will effectively become the = $limit. - ((Operator::Eq, 0), (0.0, 0, 0)), - ((Operator::Eq, -101), (0.0, -101, -101)), - ((Operator::Eq, 101), (0.0, 101, 101)), - // - // Less than (a < $): - // - // Note: upper bounds for less than is currently overstated (by the closest value). - // see the comment in `compare_left_boundaries` for the reason - ((Operator::Lt, 5), (4.0 / 100.0, 1, 5)), - ((Operator::Lt, 99), (98.0 / 100.0, 1, 99)), - ((Operator::Lt, 101), (100.0 / 100.0, 1, 100)), - // Unlike equality, we now have an obligation to provide a range of values here - // so if "col < -100" expr is executed, we don't want to say col can take [0, -100]. - ((Operator::Lt, 0), (0.0, 1, 100)), - ((Operator::Lt, 1), (0.0, 1, 100)), - ((Operator::Lt, -100), (0.0, 1, 100)), - ((Operator::Lt, -200), (0.0, 1, 100)), - // We also don't want to expand the range unnecessarily even if the predicate is - // successful. - ((Operator::Lt, 200), (1.0, 1, 100)), - // - // Less than or equal (a <= $): - // - ((Operator::LtEq, -100), (0.0, 1, 100)), - ((Operator::LtEq, 0), (0.0, 1, 100)), - ((Operator::LtEq, 1), (1.0 / 100.0, 1, 1)), - ((Operator::LtEq, 5), (5.0 / 100.0, 1, 5)), - ((Operator::LtEq, 99), (99.0 / 100.0, 1, 99)), - ((Operator::LtEq, 100), (100.0 / 100.0, 1, 100)), - ((Operator::LtEq, 101), (1.0, 1, 100)), - ((Operator::LtEq, 200), (1.0, 1, 100)), - // - // Greater than (a > $): - // - ((Operator::Gt, -100), (1.0, 1, 100)), - ((Operator::Gt, 0), (1.0, 1, 100)), - ((Operator::Gt, 1), (99.0 / 100.0, 1, 100)), - ((Operator::Gt, 5), (95.0 / 100.0, 5, 100)), - ((Operator::Gt, 99), (1.0 / 100.0, 99, 100)), - ((Operator::Gt, 100), (0.0, 1, 100)), - ((Operator::Gt, 101), (0.0, 1, 100)), - ((Operator::Gt, 200), (0.0, 1, 100)), - // - // Greater than or equal (a >= $): - // - ((Operator::GtEq, -100), (1.0, 1, 100)), - ((Operator::GtEq, 0), (1.0, 1, 100)), - ((Operator::GtEq, 1), (1.0, 1, 100)), - ((Operator::GtEq, 5), (96.0 / 100.0, 5, 100)), - ((Operator::GtEq, 99), (2.0 / 100.0, 99, 100)), - ((Operator::GtEq, 100), (1.0 / 100.0, 100, 100)), - ((Operator::GtEq, 101), (0.0, 1, 100)), - ((Operator::GtEq, 200), (0.0, 1, 100)), - ]; - - for ((operator, rhs), (exp_selectivity, exp_min, exp_max)) in cases { - let context = AnalysisContext::from_statistics(&schema, &statistics); - let left = col("a", &schema).unwrap(); - let right = ScalarValue::Int64(Some(rhs)); - let analysis_ctx = - analyze_expr_scalar_comparison(context, &operator, &left, right); - let boundaries = analysis_ctx - .boundaries - .as_ref() - .expect("Analysis must complete for this test!"); - - assert_eq!( - boundaries - .selectivity - .expect("compare_left_boundaries must produce a selectivity value"), - exp_selectivity - ); - - if exp_selectivity == 1.0 { - // When the expected selectivity is 1.0, the resulting expression - // should always be true. - assert_eq!(boundaries.reduce(), Some(ScalarValue::Boolean(Some(true)))); - } else if exp_selectivity == 0.0 { - // When the expected selectivity is 0.0, the resulting expression - // should always be false. - assert_eq!(boundaries.reduce(), Some(ScalarValue::Boolean(Some(false)))); - } else { - // Otherwise, it should be [false, true] (since we don't know anything for sure) - assert_eq!(boundaries.min_value, ScalarValue::Boolean(Some(false))); - assert_eq!(boundaries.max_value, ScalarValue::Boolean(Some(true))); - } - - // For getting the updated boundaries, we can simply analyze the LHS - // with the existing context. - let left_boundaries = left - .analyze(analysis_ctx) - .boundaries - .expect("this case should not return None"); - assert_eq!(left_boundaries.min_value, ScalarValue::Int64(Some(exp_min))); - assert_eq!(left_boundaries.max_value, ScalarValue::Int64(Some(exp_max))); - } - Ok(()) - } - - #[test] - fn test_comparison_result_estimate_different_type() -> Result<()> { - // A table where the column 'a' has a min of 1.3, a max of 50.7. - let (schema, statistics) = - get_test_table_stats(ScalarValue::from(1.3), ScalarValue::from(50.7)); - let distance = 50.0; // rounded distance is (max - min) + 1 - - // Since the generic version already covers all the paths, we can just - // test a small subset of the cases. - let cases = [ - // (operator, rhs), (expected selectivity, expected min, expected max) - // ------------------------------------------------------------------- - // - // Table: - // - a (min = 1.3, max = 50.7, distinct_count = 25) - // - // Never selects (out of range) - ((Operator::Eq, 1.1), (0.0, 1.1, 1.1)), - ((Operator::Eq, 50.75), (0.0, 50.75, 50.75)), - ((Operator::Lt, 1.3), (0.0, 1.3, 50.7)), - ((Operator::LtEq, 1.29), (0.0, 1.3, 50.7)), - ((Operator::Gt, 50.7), (0.0, 1.3, 50.7)), - ((Operator::GtEq, 50.75), (0.0, 1.3, 50.7)), - // Always selects - ((Operator::Lt, 50.75), (1.0, 1.3, 50.7)), - ((Operator::LtEq, 50.75), (1.0, 1.3, 50.7)), - ((Operator::Gt, 1.29), (1.0, 1.3, 50.7)), - ((Operator::GtEq, 1.3), (1.0, 1.3, 50.7)), - // Partial selection (the x in 'x/distance' is basically the rounded version of - // the bound distance, as per the implementation). - ((Operator::Eq, 27.8), (1.0 / distance, 27.8, 27.8)), - ((Operator::Lt, 5.2), (4.0 / distance, 1.3, 5.2)), // On a uniform distribution, this is {2.6, 3.9} - ((Operator::LtEq, 1.3), (1.0 / distance, 1.3, 1.3)), - ((Operator::Gt, 45.5), (5.0 / distance, 45.5, 50.7)), // On a uniform distribution, this is {46.8, 48.1, 49.4} - ((Operator::GtEq, 50.7), (1.0 / distance, 50.7, 50.7)), - ]; - - for ((operator, rhs), (exp_selectivity, exp_min, exp_max)) in cases { - let context = AnalysisContext::from_statistics(&schema, &statistics); - let left = col("a", &schema).unwrap(); - let right = ScalarValue::from(rhs); - let analysis_ctx = - analyze_expr_scalar_comparison(context, &operator, &left, right); - let boundaries = analysis_ctx - .clone() - .boundaries - .expect("Analysis must complete for this test!"); - - assert_eq!( - boundaries - .selectivity - .expect("compare_left_boundaries must produce a selectivity value"), - exp_selectivity - ); - - if exp_selectivity == 1.0 { - // When the expected selectivity is 1.0, the resulting expression - // should always be true. - assert_eq!(boundaries.reduce(), Some(ScalarValue::from(true))); - } else if exp_selectivity == 0.0 { - // When the expected selectivity is 0.0, the resulting expression - // should always be false. - assert_eq!(boundaries.reduce(), Some(ScalarValue::from(false))); - } else { - // Otherwise, it should be [false, true] (since we don't know anything for sure) - assert_eq!(boundaries.min_value, ScalarValue::from(false)); - assert_eq!(boundaries.max_value, ScalarValue::from(true)); - } - - let left_boundaries = left - .analyze(analysis_ctx) - .boundaries - .expect("this case should not return None"); - assert_eq!( - left_boundaries.min_value, - ScalarValue::Float64(Some(exp_min)) - ); - assert_eq!( - left_boundaries.max_value, - ScalarValue::Float64(Some(exp_max)) - ); - } - Ok(()) - } - - #[test] - fn test_binary_expression_boundaries() -> Result<()> { - // A table where the column 'a' has a min of 1, a max of 100. - let (schema, statistics) = - get_test_table_stats(ScalarValue::from(1), ScalarValue::from(100)); - - // expression: "a >= 25" - let a = col("a", &schema).unwrap(); - let gt = binary_simple( - a.clone(), - Operator::GtEq, - lit(ScalarValue::from(25)), - &schema, - ); - - let context = AnalysisContext::from_statistics(&schema, &statistics); - let predicate_boundaries = gt - .analyze(context) - .boundaries - .expect("boundaries should not be None"); - assert_eq!(predicate_boundaries.selectivity, Some(0.76)); - - Ok(()) - } - - #[test] - fn test_binary_expression_boundaries_rhs() -> Result<()> { - // This test is about the column rewriting feature in the boundary provider - // (e.g. if the lhs is a literal and rhs is the column, then we swap them when - // doing the computation). - - // A table where the column 'a' has a min of 1, a max of 100. - let (schema, statistics) = - get_test_table_stats(ScalarValue::from(1), ScalarValue::from(100)); - - // expression: "50 >= a" - let a = col("a", &schema).unwrap(); - let gt = binary_simple( - lit(ScalarValue::from(50)), - Operator::GtEq, - a.clone(), - &schema, - ); - - let context = AnalysisContext::from_statistics(&schema, &statistics); - let predicate_boundaries = gt - .analyze(context) - .boundaries - .expect("boundaries should not be None"); - assert_eq!(predicate_boundaries.selectivity, Some(0.5)); - - Ok(()) - } - #[test] fn test_display_and_or_combo() { let expr = BinaryExpr::new( diff --git a/datafusion/physical-expr/src/expressions/binary/adapter.rs b/datafusion/physical-expr/src/expressions/binary/adapter.rs deleted file mode 100644 index ec0eda392976b..0000000000000 --- a/datafusion/physical-expr/src/expressions/binary/adapter.rs +++ /dev/null @@ -1,49 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains functions that change types or names of other -//! kernels to make them compatible with the main dispatch logic - -use std::sync::Arc; - -use arrow::array::*; -use datafusion_common::Result; - -/// create a `dyn_op` wrapper function for the specified operation -/// that call the underlying dyn_op arrow kernel if the type is -/// supported, and translates ArrowError to DataFusionError -macro_rules! make_dyn_comp_op { - ($OP:tt) => { - paste::paste! { - /// wrapper over arrow compute kernel that maps Error types and - /// patches missing support in arrow - pub(crate) fn [<$OP _dyn>] (left: &dyn Array, right: &dyn Array) -> Result { - arrow::compute::kernels::comparison::[<$OP _dyn>](left, right) - .map_err(|e| e.into()) - .map(|a| Arc::new(a) as ArrayRef) - } - } - }; -} - -// create eq_dyn, gt_dyn, wrappers etc -make_dyn_comp_op!(eq); -make_dyn_comp_op!(gt); -make_dyn_comp_op!(gt_eq); -make_dyn_comp_op!(lt); -make_dyn_comp_op!(lt_eq); -make_dyn_comp_op!(neq); diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index 185b9416d3d02..22cadec40940d 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -24,6 +24,7 @@ use arrow::compute::kernels::bitwise::{ bitwise_xor, bitwise_xor_scalar, }; use arrow::datatypes::DataType; +use datafusion_common::internal_err; use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::sync::Arc; @@ -68,11 +69,11 @@ macro_rules! create_dyn_kernel { DataType::UInt64 => { call_bitwise_kernel!(left, right, $KERNEL, UInt64Array) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, - stringify!($KERNEL), - ))), + stringify!($KERNEL) + ), } } }; @@ -114,11 +115,11 @@ macro_rules! create_dyn_scalar_kernel { DataType::UInt16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), DataType::UInt32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), DataType::UInt64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, - stringify!($KERNEL), - ))), + stringify!($KERNEL) + ), }; Some(result) } diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs deleted file mode 100644 index 4d984ac8e8452..0000000000000 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ /dev/null @@ -1,2576 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains computation kernels that are eventually -//! destined for arrow-rs but are in datafusion until they are ported. - -use arrow::compute::{ - add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn, - modulus_scalar_dyn, multiply_dyn, multiply_fixed_point, multiply_scalar_dyn, - subtract_dyn, subtract_scalar_dyn, try_unary, -}; -use arrow::datatypes::{ - i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - DECIMAL128_MAX_PRECISION, -}; -use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array}; -use arrow_array::types::{ArrowDictionaryKeyType, DecimalType}; -use arrow_array::ArrowNativeTypeOp; -use arrow_buffer::ArrowNativeType; -use arrow_schema::{DataType, IntervalUnit}; -use chrono::{Days, Duration, Months, NaiveDate, NaiveDateTime}; -use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array}; -use datafusion_common::scalar::{date32_op, date64_op}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use std::cmp::min; -use std::ops::Add; -use std::sync::Arc; - -use arrow::compute::unary; -use arrow::datatypes::*; - -use arrow_array::temporal_conversions::{MILLISECONDS_IN_DAY, NANOSECONDS_IN_DAY}; -use datafusion_common::delta::shift_months; -use datafusion_common::scalar::{ - calculate_naives, microseconds_add, microseconds_sub, milliseconds_add, - milliseconds_sub, nanoseconds_add, nanoseconds_sub, op_dt, op_dt_mdn, op_mdn, op_ym, - op_ym_dt, op_ym_mdn, parse_timezones, seconds_add, MILLISECOND_MODE, NANOSECOND_MODE, -}; - -use arrow::datatypes::TimeUnit; - -use datafusion_common::cast::{ - as_interval_dt_array, as_interval_mdn_array, as_interval_ym_array, - as_timestamp_microsecond_array, as_timestamp_millisecond_array, - as_timestamp_nanosecond_array, as_timestamp_second_array, -}; -use datafusion_common::scalar::*; - -// Simple (low performance) kernels until optimized kernels are added to arrow -// See https://github.com/apache/arrow-rs/issues/960 - -macro_rules! distinct_float { - ($LEFT:expr, $RIGHT:expr, $LEFT_ISNULL:expr, $RIGHT_ISNULL:expr) => {{ - $LEFT_ISNULL != $RIGHT_ISNULL - || $LEFT.is_nan() != $RIGHT.is_nan() - || (!$LEFT.is_nan() && !$RIGHT.is_nan() && $LEFT != $RIGHT) - }}; -} - -pub(crate) fn is_distinct_from_bool( - left: &BooleanArray, - right: &BooleanArray, -) -> Result { - // Different from `neq_bool` because `null is distinct from null` is false and not null - Ok(left - .iter() - .zip(right.iter()) - .map(|(left, right)| Some(left != right)) - .collect()) -} - -pub(crate) fn is_not_distinct_from_bool( - left: &BooleanArray, - right: &BooleanArray, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(left, right)| Some(left == right)) - .collect()) -} - -pub(crate) fn is_distinct_from( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result -where - T: ArrowNumericType, -{ - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - left_isnull != right_isnull || left_value != right_value - }, - ) -} - -pub(crate) fn is_not_distinct_from( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result -where - T: ArrowNumericType, -{ - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - !(left_isnull != right_isnull || left_value != right_value) - }, - ) -} - -fn distinct< - T, - F: FnMut( - ::Native, - ::Native, - bool, - bool, - ) -> bool, ->( - left: &PrimitiveArray, - right: &PrimitiveArray, - mut op: F, -) -> Result -where - T: ArrowNumericType, -{ - let left_values = left.values(); - let right_values = right.values(); - let left_nulls = left.nulls(); - let right_nulls = right.nulls(); - - let array_len = left.len().min(right.len()); - let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| { - op( - left_values[i], - right_values[i], - left_nulls.map(|x| x.is_null(i)).unwrap_or_default(), - right_nulls.map(|x| x.is_null(i)).unwrap_or_default(), - ) - }); - let array_data = ArrayData::builder(arrow_schema::DataType::Boolean) - .len(array_len) - .add_buffer(distinct.into()); - - Ok(BooleanArray::from(unsafe { array_data.build_unchecked() })) -} - -pub(crate) fn is_distinct_from_f32( - left: &Float32Array, - right: &Float32Array, -) -> Result { - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - distinct_float!(left_value, right_value, left_isnull, right_isnull) - }, - ) -} - -pub(crate) fn is_not_distinct_from_f32( - left: &Float32Array, - right: &Float32Array, -) -> Result { - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - !(distinct_float!(left_value, right_value, left_isnull, right_isnull)) - }, - ) -} - -pub(crate) fn is_distinct_from_f64( - left: &Float64Array, - right: &Float64Array, -) -> Result { - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - distinct_float!(left_value, right_value, left_isnull, right_isnull) - }, - ) -} - -pub(crate) fn is_not_distinct_from_f64( - left: &Float64Array, - right: &Float64Array, -) -> Result { - distinct( - left, - right, - |left_value, right_value, left_isnull, right_isnull| { - !(distinct_float!(left_value, right_value, left_isnull, right_isnull)) - }, - ) -} - -pub(crate) fn is_distinct_from_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(x, y)| Some(x != y)) - .collect()) -} - -pub(crate) fn is_distinct_from_binary( - left: &GenericBinaryArray, - right: &GenericBinaryArray, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(x, y)| Some(x != y)) - .collect()) -} - -pub(crate) fn is_distinct_from_null( - left: &NullArray, - _right: &NullArray, -) -> Result { - let length = left.len(); - make_boolean_array(length, false) -} - -pub(crate) fn is_not_distinct_from_null( - left: &NullArray, - _right: &NullArray, -) -> Result { - let length = left.len(); - make_boolean_array(length, true) -} - -fn make_boolean_array(length: usize, value: bool) -> Result { - Ok((0..length).map(|_| Some(value)).collect()) -} - -pub(crate) fn is_not_distinct_from_utf8( - left: &GenericStringArray, - right: &GenericStringArray, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(x, y)| Some(x == y)) - .collect()) -} - -pub(crate) fn is_not_distinct_from_binary( - left: &GenericBinaryArray, - right: &GenericBinaryArray, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(x, y)| Some(x == y)) - .collect()) -} - -pub(crate) fn is_distinct_from_decimal( - left: &Decimal128Array, - right: &Decimal128Array, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(left, right)| match (left, right) { - (None, None) => Some(false), - (None, Some(_)) | (Some(_), None) => Some(true), - (Some(left), Some(right)) => Some(left != right), - }) - .collect()) -} - -pub(crate) fn is_not_distinct_from_decimal( - left: &Decimal128Array, - right: &Decimal128Array, -) -> Result { - Ok(left - .iter() - .zip(right.iter()) - .map(|(left, right)| match (left, right) { - (None, None) => Some(true), - (None, Some(_)) | (Some(_), None) => Some(false), - (Some(left), Some(right)) => Some(left == right), - }) - .collect()) -} - -pub(crate) fn add_dyn_decimal( - left: &dyn Array, - right: &dyn Array, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - let array = add_dyn(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn add_decimal_dyn_scalar( - left: &dyn Array, - right: i128, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - - let array = add_scalar_dyn::(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn add_dyn_temporal(left: &ArrayRef, right: &ArrayRef) -> Result { - match (left.data_type(), right.data_type()) { - (DataType::Timestamp(..), DataType::Timestamp(..)) => ts_array_op(left, right), - (DataType::Interval(..), DataType::Interval(..)) => { - interval_array_op(left, right, 1) - } - (DataType::Timestamp(..), DataType::Interval(..)) => { - ts_interval_array_op(left, 1, right) - } - (DataType::Interval(..), DataType::Timestamp(..)) => { - ts_interval_array_op(right, 1, left) - } - _ => { - // fall back to kernels in arrow-rs - Ok(add_dyn(left, right)?) - } - } -} - -pub(crate) fn add_dyn_temporal_right_scalar( - left: &ArrayRef, - right: &ScalarValue, -) -> Result { - match (left.data_type(), right.get_datatype()) { - // Date32 + Interval - (DataType::Date32, DataType::Interval(..)) => { - let left = as_date32_array(&left)?; - let ret = Arc::new(try_unary::(left, |days| { - Ok(date32_op(days, right, 1)?) - })?) as _; - Ok(ret) - } - // Date64 + Interval - (DataType::Date64, DataType::Interval(..)) => { - let left = as_date64_array(&left)?; - let ret = Arc::new(try_unary::(left, |ms| { - Ok(date64_op(ms, right, 1)?) - })?) as _; - Ok(ret) - } - // Interval + Interval - (DataType::Interval(..), DataType::Interval(..)) => { - interval_op_scalar_interval(left, 1, right) - } - // Timestamp + Interval - (DataType::Timestamp(..), DataType::Interval(..)) => { - ts_op_scalar_interval(left, 1, right) - } - _ => { - // fall back to kernels in arrow-rs - Ok(add_dyn(left, &right.to_array())?) - } - } -} - -pub(crate) fn add_dyn_temporal_left_scalar( - left: &ScalarValue, - right: &ArrayRef, -) -> Result { - match (left.get_datatype(), right.data_type()) { - // Date32 + Interval - (DataType::Date32, DataType::Interval(..)) => { - if let ScalarValue::Date32(Some(left)) = left { - scalar_date32_array_interval_op( - *left, - right, - NaiveDate::checked_add_days, - NaiveDate::checked_add_months, - ) - } else { - Err(DataFusionError::Internal( - "Date32 value is None".to_string(), - )) - } - } - // Date64 + Interval - (DataType::Date64, DataType::Interval(..)) => { - if let ScalarValue::Date64(Some(left)) = left { - scalar_date64_array_interval_op( - *left, - right, - NaiveDate::checked_add_days, - NaiveDate::checked_add_months, - ) - } else { - Err(DataFusionError::Internal( - "Date64 value is None".to_string(), - )) - } - } - // Interval + Interval - (DataType::Interval(..), DataType::Interval(..)) => { - scalar_interval_op_interval(left, 1, right) - } - // Timestamp + Interval - (DataType::Timestamp(..), DataType::Interval(..)) => { - scalar_ts_op_interval(left, 1, right) - } - _ => { - // fall back to kernels in arrow-rs - Ok(add_dyn(&left.to_array(), right)?) - } - } -} - -pub(crate) fn subtract_decimal_dyn_scalar( - left: &dyn Array, - right: i128, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - - let array = subtract_scalar_dyn::(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn subtract_dyn_temporal( - left: &ArrayRef, - right: &ArrayRef, -) -> Result { - match (left.data_type(), right.data_type()) { - (DataType::Timestamp(..), DataType::Timestamp(..)) => ts_array_op(left, right), - (DataType::Interval(..), DataType::Interval(..)) => { - interval_array_op(left, right, -1) - } - (DataType::Timestamp(..), DataType::Interval(..)) => { - ts_interval_array_op(left, -1, right) - } - (DataType::Interval(..), DataType::Timestamp(..)) => { - ts_interval_array_op(right, -1, left) - } - _ => { - // fall back to kernels in arrow-rs - Ok(subtract_dyn(left, right)?) - } - } -} - -pub(crate) fn subtract_dyn_temporal_right_scalar( - left: &ArrayRef, - right: &ScalarValue, -) -> Result { - match (left.data_type(), right.get_datatype()) { - // Date32 - Interval - (DataType::Date32, DataType::Interval(..)) => { - let left = as_date32_array(&left)?; - let ret = Arc::new(try_unary::(left, |days| { - Ok(date32_op(days, right, -1)?) - })?) as _; - Ok(ret) - } - // Date64 - Interval - (DataType::Date64, DataType::Interval(..)) => { - let left = as_date64_array(&left)?; - let ret = Arc::new(try_unary::(left, |ms| { - Ok(date64_op(ms, right, -1)?) - })?) as _; - Ok(ret) - } - // Timestamp - Timestamp - (DataType::Timestamp(..), DataType::Timestamp(..)) => { - ts_sub_scalar_ts(left, right) - } - // Interval - Interval - (DataType::Interval(..), DataType::Interval(..)) => { - interval_op_scalar_interval(left, -1, right) - } - // Timestamp - Interval - (DataType::Timestamp(..), DataType::Interval(..)) => { - ts_op_scalar_interval(left, -1, right) - } - _ => { - // fall back to kernels in arrow-rs - Ok(subtract_dyn(left, &right.to_array())?) - } - } -} - -pub(crate) fn subtract_dyn_temporal_left_scalar( - left: &ScalarValue, - right: &ArrayRef, -) -> Result { - match (left.get_datatype(), right.data_type()) { - // Date32 - Interval - (DataType::Date32, DataType::Interval(..)) => { - if let ScalarValue::Date32(Some(left)) = left { - scalar_date32_array_interval_op( - *left, - right, - NaiveDate::checked_sub_days, - NaiveDate::checked_sub_months, - ) - } else { - Err(DataFusionError::Internal( - "Date32 value is None".to_string(), - )) - } - } - // Date64 - Interval - (DataType::Date64, DataType::Interval(..)) => { - if let ScalarValue::Date64(Some(left)) = left { - scalar_date64_array_interval_op( - *left, - right, - NaiveDate::checked_sub_days, - NaiveDate::checked_sub_months, - ) - } else { - Err(DataFusionError::Internal( - "Date64 value is None".to_string(), - )) - } - } - // Timestamp - Timestamp - (DataType::Timestamp(..), DataType::Timestamp(..)) => { - scalar_ts_sub_ts(left, right) - } - // Interval - Interval - (DataType::Interval(..), DataType::Interval(..)) => { - scalar_interval_op_interval(left, -1, right) - } - // Timestamp - Interval - (DataType::Timestamp(..), DataType::Interval(..)) => { - scalar_ts_op_interval(left, -1, right) - } - _ => { - // fall back to kernels in arrow-rs - Ok(subtract_dyn(&left.to_array(), right)?) - } - } -} - -fn scalar_date32_array_interval_op( - left: i32, - right: &ArrayRef, - day_op: fn(NaiveDate, Days) -> Option, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1) - .ok_or_else(|| DataFusionError::Execution("Invalid Date entered".to_string()))?; - let prior = epoch.add(Duration::days(left as i64)); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - date32_interval_ym_op(right, &epoch, &prior, month_op) - } - DataType::Interval(IntervalUnit::DayTime) => { - date32_interval_dt_op(right, &epoch, &prior, day_op) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - date32_interval_mdn_op(right, &epoch, &prior, day_op, month_op) - } - _ => Err(DataFusionError::Internal(format!( - "Expected type is an interval, but {} is found", - right.data_type() - ))), - } -} - -fn scalar_date64_array_interval_op( - left: i64, - right: &ArrayRef, - day_op: fn(NaiveDate, Days) -> Option, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1) - .ok_or_else(|| DataFusionError::Execution("Invalid Date entered".to_string()))?; - let prior = epoch.add(Duration::milliseconds(left)); - match right.data_type() { - DataType::Interval(IntervalUnit::YearMonth) => { - date64_interval_ym_op(right, &epoch, &prior, month_op) - } - DataType::Interval(IntervalUnit::DayTime) => { - date64_interval_dt_op(right, &epoch, &prior, day_op) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - date64_interval_mdn_op(right, &epoch, &prior, day_op, month_op) - } - _ => Err(DataFusionError::Internal(format!( - "Expected type is an interval, but {} is found", - right.data_type() - ))), - } -} - -fn get_precision_scale(data_type: &DataType) -> Result<(u8, i8)> { - match data_type { - DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), - DataType::Dictionary(_, value_type) => match value_type.as_ref() { - DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), - _ => Err(DataFusionError::Internal( - "Unexpected data type".to_string(), - )), - }, - _ => Err(DataFusionError::Internal( - "Unexpected data type".to_string(), - )), - } -} - -fn decimal_array_with_precision_scale( - array: ArrayRef, - precision: u8, - scale: i8, -) -> Result { - let array = array.as_ref(); - let decimal_array = match array.data_type() { - DataType::Decimal128(_, _) => { - let array = as_decimal128_array(array)?; - Arc::new(array.clone().with_precision_and_scale(precision, scale)?) - as ArrayRef - } - DataType::Dictionary(_, _) => { - downcast_dictionary_array!( - array => match array.values().data_type() { - DataType::Decimal128(_, _) => { - let decimal_dict_array = array.downcast_dict::().unwrap(); - let decimal_array = decimal_dict_array.values().clone(); - let decimal_array = decimal_array.with_precision_and_scale(precision, scale)?; - Arc::new(array.with_values(&decimal_array)) as ArrayRef - } - t => return Err(DataFusionError::Internal(format!("Unexpected dictionary value type {t}"))), - }, - t => return Err(DataFusionError::Internal(format!("Unexpected datatype {t}"))), - ) - } - _ => { - return Err(DataFusionError::Internal( - "Unexpected data type".to_string(), - )) - } - }; - Ok(decimal_array) -} - -pub(crate) fn multiply_decimal_dyn_scalar( - left: &dyn Array, - right: i128, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - let array = multiply_scalar_dyn::(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn divide_decimal_dyn_scalar( - left: &dyn Array, - right: i128, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - - let mul = 10_i128.pow(scale as u32); - let array = multiply_scalar_dyn::(left, mul)?; - - let array = divide_scalar_dyn::(&array, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn subtract_dyn_decimal( - left: &dyn Array, - right: &dyn Array, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - let array = subtract_dyn(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`. -fn math_op_dict( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result> -where - K: ArrowDictionaryKeyType + ArrowNumericType, - T: ArrowNumericType, - F: Fn(T::Native, T::Native) -> T::Native, -{ - if left.len() != right.len() { - return Err(DataFusionError::Internal(format!( - "Cannot perform operation on arrays of different length ({}, {})", - left.len(), - right.len() - ))); - } - - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - left.values() - .as_primitive::() - .take_iter_unchecked(left.keys_iter()) - }; - - let right_iter = unsafe { - right - .values() - .as_primitive::() - .take_iter_unchecked(right.keys_iter()) - }; - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some(op(left, right)) - } else { - None - } - }) - .collect(); - - Ok(result) -} - -/// Divide a decimal native value by given divisor and round the result. -/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`. -fn divide_and_round(input: I::Native, div: I::Native) -> I::Native -where - I: DecimalType, - I::Native: ArrowNativeTypeOp, -{ - let d = input.div_wrapping(div); - let r = input.mod_wrapping(div); - - let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); - let half_neg = half.neg_wrapping(); - // Round result - match input >= I::Native::ZERO { - true if r >= half => d.add_wrapping(I::Native::ONE), - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), - _ => d, - } -} - -/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`. -/// -fn multiply_fixed_point_dyn( - left: &dyn Array, - right: &dyn Array, - required_scale: i8, -) -> Result { - match (left.data_type(), right.data_type()) { - ( - DataType::Dictionary(_, lhs_value_type), - DataType::Dictionary(_, rhs_value_type), - ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _)) - && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) => - { - downcast_dictionary_array!( - left => match left.values().data_type() { - DataType::Decimal128(_, _) => { - let lhs_precision_scale = get_precision_scale(lhs_value_type.as_ref())?; - let rhs_precision_scale = get_precision_scale(rhs_value_type.as_ref())?; - - let product_scale = lhs_precision_scale.1 + rhs_precision_scale.1; - let precision = min(lhs_precision_scale.0 + rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION); - - if required_scale == product_scale { - return Ok(multiply_dyn(left, right)?.as_primitive::().clone() - .with_precision_and_scale(precision, required_scale).map(|a| Arc::new(a) as ArrayRef)?); - } - - if required_scale > product_scale { - return Err(DataFusionError::Internal(format!( - "Required scale {required_scale} is greater than product scale {product_scale}" - ))); - } - - let divisor = - i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); - - let right = as_dictionary_array::<_>(right); - - let array = math_op_dict::<_, Decimal128Type, _>(left, right, |a, b| { - let a = i256::from_i128(a); - let b = i256::from_i128(b); - - let mut mul = a.wrapping_mul(b); - mul = divide_and_round::(mul, divisor); - mul.as_i128() - }).map(|a| a.with_precision_and_scale(precision, required_scale).unwrap())?; - - Ok(Arc::new(array)) - } - t => unreachable!("Unsupported dictionary value type {}", t), - }, - t => unreachable!("Unsupported data type {}", t), - ) - } - (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => { - let left = left.as_any().downcast_ref::().unwrap(); - let right = right.as_any().downcast_ref::().unwrap(); - - Ok(multiply_fixed_point(left, right, required_scale) - .map(|a| Arc::new(a) as ArrayRef)?) - } - (_, _) => Err(DataFusionError::Internal(format!( - "Unsupported data type {}, {}", - left.data_type(), - right.data_type() - ))), - } -} - -pub(crate) fn multiply_dyn_decimal( - left: &dyn Array, - right: &dyn Array, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - let array = multiply_fixed_point_dyn(left, right, scale)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn divide_dyn_opt_decimal( - left: &dyn Array, - right: &dyn Array, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - - let mul = 10_i128.pow(scale as u32); - let array = multiply_scalar_dyn::(left, mul)?; - - // Restore to original precision and scale (metadata only) - let (org_precision, org_scale) = get_precision_scale(right.data_type())?; - let array = decimal_array_with_precision_scale(array, org_precision, org_scale)?; - let array = divide_dyn_opt(&array, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn modulus_dyn_decimal( - left: &dyn Array, - right: &dyn Array, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - let array = modulus_dyn(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -pub(crate) fn modulus_decimal_dyn_scalar( - left: &dyn Array, - right: i128, - result_type: &DataType, -) -> Result { - let (precision, scale) = get_precision_scale(result_type)?; - - let array = modulus_scalar_dyn::(left, right)?; - decimal_array_with_precision_scale(array, precision, scale) -} - -macro_rules! sub_timestamp_macro { - ($array:expr, $rhs:expr, $caster:expr, $interval_type:ty, $opt_tz_lhs:expr, $multiplier:expr, - $opt_tz_rhs:expr, $unit_sub:expr, $naive_sub_fn:expr, $counter:expr) => {{ - let prim_array = $caster(&$array)?; - let ret: PrimitiveArray<$interval_type> = try_unary(prim_array, |lhs| { - let (parsed_lhs_tz, parsed_rhs_tz) = - (parse_timezones($opt_tz_lhs)?, parse_timezones($opt_tz_rhs)?); - let (naive_lhs, naive_rhs) = calculate_naives::<$unit_sub>( - lhs.mul_wrapping($multiplier), - parsed_lhs_tz, - $rhs.mul_wrapping($multiplier), - parsed_rhs_tz, - )?; - Ok($naive_sub_fn($counter(&naive_lhs), $counter(&naive_rhs))) - })?; - Arc::new(ret) as _ - }}; -} - -macro_rules! sub_timestamp_left_scalar_macro { - ($array:expr, $lhs:expr, $caster:expr, $interval_type:ty, $opt_tz_lhs:expr, $multiplier:expr, - $opt_tz_rhs:expr, $unit_sub:expr, $naive_sub_fn:expr, $counter:expr) => {{ - let prim_array = $caster(&$array)?; - let ret: PrimitiveArray<$interval_type> = try_unary(prim_array, |rhs| { - let (parsed_lhs_tz, parsed_rhs_tz) = - (parse_timezones($opt_tz_lhs)?, parse_timezones($opt_tz_rhs)?); - let (naive_lhs, naive_rhs) = calculate_naives::<$unit_sub>( - $lhs.mul_wrapping($multiplier), - parsed_lhs_tz, - rhs.mul_wrapping($multiplier), - parsed_rhs_tz, - )?; - Ok($naive_sub_fn($counter(&naive_lhs), $counter(&naive_rhs))) - })?; - Arc::new(ret) as _ - }}; -} - -macro_rules! op_timestamp_interval_macro { - ($array:expr, $as_timestamp:expr, $ts_type:ty, $fn_op:expr, $scalar:expr, $sign:expr, $tz:expr) => {{ - let array = $as_timestamp(&$array)?; - let ret: PrimitiveArray<$ts_type> = - try_unary::<$ts_type, _, $ts_type>(array, |ts_s| { - Ok($fn_op(ts_s, $scalar, $sign)?) - })?; - Arc::new(ret.with_timezone_opt($tz.clone())) as _ - }}; -} - -macro_rules! scalar_ts_op_interval_macro { - ($ts:ident, $tz:ident, $interval:ident, $sign:ident, - $caster1:expr, $type1:ty, $type2:ty, $op:expr, $back_caster:expr) => {{ - let interval = $caster1(&$interval)?; - let ret: PrimitiveArray<$type1> = - try_unary::<$type2, _, $type1>(interval, |e| { - let prior = $ts.ok_or_else(|| { - DataFusionError::Internal("Timestamp is out-of-range".to_string()) - })?; - Ok($back_caster(&$op(prior, e, $sign))) - })?; - Arc::new(ret.with_timezone_opt($tz.clone())) as _ - }}; -} - -macro_rules! op_interval_macro { - ($array:expr, $as_interval:expr, $interval_type:ty, $fn_op:expr, $scalar:expr, $sign:expr) => {{ - let array = $as_interval(&$array)?; - let ret: PrimitiveArray<$interval_type> = - unary(array, |lhs| $fn_op(lhs, *$scalar, $sign)); - Arc::new(ret) as _ - }}; -} - -macro_rules! op_interval_cross_macro { - ($array:expr, $as_interval:expr, $commute:expr, $fn_op:expr, $scalar:expr, $sign:expr, $t1:ty, $t2:ty) => {{ - let array = $as_interval(&$array)?; - let ret: PrimitiveArray = if $commute { - unary(array, |lhs| { - $fn_op(*$scalar as $t1, lhs as $t2, $sign, $commute) - }) - } else { - unary(array, |lhs| { - $fn_op(lhs as $t1, *$scalar as $t2, $sign, $commute) - }) - }; - Arc::new(ret) as _ - }}; -} - -macro_rules! ts_sub_op { - ($lhs:ident, $rhs:ident, $lhs_tz:ident, $rhs_tz:ident, $coef:expr, $caster:expr, $op:expr, $ts_unit:expr, $mode:expr, $type_out:ty) => {{ - let prim_array_lhs = $caster(&$lhs)?; - let prim_array_rhs = $caster(&$rhs)?; - let ret: PrimitiveArray<$type_out> = - arrow::compute::try_binary(prim_array_lhs, prim_array_rhs, |ts1, ts2| { - let (parsed_lhs_tz, parsed_rhs_tz) = ( - parse_timezones($lhs_tz.as_deref())?, - parse_timezones($rhs_tz.as_deref())?, - ); - let (naive_lhs, naive_rhs) = calculate_naives::<$mode>( - ts1.mul_wrapping($coef), - parsed_lhs_tz, - ts2.mul_wrapping($coef), - parsed_rhs_tz, - )?; - Ok($op($ts_unit(&naive_lhs), $ts_unit(&naive_rhs))) - })?; - Arc::new(ret) as _ - }}; -} - -macro_rules! interval_op { - ($lhs:ident, $rhs:ident, $caster:expr, $op:expr, $sign:ident, $type_in:ty) => {{ - let prim_array_lhs = $caster(&$lhs)?; - let prim_array_rhs = $caster(&$rhs)?; - Arc::new(arrow::compute::binary::<$type_in, $type_in, _, $type_in>( - prim_array_lhs, - prim_array_rhs, - |interval1, interval2| $op(interval1, interval2, $sign), - )?) as _ - }}; -} - -macro_rules! interval_cross_op { - ($lhs:ident, $rhs:ident, $caster1:expr, $caster2:expr, $op:expr, $sign:ident, $commute:ident, $type_in1:ty, $type_in2:ty) => {{ - let prim_array_lhs = $caster1(&$lhs)?; - let prim_array_rhs = $caster2(&$rhs)?; - Arc::new(arrow::compute::binary::< - $type_in1, - $type_in2, - _, - IntervalMonthDayNanoType, - >( - prim_array_lhs, - prim_array_rhs, - |interval1, interval2| $op(interval1, interval2, $sign, $commute), - )?) as _ - }}; -} - -macro_rules! ts_interval_op { - ($lhs:ident, $rhs:ident, $tz:ident, $caster1:expr, $caster2:expr, $op:expr, $sign:ident, $type_in1:ty, $type_in2:ty) => {{ - let prim_array_lhs = $caster1(&$lhs)?; - let prim_array_rhs = $caster2(&$rhs)?; - let ret: PrimitiveArray<$type_in1> = arrow::compute::try_binary( - prim_array_lhs, - prim_array_rhs, - |ts, interval| Ok($op(ts, interval as i128, $sign)?), - )?; - Arc::new(ret.with_timezone_opt($tz.clone())) as _ - }}; -} - -/// This function handles timestamp - timestamp operations where the former is -/// an array and the latter is a scalar, resulting in an array. -pub fn ts_sub_scalar_ts(array: &ArrayRef, scalar: &ScalarValue) -> Result { - let ret = match (array.data_type(), scalar) { - ( - DataType::Timestamp(TimeUnit::Second, opt_tz_lhs), - ScalarValue::TimestampSecond(Some(rhs), opt_tz_rhs), - ) => { - sub_timestamp_macro!( - array, - rhs, - as_timestamp_second_array, - IntervalDayTimeType, - opt_tz_lhs.as_deref(), - 1000, - opt_tz_rhs.as_deref(), - MILLISECOND_MODE, - seconds_sub, - NaiveDateTime::timestamp - ) - } - ( - DataType::Timestamp(TimeUnit::Millisecond, opt_tz_lhs), - ScalarValue::TimestampMillisecond(Some(rhs), opt_tz_rhs), - ) => { - sub_timestamp_macro!( - array, - rhs, - as_timestamp_millisecond_array, - IntervalDayTimeType, - opt_tz_lhs.as_deref(), - 1, - opt_tz_rhs.as_deref(), - MILLISECOND_MODE, - milliseconds_sub, - NaiveDateTime::timestamp_millis - ) - } - ( - DataType::Timestamp(TimeUnit::Microsecond, opt_tz_lhs), - ScalarValue::TimestampMicrosecond(Some(rhs), opt_tz_rhs), - ) => { - sub_timestamp_macro!( - array, - rhs, - as_timestamp_microsecond_array, - IntervalMonthDayNanoType, - opt_tz_lhs.as_deref(), - 1000, - opt_tz_rhs.as_deref(), - NANOSECOND_MODE, - microseconds_sub, - NaiveDateTime::timestamp_micros - ) - } - ( - DataType::Timestamp(TimeUnit::Nanosecond, opt_tz_lhs), - ScalarValue::TimestampNanosecond(Some(rhs), opt_tz_rhs), - ) => { - sub_timestamp_macro!( - array, - rhs, - as_timestamp_nanosecond_array, - IntervalMonthDayNanoType, - opt_tz_lhs.as_deref(), - 1, - opt_tz_rhs.as_deref(), - NANOSECOND_MODE, - nanoseconds_sub, - NaiveDateTime::timestamp_nanos - ) - } - (_, _) => { - return Err(DataFusionError::Internal(format!( - "Invalid array - scalar types for Timestamp subtraction: {:?} - {:?}", - array.data_type(), - scalar.get_datatype() - ))); - } - }; - Ok(ret) -} - -/// This function handles timestamp - timestamp operations where the former is -/// a scalar and the latter is an array, resulting in an array. -pub fn scalar_ts_sub_ts(scalar: &ScalarValue, array: &ArrayRef) -> Result { - let ret = match (scalar, array.data_type()) { - ( - ScalarValue::TimestampSecond(Some(lhs), opt_tz_lhs), - DataType::Timestamp(TimeUnit::Second, opt_tz_rhs), - ) => { - sub_timestamp_left_scalar_macro!( - array, - lhs, - as_timestamp_second_array, - IntervalDayTimeType, - opt_tz_lhs.as_deref(), - 1000, - opt_tz_rhs.as_deref(), - MILLISECOND_MODE, - seconds_sub, - NaiveDateTime::timestamp - ) - } - ( - ScalarValue::TimestampMillisecond(Some(lhs), opt_tz_lhs), - DataType::Timestamp(TimeUnit::Millisecond, opt_tz_rhs), - ) => { - sub_timestamp_left_scalar_macro!( - array, - lhs, - as_timestamp_millisecond_array, - IntervalDayTimeType, - opt_tz_lhs.as_deref(), - 1, - opt_tz_rhs.as_deref(), - MILLISECOND_MODE, - milliseconds_sub, - NaiveDateTime::timestamp_millis - ) - } - ( - ScalarValue::TimestampMicrosecond(Some(lhs), opt_tz_lhs), - DataType::Timestamp(TimeUnit::Microsecond, opt_tz_rhs), - ) => { - sub_timestamp_left_scalar_macro!( - array, - lhs, - as_timestamp_microsecond_array, - IntervalMonthDayNanoType, - opt_tz_lhs.as_deref(), - 1000, - opt_tz_rhs.as_deref(), - NANOSECOND_MODE, - microseconds_sub, - NaiveDateTime::timestamp_micros - ) - } - ( - ScalarValue::TimestampNanosecond(Some(lhs), opt_tz_lhs), - DataType::Timestamp(TimeUnit::Nanosecond, opt_tz_rhs), - ) => { - sub_timestamp_left_scalar_macro!( - array, - lhs, - as_timestamp_nanosecond_array, - IntervalMonthDayNanoType, - opt_tz_lhs.as_deref(), - 1, - opt_tz_rhs.as_deref(), - NANOSECOND_MODE, - nanoseconds_sub, - NaiveDateTime::timestamp_nanos - ) - } - (_, _) => { - return Err(DataFusionError::Internal(format!( - "Invalid scalar - array types for Timestamp subtraction: {:?} - {:?}", - scalar.get_datatype(), - array.data_type() - ))); - } - }; - Ok(ret) -} - -/// This function handles timestamp +/- interval operations where the former is -/// an array and the latter is a scalar, resulting in an array. -pub fn ts_op_scalar_interval( - array: &ArrayRef, - sign: i32, - scalar: &ScalarValue, -) -> Result { - let ret = match array.data_type() { - DataType::Timestamp(TimeUnit::Second, tz) => { - op_timestamp_interval_macro!( - array, - as_timestamp_second_array, - TimestampSecondType, - seconds_add, - scalar, - sign, - tz - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - op_timestamp_interval_macro!( - array, - as_timestamp_millisecond_array, - TimestampMillisecondType, - milliseconds_add, - scalar, - sign, - tz - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - op_timestamp_interval_macro!( - array, - as_timestamp_microsecond_array, - TimestampMicrosecondType, - microseconds_add, - scalar, - sign, - tz - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - op_timestamp_interval_macro!( - array, - as_timestamp_nanosecond_array, - TimestampNanosecondType, - nanoseconds_add, - scalar, - sign, - tz - ) - } - _ => Err(DataFusionError::Internal(format!( - "Invalid lhs type for Timestamp vs Interval operations: {}", - array.data_type() - )))?, - }; - Ok(ret) -} - -/// This function handles timestamp +/- interval operations where the former is -/// a scalar and the latter is an array, resulting in an array. -pub fn scalar_ts_op_interval( - scalar: &ScalarValue, - sign: i32, - array: &ArrayRef, -) -> Result { - use DataType::*; - use IntervalUnit::*; - use ScalarValue::*; - let ret = match (scalar, array.data_type()) { - // Second op YearMonth - (TimestampSecond(Some(ts_sec), tz), Interval(YearMonth)) => { - let naive_date = NaiveDateTime::from_timestamp_opt(*ts_sec, 0); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_ym_array, - TimestampSecondType, - IntervalYearMonthType, - shift_months, - NaiveDateTime::timestamp - ) - } - // Millisecond op YearMonth - (TimestampMillisecond(Some(ts_ms), tz), Interval(YearMonth)) => { - let naive_date = NaiveDateTime::from_timestamp_millis(*ts_ms); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_ym_array, - TimestampSecondType, - IntervalYearMonthType, - shift_months, - NaiveDateTime::timestamp - ) - } - // Microsecond op YearMonth - (TimestampMicrosecond(Some(ts_us), tz), Interval(YearMonth)) => { - let naive_date = NaiveDateTime::from_timestamp_micros(*ts_us); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_ym_array, - TimestampSecondType, - IntervalYearMonthType, - shift_months, - NaiveDateTime::timestamp - ) - } - // Nanosecond op YearMonth - (TimestampNanosecond(Some(ts_ns), tz), Interval(YearMonth)) => { - let naive_date = NaiveDateTime::from_timestamp_opt( - ts_ns.div_euclid(1_000_000_000), - ts_ns.rem_euclid(1_000_000_000).try_into().map_err(|_| { - DataFusionError::Internal("Overflow of divison".to_string()) - })?, - ); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_ym_array, - TimestampSecondType, - IntervalYearMonthType, - shift_months, - NaiveDateTime::timestamp - ) - } - // Second op DayTime - (TimestampSecond(Some(ts_sec), tz), Interval(DayTime)) => { - let naive_date = NaiveDateTime::from_timestamp_opt(*ts_sec, 0); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_dt_array, - TimestampSecondType, - IntervalDayTimeType, - add_day_time, - NaiveDateTime::timestamp - ) - } - // Millisecond op DayTime - (TimestampMillisecond(Some(ts_ms), tz), Interval(DayTime)) => { - let naive_date = NaiveDateTime::from_timestamp_millis(*ts_ms); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_dt_array, - TimestampMillisecondType, - IntervalDayTimeType, - add_day_time, - NaiveDateTime::timestamp_millis - ) - } - // Microsecond op DayTime - (TimestampMicrosecond(Some(ts_us), tz), Interval(DayTime)) => { - let naive_date = NaiveDateTime::from_timestamp_micros(*ts_us); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_dt_array, - TimestampMicrosecondType, - IntervalDayTimeType, - add_day_time, - NaiveDateTime::timestamp_micros - ) - } - // Nanosecond op DayTime - (TimestampNanosecond(Some(ts_ns), tz), Interval(DayTime)) => { - let naive_date = NaiveDateTime::from_timestamp_opt( - ts_ns.div_euclid(1_000_000_000), - ts_ns.rem_euclid(1_000_000_000).try_into().map_err(|_| { - DataFusionError::Internal("Overflow of divison".to_string()) - })?, - ); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_dt_array, - TimestampNanosecondType, - IntervalDayTimeType, - add_day_time, - NaiveDateTime::timestamp_nanos - ) - } - // Second op MonthDayNano - (TimestampSecond(Some(ts_sec), tz), Interval(MonthDayNano)) => { - let naive_date = NaiveDateTime::from_timestamp_opt(*ts_sec, 0); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_mdn_array, - TimestampSecondType, - IntervalMonthDayNanoType, - add_m_d_nano, - NaiveDateTime::timestamp - ) - } - // Millisecond op MonthDayNano - (TimestampMillisecond(Some(ts_ms), tz), Interval(MonthDayNano)) => { - let naive_date = NaiveDateTime::from_timestamp_millis(*ts_ms); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_mdn_array, - TimestampMillisecondType, - IntervalMonthDayNanoType, - add_m_d_nano, - NaiveDateTime::timestamp_millis - ) - } - // Microsecond op MonthDayNano - (TimestampMicrosecond(Some(ts_us), tz), Interval(MonthDayNano)) => { - let naive_date = NaiveDateTime::from_timestamp_micros(*ts_us); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_mdn_array, - TimestampMicrosecondType, - IntervalMonthDayNanoType, - add_m_d_nano, - NaiveDateTime::timestamp_micros - ) - } - - // Nanosecond op MonthDayNano - (TimestampNanosecond(Some(ts_ns), tz), Interval(MonthDayNano)) => { - let naive_date = NaiveDateTime::from_timestamp_opt( - ts_ns.div_euclid(1_000_000_000), - ts_ns.rem_euclid(1_000_000_000).try_into().map_err(|_| { - DataFusionError::Internal("Overflow of divison".to_string()) - })?, - ); - scalar_ts_op_interval_macro!( - naive_date, - tz, - array, - sign, - as_interval_mdn_array, - TimestampNanosecondType, - IntervalMonthDayNanoType, - add_m_d_nano, - NaiveDateTime::timestamp_nanos - ) - } - _ => Err(DataFusionError::Internal( - "Invalid types for Timestamp vs Interval operations".to_string(), - ))?, - }; - Ok(ret) -} - -/// This function handles interval +/- interval operations where the former is -/// an array and the latter is a scalar, resulting in an interval array. -pub fn interval_op_scalar_interval( - array: &ArrayRef, - sign: i32, - scalar: &ScalarValue, -) -> Result { - use DataType::*; - use IntervalUnit::*; - use ScalarValue::*; - let ret = match (array.data_type(), scalar) { - (Interval(YearMonth), IntervalYearMonth(Some(rhs))) => { - op_interval_macro!( - array, - as_interval_ym_array, - IntervalYearMonthType, - op_ym, - rhs, - sign - ) - } - (Interval(YearMonth), IntervalDayTime(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_ym_array, - false, - op_ym_dt, - rhs, - sign, - i32, - i64 - ) - } - (Interval(YearMonth), IntervalMonthDayNano(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_ym_array, - false, - op_ym_mdn, - rhs, - sign, - i32, - i128 - ) - } - (Interval(DayTime), IntervalYearMonth(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_dt_array, - true, - op_ym_dt, - rhs, - sign, - i32, - i64 - ) - } - (Interval(DayTime), IntervalDayTime(Some(rhs))) => { - op_interval_macro!( - array, - as_interval_dt_array, - IntervalDayTimeType, - op_dt, - rhs, - sign - ) - } - (Interval(DayTime), IntervalMonthDayNano(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_dt_array, - false, - op_dt_mdn, - rhs, - sign, - i64, - i128 - ) - } - (Interval(MonthDayNano), IntervalYearMonth(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_mdn_array, - true, - op_ym_mdn, - rhs, - sign, - i32, - i128 - ) - } - (Interval(MonthDayNano), IntervalDayTime(Some(rhs))) => { - op_interval_cross_macro!( - array, - as_interval_mdn_array, - true, - op_dt_mdn, - rhs, - sign, - i64, - i128 - ) - } - (Interval(MonthDayNano), IntervalMonthDayNano(Some(rhs))) => { - op_interval_macro!( - array, - as_interval_mdn_array, - IntervalMonthDayNanoType, - op_mdn, - rhs, - sign - ) - } - _ => Err(DataFusionError::Internal(format!( - "Invalid operands for Interval vs Interval operations: {} - {}", - array.data_type(), - scalar.get_datatype(), - )))?, - }; - Ok(ret) -} - -/// This function handles interval +/- interval operations where the former is -/// a scalar and the latter is an array, resulting in an interval array. -pub fn scalar_interval_op_interval( - scalar: &ScalarValue, - sign: i32, - array: &ArrayRef, -) -> Result { - use DataType::*; - use IntervalUnit::*; - use ScalarValue::*; - let ret = match (scalar, array.data_type()) { - // YearMonth op YearMonth - (IntervalYearMonth(Some(lhs)), Interval(YearMonth)) => { - let array = as_interval_ym_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_ym(*lhs, rhs, sign)); - Arc::new(ret) as _ - } - // DayTime op YearMonth - (IntervalDayTime(Some(lhs)), Interval(YearMonth)) => { - let array = as_interval_ym_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_ym_dt(rhs, *lhs, sign, true)); - Arc::new(ret) as _ - } - // MonthDayNano op YearMonth - (IntervalMonthDayNano(Some(lhs)), Interval(YearMonth)) => { - let array = as_interval_ym_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_ym_mdn(rhs, *lhs, sign, true)); - Arc::new(ret) as _ - } - // YearMonth op DayTime - (IntervalYearMonth(Some(lhs)), Interval(DayTime)) => { - let array = as_interval_dt_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_ym_dt(*lhs, rhs, sign, false)); - Arc::new(ret) as _ - } - // DayTime op DayTime - (IntervalDayTime(Some(lhs)), Interval(DayTime)) => { - let array = as_interval_dt_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_dt(*lhs, rhs, sign)); - Arc::new(ret) as _ - } - // MonthDayNano op DayTime - (IntervalMonthDayNano(Some(lhs)), Interval(DayTime)) => { - let array = as_interval_dt_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_dt_mdn(rhs, *lhs, sign, true)); - Arc::new(ret) as _ - } - // YearMonth op MonthDayNano - (IntervalYearMonth(Some(lhs)), Interval(MonthDayNano)) => { - let array = as_interval_mdn_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_ym_mdn(*lhs, rhs, sign, false)); - Arc::new(ret) as _ - } - // DayTime op MonthDayNano - (IntervalDayTime(Some(lhs)), Interval(MonthDayNano)) => { - let array = as_interval_mdn_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_dt_mdn(*lhs, rhs, sign, false)); - Arc::new(ret) as _ - } - // MonthDayNano op MonthDayNano - (IntervalMonthDayNano(Some(lhs)), Interval(MonthDayNano)) => { - let array = as_interval_mdn_array(&array)?; - let ret: PrimitiveArray = - unary(array, |rhs| op_mdn(*lhs, rhs, sign)); - Arc::new(ret) as _ - } - _ => Err(DataFusionError::Internal(format!( - "Invalid operands for Interval vs Interval operations: {} - {}", - scalar.get_datatype(), - array.data_type(), - )))?, - }; - Ok(ret) -} - -/// Performs a timestamp subtraction operation on two arrays and returns the resulting array. -pub fn ts_array_op(array_lhs: &ArrayRef, array_rhs: &ArrayRef) -> Result { - use DataType::*; - use TimeUnit::*; - match (array_lhs.data_type(), array_rhs.data_type()) { - (Timestamp(Second, opt_tz_lhs), Timestamp(Second, opt_tz_rhs)) => Ok(ts_sub_op!( - array_lhs, - array_rhs, - opt_tz_lhs, - opt_tz_rhs, - 1000i64, - as_timestamp_second_array, - seconds_sub, - NaiveDateTime::timestamp, - MILLISECOND_MODE, - IntervalDayTimeType - )), - (Timestamp(Millisecond, opt_tz_lhs), Timestamp(Millisecond, opt_tz_rhs)) => { - Ok(ts_sub_op!( - array_lhs, - array_rhs, - opt_tz_lhs, - opt_tz_rhs, - 1i64, - as_timestamp_millisecond_array, - milliseconds_sub, - NaiveDateTime::timestamp_millis, - MILLISECOND_MODE, - IntervalDayTimeType - )) - } - (Timestamp(Microsecond, opt_tz_lhs), Timestamp(Microsecond, opt_tz_rhs)) => { - Ok(ts_sub_op!( - array_lhs, - array_rhs, - opt_tz_lhs, - opt_tz_rhs, - 1000i64, - as_timestamp_microsecond_array, - microseconds_sub, - NaiveDateTime::timestamp_micros, - NANOSECOND_MODE, - IntervalMonthDayNanoType - )) - } - (Timestamp(Nanosecond, opt_tz_lhs), Timestamp(Nanosecond, opt_tz_rhs)) => { - Ok(ts_sub_op!( - array_lhs, - array_rhs, - opt_tz_lhs, - opt_tz_rhs, - 1i64, - as_timestamp_nanosecond_array, - nanoseconds_sub, - NaiveDateTime::timestamp_nanos, - NANOSECOND_MODE, - IntervalMonthDayNanoType - )) - } - (_, _) => Err(DataFusionError::Execution(format!( - "Invalid array types for Timestamp subtraction: {} - {}", - array_lhs.data_type(), - array_rhs.data_type() - ))), - } -} -/// Performs an interval operation on two arrays and returns the resulting array. -/// The operation sign determines whether to perform addition or subtraction. -/// The data type and unit of the two input arrays must match the supported combinations. -pub fn interval_array_op( - array_lhs: &ArrayRef, - array_rhs: &ArrayRef, - sign: i32, -) -> Result { - use DataType::*; - use IntervalUnit::*; - match (array_lhs.data_type(), array_rhs.data_type()) { - (Interval(YearMonth), Interval(YearMonth)) => Ok(interval_op!( - array_lhs, - array_rhs, - as_interval_ym_array, - op_ym, - sign, - IntervalYearMonthType - )), - (Interval(YearMonth), Interval(DayTime)) => Ok(interval_cross_op!( - array_lhs, - array_rhs, - as_interval_ym_array, - as_interval_dt_array, - op_ym_dt, - sign, - false, - IntervalYearMonthType, - IntervalDayTimeType - )), - (Interval(YearMonth), Interval(MonthDayNano)) => Ok(interval_cross_op!( - array_lhs, - array_rhs, - as_interval_ym_array, - as_interval_mdn_array, - op_ym_mdn, - sign, - false, - IntervalYearMonthType, - IntervalMonthDayNanoType - )), - (Interval(DayTime), Interval(YearMonth)) => Ok(interval_cross_op!( - array_rhs, - array_lhs, - as_interval_ym_array, - as_interval_dt_array, - op_ym_dt, - sign, - true, - IntervalYearMonthType, - IntervalDayTimeType - )), - (Interval(DayTime), Interval(DayTime)) => Ok(interval_op!( - array_lhs, - array_rhs, - as_interval_dt_array, - op_dt, - sign, - IntervalDayTimeType - )), - (Interval(DayTime), Interval(MonthDayNano)) => Ok(interval_cross_op!( - array_lhs, - array_rhs, - as_interval_dt_array, - as_interval_mdn_array, - op_dt_mdn, - sign, - false, - IntervalDayTimeType, - IntervalMonthDayNanoType - )), - (Interval(MonthDayNano), Interval(YearMonth)) => Ok(interval_cross_op!( - array_rhs, - array_lhs, - as_interval_ym_array, - as_interval_mdn_array, - op_ym_mdn, - sign, - true, - IntervalYearMonthType, - IntervalMonthDayNanoType - )), - (Interval(MonthDayNano), Interval(DayTime)) => Ok(interval_cross_op!( - array_rhs, - array_lhs, - as_interval_dt_array, - as_interval_mdn_array, - op_dt_mdn, - sign, - true, - IntervalDayTimeType, - IntervalMonthDayNanoType - )), - (Interval(MonthDayNano), Interval(MonthDayNano)) => Ok(interval_op!( - array_lhs, - array_rhs, - as_interval_mdn_array, - op_mdn, - sign, - IntervalMonthDayNanoType - )), - (_, _) => Err(DataFusionError::Execution(format!( - "Invalid array types for Interval operation: {} {} {}", - array_lhs.data_type(), - sign, - array_rhs.data_type() - ))), - } -} - -/// Performs a timestamp/interval operation on two arrays and returns the resulting array. -/// The operation sign determines whether to perform addition or subtraction. -/// The data type and unit of the two input arrays must match the supported combinations. -pub fn ts_interval_array_op( - array_lhs: &ArrayRef, - sign: i32, - array_rhs: &ArrayRef, -) -> Result { - use DataType::*; - use IntervalUnit::*; - use TimeUnit::*; - match (array_lhs.data_type(), array_rhs.data_type()) { - (Timestamp(Second, tz), Interval(YearMonth)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_second_array, - as_interval_ym_array, - seconds_add_array::, - sign, - TimestampSecondType, - IntervalYearMonthType - )), - (Timestamp(Second, tz), Interval(DayTime)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_second_array, - as_interval_dt_array, - seconds_add_array::, - sign, - TimestampSecondType, - IntervalDayTimeType - )), - (Timestamp(Second, tz), Interval(MonthDayNano)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_second_array, - as_interval_mdn_array, - seconds_add_array::, - sign, - TimestampSecondType, - IntervalMonthDayNanoType - )), - (Timestamp(Millisecond, tz), Interval(YearMonth)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_millisecond_array, - as_interval_ym_array, - milliseconds_add_array::, - sign, - TimestampMillisecondType, - IntervalYearMonthType - )), - (Timestamp(Millisecond, tz), Interval(DayTime)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_millisecond_array, - as_interval_dt_array, - milliseconds_add_array::, - sign, - TimestampMillisecondType, - IntervalDayTimeType - )), - (Timestamp(Millisecond, tz), Interval(MonthDayNano)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_millisecond_array, - as_interval_mdn_array, - milliseconds_add_array::, - sign, - TimestampMillisecondType, - IntervalMonthDayNanoType - )), - (Timestamp(Microsecond, tz), Interval(YearMonth)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_microsecond_array, - as_interval_ym_array, - microseconds_add_array::, - sign, - TimestampMicrosecondType, - IntervalYearMonthType - )), - (Timestamp(Microsecond, tz), Interval(DayTime)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_microsecond_array, - as_interval_dt_array, - microseconds_add_array::, - sign, - TimestampMicrosecondType, - IntervalDayTimeType - )), - (Timestamp(Microsecond, tz), Interval(MonthDayNano)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_microsecond_array, - as_interval_mdn_array, - microseconds_add_array::, - sign, - TimestampMicrosecondType, - IntervalMonthDayNanoType - )), - (Timestamp(Nanosecond, tz), Interval(YearMonth)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_nanosecond_array, - as_interval_ym_array, - nanoseconds_add_array::, - sign, - TimestampNanosecondType, - IntervalYearMonthType - )), - (Timestamp(Nanosecond, tz), Interval(DayTime)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_nanosecond_array, - as_interval_dt_array, - nanoseconds_add_array::, - sign, - TimestampNanosecondType, - IntervalDayTimeType - )), - (Timestamp(Nanosecond, tz), Interval(MonthDayNano)) => Ok(ts_interval_op!( - array_lhs, - array_rhs, - tz, - as_timestamp_nanosecond_array, - as_interval_mdn_array, - nanoseconds_add_array::, - sign, - TimestampNanosecondType, - IntervalMonthDayNanoType - )), - (_, _) => Err(DataFusionError::Execution(format!( - "Invalid array types for Timestamp Interval operation: {} {} {}", - array_lhs.data_type(), - sign, - array_rhs.data_type() - ))), - } -} - -#[inline] -pub fn date32_interval_ym_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |ym| { - let months = Months::new(ym.try_into().map_err(|_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - })?); - let value = month_op(*prior, months).ok_or_else(|| { - DataFusionError::Internal("Resulting date is out of range".to_string()) - })?; - Ok((value - *epoch).num_days() as i32) - }, - )?) as _; - Ok(ret) -} - -#[inline] -pub fn date32_interval_dt_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - day_op: fn(NaiveDate, Days) -> Option, -) -> Result { - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |dt| { - let (days, millis) = IntervalDayTimeType::to_parts(dt); - let days = Days::new(days.try_into().map_err(|_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - })?); - let value = day_op(*prior, days).ok_or_else(|| { - DataFusionError::Internal("Resulting date is out of range".to_string()) - })?; - let milli_days = millis as i64 / MILLISECONDS_IN_DAY; - Ok(((value - *epoch).num_days() - milli_days) as i32) - }, - )?) as _; - Ok(ret) -} - -#[inline] -pub fn date32_interval_mdn_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - day_op: fn(NaiveDate, Days) -> Option, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let cast_err = |_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - }; - let out_of_range = - || DataFusionError::Internal("Resulting date is out of range".to_string()); - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |mdn| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(mdn); - let months_obj = Months::new(months.try_into().map_err(cast_err)?); - let month_diff = month_op(*prior, months_obj).ok_or_else(out_of_range)?; - let days_obj = Days::new(days.try_into().map_err(cast_err)?); - let value = day_op(month_diff, days_obj).ok_or_else(out_of_range)?; - let nano_days = nanos / NANOSECONDS_IN_DAY; - Ok(((value - *epoch).num_days() - nano_days) as i32) - }, - )?) as _; - Ok(ret) -} - -#[inline] -pub fn date64_interval_ym_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |ym| { - let months_obj = Months::new(ym.try_into().map_err(|_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - })?); - let date = month_op(*prior, months_obj).ok_or_else(|| { - DataFusionError::Internal("Resulting date is out of range".to_string()) - })?; - Ok((date - *epoch).num_milliseconds()) - }, - )?) as _; - Ok(ret) -} - -#[inline] -pub fn date64_interval_dt_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - day_op: fn(NaiveDate, Days) -> Option, -) -> Result { - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |dt| { - let (days, millis) = IntervalDayTimeType::to_parts(dt); - let days_obj = Days::new(days.try_into().map_err(|_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - })?); - let date = day_op(*prior, days_obj).ok_or_else(|| { - DataFusionError::Internal("Resulting date is out of range".to_string()) - })?; - Ok((date - *epoch).num_milliseconds() - millis as i64) - }, - )?) as _; - Ok(ret) -} - -#[inline] -pub fn date64_interval_mdn_op( - right: &Arc, - epoch: &NaiveDate, - prior: &NaiveDate, - day_op: fn(NaiveDate, Days) -> Option, - month_op: fn(NaiveDate, Months) -> Option, -) -> Result { - let cast_err = |_| { - DataFusionError::Internal( - "Interval values cannot be casted as unsigned integers".to_string(), - ) - }; - let out_of_range = - || DataFusionError::Internal("Resulting date is out of range".to_string()); - let right: &PrimitiveArray = right.as_primitive(); - let ret = Arc::new(try_unary::( - right, - |mdn| { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(mdn); - let months_obj = Months::new(months.try_into().map_err(cast_err)?); - let month_diff = month_op(*prior, months_obj).ok_or_else(out_of_range)?; - let days_obj = Days::new(days.try_into().map_err(cast_err)?); - let value = day_op(month_diff, days_obj).ok_or_else(out_of_range)?; - Ok((value - *epoch).num_milliseconds() - nanos / 1_000_000) - }, - )?) as _; - Ok(ret) -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type; - use datafusion_expr::Operator; - - fn create_decimal_array( - array: &[Option], - precision: u8, - scale: i8, - ) -> Decimal128Array { - let mut decimal_builder = Decimal128Builder::with_capacity(array.len()); - - for value in array.iter().copied() { - decimal_builder.append_option(value) - } - decimal_builder - .finish() - .with_precision_and_scale(precision, scale) - .unwrap() - } - - fn create_int_array(array: &[Option]) -> Int32Array { - let mut int_builder = Int32Builder::with_capacity(array.len()); - - for value in array.iter().copied() { - int_builder.append_option(value) - } - int_builder.finish() - } - - #[test] - fn comparison_decimal_op_test() -> Result<()> { - let value_i128: i128 = 123; - let decimal_array = create_decimal_array( - &[ - Some(value_i128), - None, - Some(value_i128 - 1), - Some(value_i128 + 1), - ], - 25, - 3, - ); - let left_decimal_array = decimal_array; - let right_decimal_array = create_decimal_array( - &[ - Some(value_i128 - 1), - Some(value_i128), - Some(value_i128 + 1), - Some(value_i128 + 1), - ], - 25, - 3, - ); - - // is_distinct: left distinct right - let result = is_distinct_from(&left_decimal_array, &right_decimal_array)?; - assert_eq!( - BooleanArray::from(vec![Some(true), Some(true), Some(true), Some(false)]), - result - ); - // is_distinct: left distinct right - let result = is_not_distinct_from(&left_decimal_array, &right_decimal_array)?; - assert_eq!( - BooleanArray::from(vec![Some(false), Some(false), Some(false), Some(true)]), - result - ); - Ok(()) - } - - #[test] - fn arithmetic_decimal_op_test() -> Result<()> { - let value_i128: i128 = 123; - let left_decimal_array = create_decimal_array( - &[ - Some(value_i128), - None, - Some(value_i128 - 1), - Some(value_i128 + 1), - ], - 25, - 3, - ); - let right_decimal_array = create_decimal_array( - &[ - Some(value_i128), - Some(value_i128), - Some(value_i128), - Some(value_i128), - ], - 25, - 3, - ); - // add - let result_type = decimal_op_mathematics_type( - &Operator::Plus, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let result = - add_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(246), None, Some(245), Some(247)], 26, 3); - assert_eq!(&expect, result); - let result = add_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(133), None, Some(132), Some(134)], 26, 3); - assert_eq!(&expect, result); - // subtract - let result_type = decimal_op_mathematics_type( - &Operator::Minus, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let result = subtract_dyn_decimal( - &left_decimal_array, - &right_decimal_array, - &result_type, - )?; - let result = as_decimal128_array(&result)?; - let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 26, 3); - assert_eq!(&expect, result); - let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(113), None, Some(112), Some(114)], 26, 3); - assert_eq!(&expect, result); - // multiply - let result_type = decimal_op_mathematics_type( - &Operator::Multiply, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let result = multiply_dyn_decimal( - &left_decimal_array, - &right_decimal_array, - &result_type, - )?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(15129), None, Some(15006), Some(15252)], 38, 6); - assert_eq!(&expect, result); - let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(1230), None, Some(1220), Some(1240)], 38, 6); - assert_eq!(&expect, result); - // divide - let result_type = decimal_op_mathematics_type( - &Operator::Divide, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let left_decimal_array = create_decimal_array( - &[ - Some(1234567), - None, - Some(1234567), - Some(1234567), - Some(1234567), - ], - 25, - 3, - ); - let right_decimal_array = create_decimal_array( - &[Some(10), Some(100), Some(55), Some(-123), None], - 25, - 3, - ); - let result = divide_dyn_opt_decimal( - &left_decimal_array, - &right_decimal_array, - &result_type, - )?; - let result = as_decimal128_array(&result)?; - let expect = create_decimal_array( - &[ - Some(12345670000000000000000000000000000), - None, - Some(2244667272727272727272727272727272), - Some(-1003713008130081300813008130081300), - None, - ], - 38, - 29, - ); - assert_eq!(&expect, result); - let result = divide_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = create_decimal_array( - &[ - Some(12345670000000000000000000000000000), - None, - Some(12345670000000000000000000000000000), - Some(12345670000000000000000000000000000), - Some(12345670000000000000000000000000000), - ], - 38, - 29, - ); - assert_eq!(&expect, result); - // modulus - let result_type = decimal_op_mathematics_type( - &Operator::Modulo, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let result = - modulus_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3); - assert_eq!(&expect, result); - let result = modulus_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; - let result = as_decimal128_array(&result)?; - let expect = - create_decimal_array(&[Some(7), None, Some(7), Some(7), Some(7)], 25, 3); - assert_eq!(&expect, result); - - Ok(()) - } - - #[test] - fn arithmetic_decimal_divide_by_zero() { - let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1); - let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1); - - let result_type = decimal_op_mathematics_type( - &Operator::Divide, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let err = - divide_decimal_dyn_scalar(&left_decimal_array, 0, &result_type).unwrap_err(); - assert_eq!("Arrow error: Divide by zero error", err.to_string()); - let result_type = decimal_op_mathematics_type( - &Operator::Modulo, - left_decimal_array.data_type(), - right_decimal_array.data_type(), - ) - .unwrap(); - let err = - modulus_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type) - .unwrap_err(); - assert_eq!("Arrow error: Divide by zero error", err.to_string()); - let err = - modulus_decimal_dyn_scalar(&left_decimal_array, 0, &result_type).unwrap_err(); - assert_eq!("Arrow error: Divide by zero error", err.to_string()); - } - - #[test] - fn is_distinct_from_non_nulls() -> Result<()> { - let left_int_array = - create_int_array(&[Some(0), Some(1), Some(2), Some(3), Some(4)]); - let right_int_array = - create_int_array(&[Some(4), Some(3), Some(2), Some(1), Some(0)]); - - assert_eq!( - BooleanArray::from(vec![ - Some(true), - Some(true), - Some(false), - Some(true), - Some(true), - ]), - is_distinct_from(&left_int_array, &right_int_array)? - ); - assert_eq!( - BooleanArray::from(vec![ - Some(false), - Some(false), - Some(true), - Some(false), - Some(false), - ]), - is_not_distinct_from(&left_int_array, &right_int_array)? - ); - Ok(()) - } - - #[test] - fn is_distinct_from_nulls() -> Result<()> { - let left_int_array = - create_int_array(&[Some(0), Some(0), None, Some(3), Some(0), Some(0)]); - let right_int_array = - create_int_array(&[Some(0), None, None, None, Some(0), None]); - - assert_eq!( - BooleanArray::from(vec![ - Some(false), - Some(true), - Some(false), - Some(true), - Some(false), - Some(true), - ]), - is_distinct_from(&left_int_array, &right_int_array)? - ); - - assert_eq!( - BooleanArray::from(vec![ - Some(true), - Some(false), - Some(true), - Some(false), - Some(true), - Some(false), - ]), - is_not_distinct_from(&left_int_array, &right_int_array)? - ); - Ok(()) - } - - #[test] - fn test_decimal_multiply_fixed_point_dyn() { - // [123456789] - let a = Decimal128Array::from(vec![123456789000000000000000000]) - .with_precision_and_scale(38, 18) - .unwrap(); - - // [10] - let b = Decimal128Array::from(vec![10000000000000000000]) - .with_precision_and_scale(38, 18) - .unwrap(); - - // Avoid overflow by reducing the scale. - let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap(); - // [1234567890] - let expected = Arc::new( - Decimal128Array::from(vec![12345678900000000000000000000000000000]) - .with_precision_and_scale(38, 28) - .unwrap(), - ) as ArrayRef; - - assert_eq!(&expected, &result); - assert_eq!( - result.as_primitive::().value_as_string(0), - "1234567890.0000000000000000000000000000" - ); - - // [123456789, 10] - let a = Decimal128Array::from(vec![ - 123456789000000000000000000, - 10000000000000000000, - ]) - .with_precision_and_scale(38, 18) - .unwrap(); - - // [10, 123456789, 12] - let b = Decimal128Array::from(vec![ - 10000000000000000000, - 123456789000000000000000000, - 12000000000000000000, - ]) - .with_precision_and_scale(38, 18) - .unwrap(); - - let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]); - let array1 = DictionaryArray::new(keys, Arc::new(a)); - let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]); - let array2 = DictionaryArray::new(keys, Arc::new(b)); - - let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap(); - let expected = Arc::new( - Decimal128Array::from(vec![ - Some(12345678900000000000000000000000000000), - Some(12345678900000000000000000000000000000), - Some(1200000000000000000000000000000), - None, - ]) - .with_precision_and_scale(38, 28) - .unwrap(), - ) as ArrayRef; - - assert_eq!(&expected, &result); - assert_eq!( - result.as_primitive::().value_as_string(0), - "1234567890.0000000000000000000000000000" - ); - assert_eq!( - result.as_primitive::().value_as_string(1), - "1234567890.0000000000000000000000000000" - ); - assert_eq!( - result.as_primitive::().value_as_string(2), - "120.0000000000000000000000000000" - ); - } -} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 903ccda62f084..5fcfd61d90e49 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,6 +16,7 @@ // under the License. use std::borrow::Cow; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::expressions::try_cast; @@ -23,11 +24,13 @@ use crate::expressions::NoOp; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::*; +use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, eq_dyn, is_null, not, or, prep_null_mask_filter}; +use arrow::compute::{and, is_null, not, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, DataFusionError, Result}; +use datafusion_common::exec_err; +use datafusion_common::{cast::as_boolean_array, internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use itertools::Itertools; @@ -51,7 +54,7 @@ type WhenThen = (Arc, Arc); /// [WHEN ...] /// [ELSE result] /// END -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct CaseExpr { /// Optional base expression that can be compared to literal values in the "when" expressions expr: Option>, @@ -85,9 +88,7 @@ impl CaseExpr { else_expr: Option>, ) -> Result { if when_then_expr.is_empty() { - Err(DataFusionError::Execution( - "There must be at least one WHEN clause".to_string(), - )) + exec_err!("There must be at least one WHEN clause") } else { Ok(Self { expr, @@ -125,7 +126,7 @@ impl CaseExpr { let return_type = self.data_type(&batch.schema())?; let expr = self.expr.as_ref().unwrap(); let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows()); + let base_value = base_value.into_array(batch.num_rows())?; let base_nulls = is_null(base_value.as_ref())?; // start with nulls as default output @@ -136,9 +137,9 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value - let when_match = eq_dyn(&when_value, base_value.as_ref())?; + let when_match = eq(&when_value, &base_value)?; // Treat nulls as false let when_match = match when_match.null_count() { 0 => Cow::Borrowed(&when_match), @@ -152,7 +153,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -169,7 +170,7 @@ impl CaseExpr { remainder = or(&base_nulls, &remainder)?; let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -193,7 +194,7 @@ impl CaseExpr { let when_value = self.when_then_expr[i] .0 .evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows()); + let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -213,7 +214,7 @@ impl CaseExpr { ColumnarValue::Scalar(value) if value.is_null() => { new_null_array(&return_type, batch.num_rows()) } - _ => then_value.into_array(batch.num_rows()), + _ => then_value.into_array(batch.num_rows())?, }; current_value = @@ -230,7 +231,7 @@ impl CaseExpr { .unwrap_or_else(|_| e.clone()); let else_ = expr .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows()); + .into_array(batch.num_rows())?; current_value = zip(&remainder, else_.as_ref(), current_value.as_ref())?; } @@ -318,9 +319,7 @@ impl PhysicalExpr for CaseExpr { children: Vec>, ) -> Result> { if children.len() != self.children().len() { - Err(DataFusionError::Internal( - "CaseExpr: Wrong number of children".to_string(), - )) + internal_err!("CaseExpr: Wrong number of children") } else { assert_eq!(children.len() % 2, 0); let expr = match children[0].clone().as_any().downcast_ref::() { @@ -348,6 +347,11 @@ impl PhysicalExpr for CaseExpr { )?)) } } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for CaseExpr { @@ -398,6 +402,7 @@ mod tests { use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; + use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; @@ -420,7 +425,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -448,7 +456,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -480,7 +491,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -518,7 +532,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); @@ -546,7 +563,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -578,7 +598,10 @@ mod tests { Some(x), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -624,7 +647,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_int32_array(&result)?; let expected = @@ -656,7 +682,10 @@ mod tests { Some(else_value), schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -688,7 +717,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -716,7 +748,10 @@ mod tests { None, schema.as_ref(), )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_float64_array(&result).expect("failed to downcast to Float64Array"); @@ -960,9 +995,9 @@ mod tests { let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); let (when_thens, else_expr) = match coerce_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" - ))), + ), Some(data_type) => { // cast then expr let left = when_thens @@ -1004,11 +1039,10 @@ mod tests { }; thens_type .iter() - .fold(Some(else_type), |left, right_type| match left { - None => None, + .try_fold(else_type, |left_type, right_type| { // TODO: now just use the `equal` coercion rule for case when. If find the issue, and // refactor again. - Some(left_type) => comparison_coercion(&left_type, right_type), + comparison_coercion(&left_type, right_type) }) } } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 8e4e1b57e8c24..0c4ed3c125498 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -17,30 +17,28 @@ use std::any::Any; use std::fmt; +use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; +use crate::sort_properties::SortProperties; use crate::PhysicalExpr; -use arrow::compute; -use arrow::compute::{kernels, CastOptions}; + +use arrow::compute::{can_cast_types, kernels, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use compute::can_cast_types; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; -/// provide DataFusion default cast options -fn default_cast_options() -> CastOptions<'static> { - CastOptions { - safe: false, - format_options: Default::default(), - } -} +const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: false, + format_options: DEFAULT_FORMAT_OPTIONS, +}; /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CastExpr { /// The expression to cast expr: Arc, @@ -60,7 +58,7 @@ impl CastExpr { Self { expr, cast_type, - cast_options: cast_options.unwrap_or_else(default_cast_options), + cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), } } @@ -73,6 +71,11 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + + /// The cast options + pub fn cast_options(&self) -> &CastOptions<'static> { + &self.cast_options + } } impl fmt::Display for CastExpr { @@ -124,13 +127,25 @@ impl PhysicalExpr for CastExpr { &self, interval: &Interval, children: &[&Interval], - ) -> Result>> { + ) -> Result>> { let child_interval = children[0]; // Get child's datatype: - let cast_type = child_interval.get_datatype()?; - Ok(vec![Some( - interval.cast_to(&cast_type, &self.cast_options)?, - )]) + let cast_type = child_interval.data_type(); + Ok(Some( + vec![interval.cast_to(&cast_type, &self.cast_options)?], + )) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.expr.hash(&mut s); + self.cast_type.hash(&mut s); + self.cast_options.hash(&mut s); + } + + /// A [`CastExpr`] preserves the ordering of its child. + fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { + children[0] } } @@ -141,8 +156,7 @@ impl PartialEq for CastExpr { .map(|x| { self.expr.eq(&x.expr) && self.cast_type == x.cast_type - // TODO: Use https://github.com/apache/arrow-rs/issues/2966 when available - && self.cast_options.safe == x.cast_options.safe + && self.cast_options == x.cast_options }) .unwrap_or(false) } @@ -154,13 +168,26 @@ pub fn cast_column( cast_type: &DataType, cast_options: Option<&CastOptions<'static>>, ) -> Result { - let cast_options = cast_options.cloned().unwrap_or_else(default_cast_options); + let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); match value { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array(); + let scalar_array = if cast_type + == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64( + Some((float_ts * 1_000_000_000_f64).trunc() as i64), + ) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; let cast_array = kernels::cast::cast_with_options( &scalar_array, cast_type, @@ -185,12 +212,13 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(expr.clone()) - } else if can_cast_types(&expr_type, &cast_type) { + } else if can_cast_types(&expr_type, &cast_type) + || (expr_type == DataType::Float64 + && cast_type == DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)) + { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { - Err(DataFusionError::NotImplemented(format!( - "Unsupported CAST from {expr_type:?} to {cast_type:?}" - ))) + not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") } } @@ -210,6 +238,7 @@ pub fn cast( mod tests { use super::*; use crate::expressions::col; + use arrow::{ array::{ Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, @@ -218,6 +247,7 @@ mod tests { }, datatypes::*, }; + use datafusion_common::Result; // runs an end-to-end test of physical type cast @@ -247,7 +277,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -296,7 +329,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -342,13 +378,13 @@ mod tests { DataType::Decimal128(10, 3), Decimal128Array, DataType::Decimal128(20, 6), - vec![ + [ Some(1_234_000), Some(2_222_000), Some(3_000), Some(4_000_000), Some(5_000_000), - None, + None ], None ); @@ -363,7 +399,7 @@ mod tests { DataType::Decimal128(10, 3), Decimal128Array, DataType::Decimal128(10, 2), - vec![Some(123), Some(222), Some(0), Some(400), Some(500), None,], + [Some(123), Some(222), Some(0), Some(400), Some(500), None], None ); @@ -384,13 +420,13 @@ mod tests { DataType::Decimal128(10, 0), Int8Array, DataType::Int8, - vec![ + [ Some(1_i8), Some(2_i8), Some(3_i8), Some(4_i8), Some(5_i8), - None, + None ], None ); @@ -406,13 +442,13 @@ mod tests { DataType::Decimal128(10, 0), Int16Array, DataType::Int16, - vec![ + [ Some(1_i16), Some(2_i16), Some(3_i16), Some(4_i16), Some(5_i16), - None, + None ], None ); @@ -428,13 +464,13 @@ mod tests { DataType::Decimal128(10, 0), Int32Array, DataType::Int32, - vec![ + [ Some(1_i32), Some(2_i32), Some(3_i32), Some(4_i32), Some(5_i32), - None, + None ], None ); @@ -449,13 +485,13 @@ mod tests { DataType::Decimal128(10, 0), Int64Array, DataType::Int64, - vec![ + [ Some(1_i64), Some(2_i64), Some(3_i64), Some(4_i64), Some(5_i64), - None, + None ], None ); @@ -479,13 +515,13 @@ mod tests { DataType::Decimal128(10, 3), Float32Array, DataType::Float32, - vec![ + [ Some(1.234_f32), Some(2.222_f32), Some(0.003_f32), Some(4.0_f32), Some(5.0_f32), - None, + None ], None ); @@ -500,13 +536,13 @@ mod tests { DataType::Decimal128(20, 6), Float64Array, DataType::Float64, - vec![ + [ Some(0.001234_f64), Some(0.002222_f64), Some(0.000003_f64), Some(0.004_f64), Some(0.005_f64), - None, + None ], None ); @@ -522,7 +558,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(3, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),], + [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -533,7 +569,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(5, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),], + [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -544,7 +580,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(10, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),], + [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -555,7 +591,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(20, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),], + [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -566,7 +602,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(20, 2), - vec![Some(100), Some(200), Some(300), Some(400), Some(500),], + [Some(100), Some(200), Some(300), Some(400), Some(500)], None ); @@ -577,7 +613,7 @@ mod tests { vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, DataType::Decimal128(10, 2), - vec![Some(150), Some(250), Some(300), Some(112), Some(550),], + [Some(150), Some(250), Some(300), Some(112), Some(550)], None ); @@ -588,12 +624,12 @@ mod tests { vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, DataType::Decimal128(20, 4), - vec![ + [ Some(15000), Some(25000), Some(30000), Some(11235), - Some(55000), + Some(55000) ], None ); @@ -608,7 +644,7 @@ mod tests { vec![1, 2, 3, 4, 5], UInt32Array, DataType::UInt32, - vec![ + [ Some(1_u32), Some(2_u32), Some(3_u32), @@ -628,7 +664,7 @@ mod tests { vec![1, 2, 3, 4, 5], StringArray, DataType::Utf8, - vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], + [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], None ); Ok(()) @@ -658,7 +694,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index eb2be5ef217c1..62da8ff9ed44e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,16 +18,17 @@ //! Column expression use std::any::Any; +use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::physical_expr::down_cast_any_ref; +use crate::PhysicalExpr; + use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::physical_expr::down_cast_any_ref; -use crate::{AnalysisContext, PhysicalExpr}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; /// Represents the column at a given index in a RecordBatch @@ -103,11 +104,9 @@ impl PhysicalExpr for Column { Ok(self) } - /// Return the boundaries of this column, if known. - fn analyze(&self, context: AnalysisContext) -> AnalysisContext { - assert!(self.index < context.column_boundaries.len()); - let col_bounds = context.column_boundaries[self.index].clone(); - context.with_boundaries(col_bounds) + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); } } @@ -125,10 +124,10 @@ impl Column { if self.index < input_schema.fields.len() { Ok(()) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "PhysicalExpr Column references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}", self.name, - self.index, input_schema.fields.len(), input_schema.fields().iter().map(|f| f.name().clone()).collect::>()))) + self.index, input_schema.fields.len(), input_schema.fields().iter().map(|f| f.name().clone()).collect::>()) } } } @@ -176,9 +175,7 @@ impl PhysicalExpr for UnKnownColumn { /// Evaluate the expression fn evaluate(&self, _batch: &RecordBatch) -> Result { - Err(DataFusionError::Plan( - "UnKnownColumn::evaluate() should not be called".to_owned(), - )) + internal_err!("UnKnownColumn::evaluate() should not be called") } fn children(&self) -> Vec> { @@ -191,6 +188,11 @@ impl PhysicalExpr for UnKnownColumn { ) -> Result> { Ok(self) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for UnKnownColumn { @@ -210,33 +212,33 @@ pub fn col(name: &str, schema: &Schema) -> Result> { #[cfg(test)] mod test { use crate::expressions::Column; - use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr}; + use crate::PhysicalExpr; + use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics}; + use datafusion_common::Result; + use std::sync::Arc; #[test] fn out_of_bounds_data_type() { let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); let col = Column::new("id", 9); - let error = col.data_type(&schema).expect_err("error"); - assert_eq!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"]. This was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker", - &format!("{error}")) + let error = col.data_type(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) } #[test] fn out_of_bounds_nullable() { let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); let col = Column::new("id", 9); - let error = col.nullable(&schema).expect_err("error"); - assert_eq!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"]. This was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker", - &format!("{error}")) + let error = col.nullable(&schema).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) } #[test] @@ -245,83 +247,10 @@ mod test { let data: StringArray = vec!["data"].into(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; let col = Column::new("id", 9); - let error = col.evaluate(&batch).expect_err("error"); - assert_eq!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"]. This was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker", - &format!("{error}")); - Ok(()) - } - - /// Returns a pair of (schema, statistics) for a table of: - /// - a => Stats(range=[1, 100], distinct=15) - /// - b => unknown - /// - c => Stats(range=[1, 100], distinct=unknown) - fn get_test_table_stats() -> (Schema, Statistics) { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - - let columns = vec![ - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(100))), - distinct_count: Some(15), - ..Default::default() - }, - ColumnStatistics::default(), - ColumnStatistics { - min_value: Some(ScalarValue::Int32(Some(1))), - max_value: Some(ScalarValue::Int32(Some(75))), - distinct_count: None, - ..Default::default() - }, - ]; - - let statistics = Statistics { - column_statistics: Some(columns), - ..Default::default() - }; - - (schema, statistics) - } - - #[test] - fn stats_bounds_analysis() -> Result<()> { - let (schema, statistics) = get_test_table_stats(); - let context = AnalysisContext::from_statistics(&schema, &statistics); - - let cases = [ - // (name, index, expected boundaries) - ( - "a", - 0, - Some(ExprBoundaries::new( - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(100)), - Some(15), - )), - ), - ("b", 1, None), - ( - "c", - 2, - Some(ExprBoundaries::new( - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(75)), - None, - )), - ), - ]; - - for (name, index, expected) in cases { - let col = Column::new(name, index); - let test_ctx = col.analyze(context.clone()); - assert_eq!(test_ctx.boundaries, expected); - } - + let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); + assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ + but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ + DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs deleted file mode 100644 index f1933c1d180a6..0000000000000 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ /dev/null @@ -1,925 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; -use crate::intervals::{apply_operator, Interval}; -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; - -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::type_coercion::binary::get_result_type; -use datafusion_expr::{ColumnarValue, Operator}; -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -use super::binary::{resolve_temporal_op, resolve_temporal_op_scalar}; - -/// Perform DATE/TIME/TIMESTAMP +/ INTERVAL math -#[derive(Debug)] -pub struct DateTimeIntervalExpr { - lhs: Arc, - op: Operator, - rhs: Arc, -} - -impl DateTimeIntervalExpr { - /// Create a new instance of DateIntervalExpr - pub fn new( - lhs: Arc, - op: Operator, - rhs: Arc, - ) -> Self { - Self { lhs, op, rhs } - } - - /// Get the left-hand side expression - pub fn lhs(&self) -> &Arc { - &self.lhs - } - - /// Get the operator - pub fn op(&self) -> &Operator { - &self.op - } - - /// Get the right-hand side expression - pub fn rhs(&self) -> &Arc { - &self.rhs - } -} - -impl Display for DateTimeIntervalExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{} {} {}", self.lhs, self.op, self.rhs) - } -} - -impl PhysicalExpr for DateTimeIntervalExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> Result { - get_result_type( - &self.lhs.data_type(input_schema)?, - &Operator::Minus, - &self.rhs.data_type(input_schema)?, - ) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - self.lhs.nullable(input_schema) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - let lhs_value = self.lhs.evaluate(batch)?; - let rhs_value = self.rhs.evaluate(batch)?; - // Invert sign for subtraction - let sign = match self.op { - Operator::Plus => 1, - Operator::Minus => -1, - _ => { - // this should be unreachable because we check the operators in `try_new` - let msg = "Invalid operator for DateIntervalExpr"; - return Err(DataFusionError::Internal(msg.to_string())); - } - }; - // RHS is first checked. If it is a Scalar, there are 2 options: - // Either LHS is also a Scalar and matching operation is applied, - // or LHS is an Array and unary operations for related types are - // applied in evaluate_array function. If RHS is an Array, then - // LHS must also be, moreover; they must be the same Timestamp type. - match (lhs_value, rhs_value) { - (ColumnarValue::Scalar(operand_lhs), ColumnarValue::Scalar(operand_rhs)) => { - Ok(ColumnarValue::Scalar(if sign > 0 { - operand_lhs.add(&operand_rhs)? - } else { - operand_lhs.sub(&operand_rhs)? - })) - } - // This function evaluates temporal array vs scalar operations, such as timestamp - timestamp, - // interval + interval, timestamp + interval, and interval + timestamp. It takes one array and one scalar as input - // and an integer sign representing the operation (+1 for addition and -1 for subtraction). - (ColumnarValue::Array(arr), ColumnarValue::Scalar(scalar)) => { - Ok(ColumnarValue::Array(resolve_temporal_op_scalar( - &arr, sign, &scalar, false, - )?)) - } - // This function evaluates operations between a scalar value and an array of temporal - // values. One example is calculating the duration between a scalar timestamp and an - // array of timestamps (i.e. `now() - some_column`). - (ColumnarValue::Scalar(scalar), ColumnarValue::Array(arr)) => { - Ok(ColumnarValue::Array(resolve_temporal_op_scalar( - &arr, sign, &scalar, true, - )?)) - } - // This function evaluates temporal array operations, such as timestamp - timestamp, interval + interval, - // timestamp + interval, and interval + timestamp. It takes two arrays as input and an integer sign representing - // the operation (+1 for addition and -1 for subtraction). - (ColumnarValue::Array(array_lhs), ColumnarValue::Array(array_rhs)) => Ok( - ColumnarValue::Array(resolve_temporal_op(&array_lhs, sign, &array_rhs)?), - ), - } - } - - fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - // Get children intervals: - let left_interval = children[0]; - let right_interval = children[1]; - // Calculate current node's interval: - apply_operator(&self.op, left_interval, right_interval) - } - - fn propagate_constraints( - &self, - interval: &Interval, - children: &[&Interval], - ) -> Result>> { - // Get children intervals. Graph brings - let left_interval = children[0]; - let right_interval = children[1]; - let (left, right) = if self.op.is_comparison_operator() { - if interval == &Interval::CERTAINLY_FALSE { - // TODO: We will handle strictly false clauses by negating - // the comparison operator (e.g. GT to LE, LT to GE) - // once open/closed intervals are supported. - return Ok(vec![]); - } - // Propagate the comparison operator. - propagate_comparison(&self.op, left_interval, right_interval)? - } else { - // Propagate the arithmetic operator. - propagate_arithmetic(&self.op, interval, left_interval, right_interval)? - }; - Ok(vec![left, right]) - } - - fn children(&self) -> Vec> { - vec![self.lhs.clone(), self.rhs.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(DateTimeIntervalExpr::new( - children[0].clone(), - self.op, - children[1].clone(), - ))) - } -} - -impl PartialEq for DateTimeIntervalExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.lhs.eq(&x.lhs) && self.op == x.op && self.rhs.eq(&x.rhs)) - .unwrap_or(false) - } -} - -/// create a DateIntervalExpr -pub fn date_time_interval_expr( - lhs: Arc, - op: Operator, - rhs: Arc, - input_schema: &Schema, -) -> Result> { - match ( - lhs.data_type(input_schema)?, - op, - rhs.data_type(input_schema)?, - ) { - ( - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - Operator::Plus | Operator::Minus, - DataType::Interval(_), - ) - | (DataType::Timestamp(_, _), Operator::Minus, DataType::Timestamp(_, _)) - | (DataType::Interval(_), Operator::Plus, DataType::Timestamp(_, _)) - | ( - DataType::Interval(_), - Operator::Plus | Operator::Minus, - DataType::Interval(_), - ) => Ok(Arc::new(DateTimeIntervalExpr::new(lhs, op, rhs))), - (lhs, _, rhs) => Err(DataFusionError::Execution(format!( - "Invalid operation {op} between '{lhs}' and '{rhs}' for DateIntervalExpr" - ))), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::create_physical_expr; - use crate::execution_props::ExecutionProps; - use arrow::array::{ArrayRef, Date32Builder}; - use arrow::datatypes::*; - use arrow_array::IntervalMonthDayNanoArray; - use chrono::{Duration, NaiveDate}; - use datafusion_common::{Column, Result, ScalarValue, ToDFSchema}; - use datafusion_expr::Expr; - use std::ops::Add; - - #[test] - fn add_32_day_time() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(1, 0)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::Date32(Some(d))) => { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let res = epoch.add(Duration::days(d as i64)); - assert_eq!(format!("{res:?}").as_str(), "1970-01-02"); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn sub_32_year_month() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Minus; - let interval = Expr::Literal(ScalarValue::IntervalYearMonth(Some(13))); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::Date32(Some(d))) => { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let res = epoch.add(Duration::days(d as i64)); - assert_eq!(format!("{res:?}").as_str(), "1968-12-01"); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn add_64_day_time() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date64(Some(0))); - let op = Operator::Plus; - let interval = - Expr::Literal(ScalarValue::new_interval_dt(-15, -24 * 60 * 60 * 1000)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::Date64(Some(d))) => { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let res = epoch.add(Duration::milliseconds(d)); - assert_eq!(format!("{res:?}").as_str(), "1969-12-16"); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn add_32_year_month() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::IntervalYearMonth(Some(1))); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::Date32(Some(d))) => { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let res = epoch.add(Duration::days(d as i64)); - assert_eq!(format!("{res:?}").as_str(), "1970-02-01"); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn add_32_month_day_nano() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::new_interval_mdn(-12, -15, -42)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::Date32(Some(d))) => { - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let res = epoch.add(Duration::days(d as i64)); - assert_eq!(format!("{res:?}").as_str(), "1968-12-17"); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn add_1_millisecond() -> Result<()> { - // setup - let now_ts_ns = chrono::Utc::now().timestamp_nanos(); - let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 1)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(ts), None)) => { - assert_eq!(ts, now_ts_ns + 1_000_000); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - Ok(()) - } - - #[test] - fn add_2_hours() -> Result<()> { - // setup - let now_ts_s = chrono::Utc::now().timestamp(); - let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 2 * 3600 * 1_000)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(ts), None)) => { - assert_eq!(ts, now_ts_s + 2 * 3600); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - Ok(()) - } - - #[test] - fn sub_4_hours() -> Result<()> { - // setup - let now_ts_s = chrono::Utc::now().timestamp(); - let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); - let op = Operator::Minus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 4 * 3600 * 1_000)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(ts), None)) => { - assert_eq!(ts, now_ts_s - 4 * 3600); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - Ok(()) - } - - #[test] - fn add_8_days() -> Result<()> { - // setup - let now_ts_ns = chrono::Utc::now().timestamp_nanos(); - let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(8, 0)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(ts), None)) => { - assert_eq!(ts, now_ts_ns + 8 * 86400 * 1_000_000_000); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - Ok(()) - } - - #[test] - fn sub_16_days() -> Result<()> { - // setup - let now_ts_ns = chrono::Utc::now().timestamp_nanos(); - let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); - let op = Operator::Minus; - let interval = Expr::Literal(ScalarValue::new_interval_dt(16, 0)); - - // exercise - let res = exercise(&dt, op, &interval)?; - - // assert - match res { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(ts), None)) => { - assert_eq!(ts, now_ts_ns - 16 * 86400 * 1_000_000_000); - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - Ok(()) - } - - #[test] - fn array_add_26_days() -> Result<()> { - let mut builder = Date32Builder::with_capacity(8); - builder.append_slice(&[0, 1, 2, 3, 4, 5, 6, 7]); - let a: ArrayRef = Arc::new(builder.finish()); - - let schema = Schema::new(vec![Field::new("a", DataType::Date32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let dfs = schema.clone().to_dfschema()?; - let props = ExecutionProps::new(); - - let dt = Expr::Column(Column::from_name("a")); - let interval = Expr::Literal(ScalarValue::new_interval_dt(26, 0)); - let op = Operator::Plus; - - let lhs = create_physical_expr(&dt, &dfs, &schema, &props)?; - let rhs = create_physical_expr(&interval, &dfs, &schema, &props)?; - - let cut = date_time_interval_expr(lhs, op, rhs, &schema)?; - let res = cut.evaluate(&batch)?; - - let mut builder = Date32Builder::with_capacity(8); - builder.append_slice(&[26, 27, 28, 29, 30, 31, 32, 33]); - let expected: ArrayRef = Arc::new(builder.finish()); - - // assert - match res { - ColumnarValue::Array(array) => { - assert_eq!(&array, &expected) - } - _ => Err(DataFusionError::NotImplemented( - "Unexpected result!".to_string(), - ))?, - } - - Ok(()) - } - - #[test] - fn invalid_interval() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::Null); - - // exercise - let res = exercise(&dt, op, &interval); - assert!(res.is_err(), "Can't add a NULL interval"); - - Ok(()) - } - - #[test] - fn invalid_date() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Null); - let op = Operator::Plus; - let interval = Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(0))); - - // exercise - let res = exercise(&dt, op, &interval); - assert!(res.is_err(), "Can't add to NULL date"); - - Ok(()) - } - - #[test] - fn invalid_op() -> Result<()> { - // setup - let dt = Expr::Literal(ScalarValue::Date32(Some(0))); - let op = Operator::Eq; - let interval = Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(0))); - - // exercise - let res = exercise(&dt, op, &interval); - assert!(res.is_err(), "Can't add dates with == operator"); - - Ok(()) - } - - fn exercise(dt: &Expr, op: Operator, interval: &Expr) -> Result { - let mut builder = Date32Builder::with_capacity(1); - builder.append_value(0); - let a: ArrayRef = Arc::new(builder.finish()); - let schema = Schema::new(vec![Field::new("a", DataType::Date32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let dfs = schema.clone().to_dfschema()?; - let props = ExecutionProps::new(); - - let lhs = create_physical_expr(dt, &dfs, &schema, &props)?; - let rhs = create_physical_expr(interval, &dfs, &schema, &props)?; - - let lhs_str = format!("{lhs}"); - let rhs_str = format!("{rhs}"); - - let cut = DateTimeIntervalExpr::new(lhs, op, rhs); - - assert_eq!(lhs_str, format!("{}", cut.lhs())); - assert_eq!(op, cut.op().clone()); - assert_eq!(rhs_str, format!("{}", cut.rhs())); - - let res = cut.evaluate(&batch)?; - Ok(res) - } - - // In this test, ArrayRef of one element arrays is evaluated with some ScalarValues, - // aiming that resolve_temporal_op_scalar function is working properly and shows the same - // behavior with ScalarValue arithmetic. - fn experiment( - timestamp_scalar: ScalarValue, - interval_scalar: ScalarValue, - ) -> Result<()> { - let timestamp_array = timestamp_scalar.to_array(); - let interval_array = interval_scalar.to_array(); - - // timestamp + interval - let res1 = - resolve_temporal_op_scalar(×tamp_array, 1, &interval_scalar, false)?; - let res2 = timestamp_scalar.add(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Timestamp Scalar={timestamp_scalar} + Interval Scalar={interval_scalar}" - ); - let res1 = - resolve_temporal_op_scalar(×tamp_array, 1, &interval_scalar, true)?; - let res2 = interval_scalar.add(×tamp_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Timestamp Scalar={timestamp_scalar} + Interval Scalar={interval_scalar}" - ); - - // timestamp - interval - let res1 = - resolve_temporal_op_scalar(×tamp_array, -1, &interval_scalar, false)?; - let res2 = timestamp_scalar.sub(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Timestamp Scalar={timestamp_scalar} - Interval Scalar={interval_scalar}" - ); - - // timestamp - timestamp - let res1 = - resolve_temporal_op_scalar(×tamp_array, -1, ×tamp_scalar, false)?; - let res2 = timestamp_scalar.sub(×tamp_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Timestamp Scalar={timestamp_scalar} - Timestamp Scalar={timestamp_scalar}" - ); - let res1 = - resolve_temporal_op_scalar(×tamp_array, -1, ×tamp_scalar, true)?; - let res2 = timestamp_scalar.sub(×tamp_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Timestamp Scalar={timestamp_scalar} - Timestamp Scalar={timestamp_scalar}" - ); - - // interval - interval - let res1 = - resolve_temporal_op_scalar(&interval_array, -1, &interval_scalar, false)?; - let res2 = interval_scalar.sub(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Interval Scalar={interval_scalar} - Interval Scalar={interval_scalar}" - ); - let res1 = - resolve_temporal_op_scalar(&interval_array, -1, &interval_scalar, true)?; - let res2 = interval_scalar.sub(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Interval Scalar={interval_scalar} - Interval Scalar={interval_scalar}" - ); - - // interval + interval - let res1 = - resolve_temporal_op_scalar(&interval_array, 1, &interval_scalar, false)?; - let res2 = interval_scalar.add(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Interval Scalar={interval_scalar} + Interval Scalar={interval_scalar}" - ); - let res1 = - resolve_temporal_op_scalar(&interval_array, 1, &interval_scalar, true)?; - let res2 = interval_scalar.add(&interval_scalar)?.to_array(); - assert_eq!( - &res1, &res2, - "Interval Scalar={interval_scalar} + Interval Scalar={interval_scalar}" - ); - - Ok(()) - } - #[test] - fn test_evalute_with_scalar() -> Result<()> { - // Timestamp (sec) & Interval (DayTime) - let timestamp_scalar = ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(0, 1_000); - - experiment(timestamp_scalar, interval_scalar)?; - - // Timestamp (millisec) & Interval (DayTime) - let timestamp_scalar = ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_milli_opt(0, 0, 0, 0) - .unwrap() - .timestamp_millis(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(0, 1_000); - - experiment(timestamp_scalar, interval_scalar)?; - - // Timestamp (nanosec) & Interval (MonthDayNano) - let timestamp_scalar = ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_nano_opt(0, 0, 0, 0) - .unwrap() - .timestamp_nanos(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_mdn(0, 0, 1_000); - - experiment(timestamp_scalar, interval_scalar)?; - - // Timestamp (nanosec) & Interval (MonthDayNano), negatively resulting cases - - let timestamp_scalar = ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(1970, 1, 1) - .unwrap() - .and_hms_nano_opt(0, 0, 0, 000) - .unwrap() - .timestamp_nanos(), - ), - None, - ); - - Arc::new(IntervalMonthDayNanoArray::from(vec![1_000])); // 1 us - let interval_scalar = ScalarValue::new_interval_mdn(0, 0, 1_000); - - experiment(timestamp_scalar, interval_scalar)?; - - // Timestamp (sec) & Interval (YearMonth) - let timestamp_scalar = ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2023, 1, 1) - .unwrap() - .and_hms_opt(0, 0, 0) - .unwrap() - .timestamp(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_ym(0, 1); - - experiment(timestamp_scalar, interval_scalar)?; - - // More test with all matchings of timestamps and intervals - let timestamp_scalar = ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_opt(23, 59, 59) - .unwrap() - .timestamp(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_ym(0, 1); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_opt(23, 59, 59) - .unwrap() - .timestamp(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(10, 100000); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampSecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_opt(23, 59, 59) - .unwrap() - .timestamp(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_mdn(13, 32, 123456); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 909) - .unwrap() - .timestamp_millis(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_ym(0, 1); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 909) - .unwrap() - .timestamp_millis(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(10, 100000); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMillisecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_milli_opt(23, 59, 59, 909) - .unwrap() - .timestamp_millis(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_mdn(13, 32, 123456); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_micro_opt(23, 59, 59, 987654) - .unwrap() - .timestamp_micros(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_ym(0, 1); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_micro_opt(23, 59, 59, 987654) - .unwrap() - .timestamp_micros(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(10, 100000); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampMicrosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_micro_opt(23, 59, 59, 987654) - .unwrap() - .timestamp_micros(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_mdn(13, 32, 123456); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_nano_opt(23, 59, 59, 999999999) - .unwrap() - .timestamp_nanos(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_ym(0, 1); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_nano_opt(23, 59, 59, 999999999) - .unwrap() - .timestamp_nanos(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_dt(10, 100000); - - experiment(timestamp_scalar, interval_scalar)?; - - let timestamp_scalar = ScalarValue::TimestampNanosecond( - Some( - NaiveDate::from_ymd_opt(2000, 12, 31) - .unwrap() - .and_hms_nano_opt(23, 59, 59, 999999999) - .unwrap() - .timestamp_nanos(), - ), - None, - ); - let interval_scalar = ScalarValue::new_interval_mdn(13, 32, 123456); - - experiment(timestamp_scalar, interval_scalar)?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs new file mode 100644 index 0000000000000..2bb79922cfecc --- /dev/null +++ b/datafusion/physical-expr/src/expressions/datum.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Datum}; +use arrow::error::ArrowError; +use arrow_array::BooleanArray; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` +/// +/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction +pub(crate) fn apply( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + match (&lhs, &rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; + let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +pub(crate) fn apply_cmp( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) +} diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index c07641796aa4d..43fd5a812a16c 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -18,52 +18,148 @@ //! get field of a `ListArray` use crate::PhysicalExpr; -use arrow::array::Array; -use arrow::compute::concat; +use datafusion_common::exec_err; +use crate::array_expressions::{array_element, array_slice}; use crate::physical_expr::down_cast_any_ref; use arrow::{ + array::{Array, Scalar, StringArray}, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::cast::{as_list_array, as_struct_array}; -use datafusion_common::DataFusionError; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::{ - field_util::get_indexed_field as get_data_type_field, ColumnarValue, +use datafusion_common::{ + cast::{as_map_array, as_struct_array}, + DataFusionError, Result, ScalarValue, }; -use std::convert::TryInto; +use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; use std::fmt::Debug; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -/// expression to get a field of a struct array. -#[derive(Debug)] +/// Access a sub field of a nested type, such as `Field` or `List` +#[derive(Clone, Hash, Debug)] +pub enum GetFieldAccessExpr { + /// Named field, For example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key: Arc }, + /// List range, for example `list[i:j]` + ListRange { + start: Arc, + stop: Arc, + }, +} + +impl std::fmt::Display for GetFieldAccessExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), + GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key), + GetFieldAccessExpr::ListRange { start, stop } => { + write!(f, "[{}:{}]", start, stop) + } + } + } +} + +impl PartialEq for GetFieldAccessExpr { + fn eq(&self, other: &dyn Any) -> bool { + use GetFieldAccessExpr::{ListIndex, ListRange, NamedStructField}; + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| match (self, x) { + (NamedStructField { name: lhs }, NamedStructField { name: rhs }) => { + lhs.eq(rhs) + } + (ListIndex { key: lhs }, ListIndex { key: rhs }) => lhs.eq(rhs), + ( + ListRange { + start: start_lhs, + stop: stop_lhs, + }, + ListRange { + start: start_rhs, + stop: stop_rhs, + }, + ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs), + (NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) => false, + (ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) => false, + (ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) => false, + }) + .unwrap_or(false) + } +} + +/// Expression to get a field of a struct array. +#[derive(Debug, Hash)] pub struct GetIndexedFieldExpr { + /// The expression to find arg: Arc, - key: ScalarValue, + /// The key statement + field: GetFieldAccessExpr, } impl GetIndexedFieldExpr { - /// Create new get field expression - pub fn new(arg: Arc, key: ScalarValue) -> Self { - Self { arg, key } + /// Create new [`GetIndexedFieldExpr`] + pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { + Self { arg, field } + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the named field + pub fn new_field(arg: Arc, name: impl Into) -> Self { + Self::new( + arg, + GetFieldAccessExpr::NamedStructField { + name: ScalarValue::from(name.into()), + }, + ) + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the specified index + pub fn new_index(arg: Arc, key: Arc) -> Self { + Self::new(arg, GetFieldAccessExpr::ListIndex { key }) } - /// Get the input key - pub fn key(&self) -> &ScalarValue { - &self.key + /// Create a new [`GetIndexedFieldExpr`] for accessing the range + pub fn new_range( + arg: Arc, + start: Arc, + stop: Arc, + ) -> Self { + Self::new(arg, GetFieldAccessExpr::ListRange { start, stop }) + } + + /// Get the description of what field should be accessed + pub fn field(&self) -> &GetFieldAccessExpr { + &self.field } /// Get the input expression pub fn arg(&self) -> &Arc { &self.arg } + + fn schema_access(&self, input_schema: &Schema) -> Result { + Ok(match &self.field { + GetFieldAccessExpr::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } + } + GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.data_type(input_schema)?, + }, + GetFieldAccessExpr::ListRange { start, stop } => { + GetFieldAccessSchema::ListRange { + start_dt: start.data_type(input_schema)?, + stop_dt: stop.data_type(input_schema)?, + } + } + }) + } } impl std::fmt::Display for GetIndexedFieldExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "({}).[{}]", self.arg, self.key) + write!(f, "({}).{}", self.arg, self.field) } } @@ -73,70 +169,75 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.data_type().clone()) + let arg_dt = self.arg.data_type(input_schema)?; + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.data_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable()) + let arg_dt = self.arg.data_type(input_schema)?; + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(1); - match (array.data_type(), &self.key) { - (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { - let scalar_null: ScalarValue = array.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) - } - (DataType::List(lst), ScalarValue::Int64(Some(i))) => { - let as_list_array = as_list_array(&array)?; - - if *i < 1 || as_list_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - return Ok(ColumnarValue::Scalar(scalar_null)) + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; + match &self.field { + GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + let entries = arrow::compute::filter(map_array.entries(), &keys)?; + let entries_struct_array = as_struct_array(entries.as_ref())?; + Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) } - - let sliced_array: Vec> = as_list_array - .iter() - .filter_map(|o| match o { - Some(list) => if *i as usize > list.len() { - None - } else { - Some(list.slice((*i -1) as usize, 1)) - }, - None => None - }) - .collect(); - - // concat requires input of at least one array - if sliced_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) - } else { - let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); - let iter = concat(vec.as_slice()).unwrap(); - - Ok(ColumnarValue::Array(iter)) + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(k) { + None => exec_err!( + "get indexed field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } } - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => Err(DataFusionError::Execution( - format!("get indexed field {k} not found in struct"))), - Some(col) => Ok(ColumnarValue::Array(col.clone())) + (DataType::Struct(_), name) => exec_err!( + "get indexed field is only possible on struct with utf8 indexes. \ + Tried with {name:?} index"), + (dt, name) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {name:?} index"), + }, + GetFieldAccessExpr::ListIndex{key} => { + let key = key.evaluate(batch)?.into_array(batch.num_rows())?; + match (array.data_type(), key.data_type()) { + (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ + array, key + ])?)), + (DataType::List(_), key) => exec_err!( + "get indexed field is only possible on lists with int64 indexes. \ + Tried with {key:?} index"), + (dt, key) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {key:?} index"), + } + }, + GetFieldAccessExpr::ListRange{start, stop} => { + let start = start.evaluate(batch)?.into_array(batch.num_rows())?; + let stop = stop.evaluate(batch)?.into_array(batch.num_rows())?; + match (array.data_type(), start.data_type(), stop.data_type()) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ + array, start, stop + ])?)), + (DataType::List(_), start, stop) => exec_err!( + "get indexed field is only possible on lists with int64 indexes. \ + Tried with {start:?} and {stop:?} indices"), + (dt, start, stop) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"), } - } - (DataType::List(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes. \ - Tried with {key:?} index"))), - (DataType::Struct(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on struct with utf8 indexes. \ - Tried with {key:?} index"))), - (dt, key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {key:?} index"))), + }, } } @@ -150,16 +251,21 @@ impl PhysicalExpr for GetIndexedFieldExpr { ) -> Result> { Ok(Arc::new(GetIndexedFieldExpr::new( children[0].clone(), - self.key.clone(), + self.field.clone(), ))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for GetIndexedFieldExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.arg.eq(&x.arg) && self.key == x.key) + .map(|x| self.arg.eq(&x.arg) && self.field.eq(&x.field)) .unwrap_or(false) } } @@ -167,301 +273,226 @@ impl PartialEq for GetIndexedFieldExpr { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; - use arrow::array::{ArrayRef, Float64Array, GenericListArray, PrimitiveBuilder}; + use crate::expressions::col; + use arrow::array::new_empty_array; + use arrow::array::{ArrayRef, GenericListArray}; use arrow::array::{ - Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder, + BooleanArray, Int64Array, ListBuilder, StringBuilder, StructArray, }; - use arrow::datatypes::{Float64Type, Int64Type}; + use arrow::datatypes::Fields; use arrow::{array::StringArray, datatypes::Field}; - use datafusion_common::cast::{as_int64_array, as_string_array}; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_string_array}; use datafusion_common::Result; - fn build_utf8_lists(list_of_lists: Vec>>) -> GenericListArray { + fn build_list_arguments( + list_of_lists: Vec>>, + list_of_start_indices: Vec>, + list_of_stop_indices: Vec>, + ) -> (GenericListArray, Int64Array, Int64Array) { let builder = StringBuilder::with_capacity(list_of_lists.len(), 1024); - let mut lb = ListBuilder::new(builder); + let mut list_builder = ListBuilder::new(builder); for values in list_of_lists { - let builder = lb.values(); + let builder = list_builder.values(); for value in values { match value { None => builder.append_null(), Some(v) => builder.append_value(v), } } - lb.append(true); + list_builder.append(true); } - lb.finish() + let start_array = Int64Array::from(list_of_start_indices); + let stop_array = Int64Array::from(list_of_stop_indices); + (list_builder.finish(), start_array, stop_array) } - fn get_indexed_field_test( - list_of_lists: Vec>>, - index: i64, - expected: Vec>, - ) -> Result<()> { - let schema = list_schema("l"); - let list_col = build_utf8_lists(list_of_lists); - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?; - let key = ScalarValue::Int64(Some(index)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_string_array(&result).expect("failed to downcast to StringArray"); - let expected = &StringArray::from(expected); - assert_eq!(expected, result); + #[test] + fn get_indexed_field_named_struct_field() -> Result<()> { + let schema = struct_schema(); + let boolean = BooleanArray::from(vec![false, false, true, true]); + let int = Int64Array::from(vec![42, 28, 19, 31]); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, true)), + Arc::new(boolean.clone()) as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int64, true)), + Arc::new(int) as ArrayRef, + ), + ]); + let expr = col("str", &schema).unwrap(); + // only one row should be processed + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; + let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); + let result = + as_boolean_array(&result).expect("failed to downcast to BooleanArray"); + assert_eq!(boolean, result.clone()); Ok(()) } - fn list_schema(col: &str) -> Schema { - Schema::new(vec![Field::new_list( - col, - Field::new("item", DataType::Utf8, true), + fn struct_schema() -> Schema { + Schema::new(vec![Field::new_struct( + "str", + Fields::from(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int64, true), + ]), true, )]) } + fn list_schema(cols: &[&str]) -> Schema { + if cols.len() == 2 { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + ]) + } else { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + Field::new(cols[2], DataType::Int64, true), + ]) + } + } + + #[test] + fn get_indexed_field_list_index() -> Result<()> { + let list_of_lists = vec![ + vec![Some("a"), Some("b"), None], + vec![None, Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], + ]; + let list_of_start_indices = vec![Some(1), Some(2), None]; + let list_of_stop_indices = vec![None]; + let expected_list = vec![Some("a"), Some("c"), None]; + + let schema = list_schema(&["list", "key"]); + let (list_col, key_col, _) = build_list_arguments( + list_of_lists, + list_of_start_indices, + list_of_stop_indices, + ); + let expr = col("list", &schema).unwrap(); + let key = col("key", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_col), Arc::new(key_col)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); + let result = as_string_array(&result).expect("failed to downcast to ListArray"); + let expected = StringArray::from(expected_list); + assert_eq!(expected, result.clone()); + Ok(()) + } + #[test] - fn get_indexed_field_list() -> Result<()> { + fn get_indexed_field_list_range() -> Result<()> { let list_of_lists = vec![ vec![Some("a"), Some("b"), None], vec![None, Some("c"), Some("d")], vec![Some("e"), None, Some("f")], ]; + let list_of_start_indices = vec![Some(1), Some(2), None]; + let list_of_stop_indices = vec![Some(2), None, Some(3)]; let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], + vec![Some("a"), Some("b")], + vec![Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], ]; - for (i, expected) in expected_list.into_iter().enumerate() { - get_indexed_field_test(list_of_lists.clone(), (i + 1) as i64, expected)?; - } + let schema = list_schema(&["list", "start", "stop"]); + let (list_col, start_col, stop_col) = build_list_arguments( + list_of_lists, + list_of_start_indices, + list_of_stop_indices, + ); + let expr = col("list", &schema).unwrap(); + let start = col("start", &schema).unwrap(); + let stop = col("stop", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); + let result = expr + .evaluate(&batch)? + .into_array(1) + .expect("Failed to convert to array"); + let result = as_list_array(&result).expect("failed to downcast to ListArray"); + let (expected, _, _) = + build_list_arguments(expected_list, vec![None], vec![None]); + assert_eq!(expected, result.clone()); Ok(()) } #[test] fn get_indexed_field_empty_list() -> Result<()> { - let schema = list_schema("l"); + let schema = list_schema(&["list", "key"]); let builder = StringBuilder::new(); - let mut lb = ListBuilder::new(builder); - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let key = ScalarValue::Int64(Some(1)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let mut list_builder = ListBuilder::new(builder); + let key_array = new_empty_array(&DataType::Int64); + let expr = col("list", &schema).unwrap(); + let key = col("key", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_builder.finish()), key_array], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); assert!(result.is_empty()); Ok(()) } - fn get_indexed_field_test_failure( - schema: Schema, - expr: Arc, - key: ScalarValue, - expected: &str, - ) -> Result<()> { - let builder = StringBuilder::with_capacity(3, 1024); - let mut lb = ListBuilder::new(builder); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let r = expr.evaluate(&batch).map(|_| ()); - assert!(r.is_err()); - assert_eq!(format!("{}", r.unwrap_err()), expected); - Ok(()) - } - - #[test] - fn get_indexed_field_invalid_scalar() -> Result<()> { - let schema = list_schema("l"); - let expr = lit("a"); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int64(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes or \ - struct with utf8 indexes. Tried Utf8 with Int64(0) index") - } - #[test] fn get_indexed_field_invalid_list_index() -> Result<()> { - let schema = list_schema("l"); - let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int8(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes. \ - Tried with Int8(0) index") - } - - fn build_struct( - fields: Vec, - list_of_tuples: Vec<(Option, Vec>)>, - ) -> StructArray { - let foo_builder = Int64Array::builder(list_of_tuples.len()); - let str_builder = StringBuilder::with_capacity(list_of_tuples.len(), 1024); - let bar_builder = ListBuilder::new(str_builder); - let mut builder = StructBuilder::new( - fields, - vec![Box::new(foo_builder), Box::new(bar_builder)], - ); - for (int_value, list_value) in list_of_tuples { - let fb = builder.field_builder::(0).unwrap(); - match int_value { - None => fb.append_null(), - Some(v) => fb.append_value(v), - }; - builder.append(true); - let lb = builder - .field_builder::>(1) - .unwrap(); - for str_value in list_value { - match str_value { - None => lb.values().append_null(), - Some(v) => lb.values().append_value(v), - }; - } - lb.append(true); - } - builder.finish() - } + let schema = list_schema(&["list", "error"]); + let expr = col("list", &schema).unwrap(); + let key = col("error", &schema).unwrap(); + let builder = StringBuilder::with_capacity(3, 1024); + let mut list_builder = ListBuilder::new(builder); + list_builder.values().append_value("hello"); + list_builder.append(true); - fn get_indexed_field_mixed_test( - list_of_tuples: Vec<(Option, Vec>)>, - expected_strings: Vec>>, - expected_ints: Vec>, - ) -> Result<()> { - let struct_col = "s"; - let fields = vec![ - Field::new("foo", DataType::Int64, true), - Field::new_list("bar", Field::new("item", DataType::Utf8, true), true), - ]; - let schema = Schema::new(vec![Field::new( - struct_col, - DataType::Struct(fields.clone().into()), - true, - )]); - let struct_col = build_struct(fields, list_of_tuples.clone()); - - let struct_col_expr = col("s", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?; - - let int_field_key = ScalarValue::Utf8(Some("foo".to_string())); - let get_field_expr = Arc::new(GetIndexedFieldExpr::new( - struct_col_expr.clone(), - int_field_key, - )); - let result = get_field_expr + let key_array = Int64Array::from(vec![Some(3)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_builder.finish()), Arc::new(key_array)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_int64_array(&result)?; - let expected = &Int64Array::from(expected_ints); - assert_eq!(expected, result); - - let list_field_key = ScalarValue::Utf8(Some("bar".to_string())); - let get_list_expr = - Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); - let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_list_array(&result)?; - let expected = - &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); - assert_eq!(expected, result); - - for (i, expected) in expected_strings.into_iter().enumerate() { - let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new( - get_list_expr.clone(), - ScalarValue::Int64(Some((i + 1) as i64)), - )); - let result = get_nested_str_expr - .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_string_array(&result)?; - let expected = &StringArray::from(expected); - assert_eq!(expected, result); - } + .into_array(1) + .expect("Failed to convert to array"); + assert!(result.is_null(0)); Ok(()) } #[test] - fn get_indexed_field_struct() -> Result<()> { - let list_of_structs = vec![ - (Some(10), vec![Some("a"), Some("b"), None]), - (Some(15), vec![None, Some("c"), Some("d")]), - (None, vec![Some("e"), None, Some("f")]), - ]; - - let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], - ]; - - let expected_ints = vec![Some(10), Some(15), None]; - - get_indexed_field_mixed_test( - list_of_structs.clone(), - expected_list, - expected_ints, - )?; + fn get_indexed_field_eq() -> Result<()> { + let schema = list_schema(&["list", "error"]); + let expr = col("list", &schema).unwrap(); + let key = col("error", &schema).unwrap(); + let indexed_field = + Arc::new(GetIndexedFieldExpr::new_index(expr.clone(), key.clone())) + as Arc; + let indexed_field_other = + Arc::new(GetIndexedFieldExpr::new_index(key, expr)) as Arc; + assert!(indexed_field.eq(&indexed_field)); + assert!(!indexed_field.eq(&indexed_field_other)); Ok(()) } - - #[test] - fn get_indexed_field_list_out_of_bounds() { - let fields = vec![ - Field::new("id", DataType::Int64, true), - Field::new_list("a", Field::new("item", DataType::Float64, true), true), - ]; - - let schema = Schema::new(fields); - let mut int_builder = PrimitiveBuilder::::new(); - int_builder.append_value(1); - - let mut lb = ListBuilder::new(PrimitiveBuilder::::new()); - lb.values().append_value(1.0); - lb.values().append_null(); - lb.values().append_value(3.0); - lb.append(true); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(int_builder.finish()), Arc::new(lb.finish())], - ) - .unwrap(); - - let col_a = col("a", &schema).unwrap(); - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 0, float64_array(None)); - - verify_index_evaluation(&batch, col_a.clone(), 1, float64_array(Some(1.0))); - verify_index_evaluation(&batch, col_a.clone(), 2, float64_array(None)); - verify_index_evaluation(&batch, col_a.clone(), 3, float64_array(Some(3.0))); - - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 100, float64_array(None)); - } - - fn verify_index_evaluation( - batch: &RecordBatch, - arg: Arc, - index: i64, - expected_result: ArrayRef, - ) { - let expr = Arc::new(GetIndexedFieldExpr::new( - arg, - ScalarValue::Int64(Some(index)), - )); - let result = expr.evaluate(batch).unwrap().into_array(batch.num_rows()); - assert!( - result == expected_result.clone(), - "result: {result:?} != expected result: {expected_result:?}" - ); - assert_eq!(result.data_type(), &DataType::Float64); - } - - fn float64_array(value: Option) -> ArrayRef { - match value { - Some(v) => Arc::new(Float64Array::from_value(v, 1)), - None => { - let mut b = PrimitiveBuilder::::new(); - b.append_null(); - Arc::new(b.finish()) - } - } - } } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3feb728900adc..625b01ec9a7ea 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -17,28 +17,33 @@ //! Implementation of `InList` expressions: [`InListExpr`] -use ahash::RandomState; use std::any::Any; use std::fmt::Debug; +use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::hash_utils::HashValue; -use crate::physical_expr::down_cast_any_ref; -use crate::utils::expr_list_eq_any_order; +use crate::physical_expr::{down_cast_any_ref, physical_exprs_bag_equal}; use crate::PhysicalExpr; + use arrow::array::*; +use arrow::buffer::BooleanBuffer; +use arrow::compute::kernels::boolean::{not, or_kleene}; +use arrow::compute::kernels::cmp::eq; use arrow::compute::take; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use arrow::util::bit_iterator::BitIndexIterator; use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use datafusion_common::cast::{ + as_boolean_array, as_generic_binary_array, as_string_array, +}; +use datafusion_common::hash_utils::HashValue; use datafusion_common::{ - cast::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, - }, - DataFusionError, Result, ScalarValue, + exec_err, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; + +use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; @@ -63,6 +68,7 @@ impl Debug for InListExpr { /// A type-erased container of array elements pub trait Set: Send + Sync { fn contains(&self, v: &dyn Array, negated: bool) -> Result; + fn has_nulls(&self) -> bool; } struct ArrayHashSet { @@ -95,7 +101,7 @@ impl Set for ArraySet where T: Array + 'static, for<'a> &'a T: ArrayAccessor, - for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue, + for<'a> <&'a T as ArrayAccessor>::Item: IsEqual, { fn contains(&self, v: &dyn Array, negated: bool) -> Result { downcast_dictionary_array! { @@ -119,7 +125,7 @@ where .hash_set .map .raw_entry() - .from_hash(hash, |idx| in_array.value(*idx) == v) + .from_hash(hash, |idx| in_array.value(*idx).is_equal(&v)) .is_some(); match contains { @@ -131,6 +137,10 @@ where }) .collect()) } + + fn has_nulls(&self) -> bool { + self.array.null_count() != 0 + } } /// Computes an [`ArrayHashSet`] for the provided [`Array`] if there @@ -142,7 +152,7 @@ where fn make_hash_set(array: T) -> ArrayHashSet where T: ArrayAccessor, - T::Item: PartialEq + HashValue, + T::Item: IsEqual, { let state = RandomState::new(); let mut map: HashMap = @@ -153,7 +163,7 @@ where let hash = value.hash_one(&state); if let RawEntryMut::Vacant(v) = map .raw_entry_mut() - .from_hash(hash, |x| array.value(*x) == value) + .from_hash(hash, |x| array.value(*x).is_equal(&value)) { v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state)); } @@ -178,14 +188,6 @@ fn make_set(array: &dyn Array) -> Result> { let array = as_boolean_array(array)?; Arc::new(ArraySet::new(array, make_hash_set(array))) }, - DataType::Decimal128(_, _) => { - let array = as_primitive_array::(array)?; - Arc::new(ArraySet::new(array, make_hash_set(array))) - } - DataType::Decimal256(_, _) => { - let array = as_primitive_array::(array)?; - Arc::new(ArraySet::new(array, make_hash_set(array))) - } DataType::Utf8 => { let array = as_string_array(array)?; Arc::new(ArraySet::new(array, make_hash_set(array))) @@ -203,7 +205,7 @@ fn make_set(array: &dyn Array) -> Result> { Arc::new(ArraySet::new(array, make_hash_set(array))) } DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"), - d => return Err(DataFusionError::NotImplemented(format!("DataType::{d} not supported in InList"))) + d => return not_impl_err!("DataType::{d} not supported in InList") }) } @@ -216,9 +218,9 @@ fn evaluate_list( .iter() .map(|expr| { expr.evaluate(batch).and_then(|r| match r { - ColumnarValue::Array(_) => Err(DataFusionError::Execution( - "InList expression must evaluate to a scalar".to_string(), - )), + ColumnarValue::Array(_) => { + exec_err!("InList expression must evaluate to a scalar") + } // Flatten dictionary values ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v), ColumnarValue::Scalar(s) => Ok(s), @@ -237,6 +239,40 @@ fn try_cast_static_filter_to_set( make_set(evaluate_list(list, &batch)?.as_ref()) } +/// Custom equality check function which is used with [`ArrayHashSet`] for existence check. +trait IsEqual: HashValue { + fn is_equal(&self, other: &Self) -> bool; +} + +impl<'a, T: IsEqual + ?Sized> IsEqual for &'a T { + fn is_equal(&self, other: &Self) -> bool { + T::is_equal(self, other) + } +} + +macro_rules! is_equal { + ($($t:ty),+) => { + $(impl IsEqual for $t { + fn is_equal(&self, other: &Self) -> bool { + self == other + } + })* + }; +} +is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); +is_equal!(bool, str, [u8]); + +macro_rules! is_equal_float { + ($($t:ty),+) => { + $(impl IsEqual for $t { + fn is_equal(&self, other: &Self) -> bool { + self.to_bits() == other.to_bits() + } + })* + }; +} +is_equal_float!(half::f16, f32, f64); + impl InListExpr { /// Create a new InList expression pub fn new( @@ -296,16 +332,43 @@ impl PhysicalExpr for InListExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - self.expr.nullable(input_schema) + if self.expr.nullable(input_schema)? { + return Ok(true); + } + + if let Some(static_filter) = &self.static_filter { + Ok(static_filter.has_nulls()) + } else { + for expr in &self.list { + if expr.nullable(input_schema)? { + return Ok(true); + } + } + Ok(false) + } } fn evaluate(&self, batch: &RecordBatch) -> Result { - let value = self.expr.evaluate(batch)?.into_array(1); + let value = self.expr.evaluate(batch)?; let r = match &self.static_filter { - Some(f) => f.contains(value.as_ref(), self.negated)?, + Some(f) => f.contains(value.into_array(1)?.as_ref(), self.negated)?, None => { - let list = evaluate_list(&self.list, batch)?; - make_set(list.as_ref())?.contains(value.as_ref(), self.negated)? + let value = value.into_array(batch.num_rows())?; + let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( + BooleanArray::new(BooleanBuffer::new_unset(batch.num_rows()), None), + |result, expr| -> Result { + Ok(or_kleene( + &result, + &eq(&value, &expr?.into_array(batch.num_rows())?)?, + )?) + }, + )?; + + if self.negated { + not(&found)? + } else { + found + } } }; Ok(ColumnarValue::Array(Arc::new(r))) @@ -330,6 +393,14 @@ impl PhysicalExpr for InListExpr { self.static_filter.clone(), ))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.expr.hash(&mut s); + self.negated.hash(&mut s); + self.list.hash(&mut s); + // Add `self.static_filter` when hash is available + } } impl PartialEq for InListExpr { @@ -338,7 +409,7 @@ impl PartialEq for InListExpr { .downcast_ref::() .map(|x| { self.expr.eq(&x.expr) - && expr_list_eq_any_order(&self.list, &x.list) + && physical_exprs_bag_equal(&self.list, &x.list) && self.negated == x.negated }) .unwrap_or(false) @@ -357,9 +428,9 @@ pub fn in_list( for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; if !expr_data_type.eq(&list_expr_data_type) { - return Err(DataFusionError::Internal(format!( + return internal_err!( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" - ))); + ); } } let static_filter = try_cast_static_filter_to_set(&list, schema).ok(); @@ -378,6 +449,7 @@ mod tests { use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; + use datafusion_common::plan_err; use datafusion_common::Result; use datafusion_expr::type_coercion::binary::comparison_coercion; @@ -397,9 +469,9 @@ mod tests { .collect(); let result_type = get_coerce_type(expr_type, &list_types); match result_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can not find compatible types to compare {expr_type:?} with {list_types:?}" - ))), + ), Some(data_type) => { // find the coerced type let cast_expr = try_cast(expr, input_schema, data_type.clone())?; @@ -419,9 +491,8 @@ mod tests { fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option { list_type .iter() - .fold(Some(expr_type.clone()), |left, right_type| match left { - None => None, - Some(left_type) => comparison_coercion(&left_type, right_type), + .try_fold(expr_type.clone(), |left_type, right_type| { + comparison_coercion(&left_type, right_type) }) } @@ -430,7 +501,10 @@ mod tests { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, $SCHEMA).unwrap(); - let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); + let result = expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($EXPECTED); @@ -609,50 +683,100 @@ mod tests { #[test] fn in_list_float64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); + let a = Float64Array::from(vec![ + Some(0.0), + Some(0.2), + None, + Some(f64::NAN), + Some(-f64::NAN), + ]); let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in (0.0, 0.2)" + // expression: "a in (0.0, 0.1)" let list = vec![lit(0.0f64), lit(0.1f64)]; in_list!( batch, list, &false, - vec![Some(true), Some(false), None], + vec![Some(true), Some(false), None, Some(false), Some(false)], col_a.clone(), &schema ); - // expression: "a not in (0.0, 0.2)" + // expression: "a not in (0.0, 0.1)" let list = vec![lit(0.0f64), lit(0.1f64)]; in_list!( batch, list, &true, - vec![Some(false), Some(true), None], + vec![Some(false), Some(true), None, Some(true), Some(true)], col_a.clone(), &schema ); - // expression: "a in (0.0, 0.2, NULL)" + // expression: "a in (0.0, 0.1, NULL)" let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; in_list!( batch, list, &false, - vec![Some(true), None, None], + vec![Some(true), None, None, None, None], col_a.clone(), &schema ); - // expression: "a not in (0.0, 0.2, NULL)" + // expression: "a not in (0.0, 0.1, NULL)" let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; in_list!( batch, list, &true, - vec![Some(false), None, None], + vec![Some(false), None, None, None, None], + col_a.clone(), + &schema + ); + + // expression: "a in (0.0, 0.1, NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None, Some(true), Some(false)], + col_a.clone(), + &schema + ); + + // expression: "a not in (0.0, 0.1, NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None, Some(false), Some(true)], + col_a.clone(), + &schema + ); + + // expression: "a in (0.0, 0.1, -NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None, Some(false), Some(true)], + col_a.clone(), + &schema + ); + + // expression: "a not in (0.0, 0.1, -NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None, Some(true), Some(false)], col_a.clone(), &schema ); @@ -1044,4 +1168,103 @@ mod tests { ); Ok(()) } + + #[test] + fn in_expr_with_multiple_element_in_list() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Float64, true), + Field::new("c", DataType::Float64, true), + ]); + let a = Float64Array::from(vec![ + Some(0.0), + Some(1.0), + Some(2.0), + Some(f64::NAN), + Some(-f64::NAN), + ]); + let b = Float64Array::from(vec![ + Some(8.0), + Some(1.0), + Some(5.0), + Some(f64::NAN), + Some(3.0), + ]); + let c = Float64Array::from(vec![ + Some(6.0), + Some(7.0), + None, + Some(5.0), + Some(-f64::NAN), + ]); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + )?; + + let list = vec![col_b.clone(), col_c.clone()]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(false), Some(true), None, Some(true), Some(true)], + col_a.clone(), + &schema + ); + + in_list!( + batch, + list, + &true, + vec![Some(true), Some(false), None, Some(false), Some(false)], + col_a.clone(), + &schema + ); + + Ok(()) + } + + macro_rules! test_nullable { + ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{ + let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; + let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap(); + let result = expr.nullable($SCHEMA)?; + assert_eq!($EXPECTED, result); + }}; + } + + #[test] + fn in_list_nullable() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1_nullable", DataType::Int64, true), + Field::new("c2_non_nullable", DataType::Int64, false), + ]); + + let c1_nullable = col("c1_nullable", &schema)?; + let c2_non_nullable = col("c2_non_nullable", &schema)?; + + // static_filter has no nulls + let list = vec![lit(1_i64), lit(2_i64)]; + test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + + // static_filter has nulls + let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; + test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + + let list = vec![c1_nullable.clone()]; + test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + + let list = vec![c2_non_nullable.clone()]; + test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + + let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()]; + test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 32e53e0c1edea..2e6a2bec9cab5 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,6 +17,7 @@ //! IS NOT NULL expression +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::physical_expr::down_cast_any_ref; @@ -31,7 +32,7 @@ use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; /// IS NOT NULL expression -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct IsNotNullExpr { /// The input expression arg: Arc, @@ -91,6 +92,11 @@ impl PhysicalExpr for IsNotNullExpr { ) -> Result> { Ok(Arc::new(IsNotNullExpr::new(children[0].clone()))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for IsNotNullExpr { @@ -126,7 +132,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; // expression: "a is not null" - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 85e111440aaf5..3ad4058dd6493 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,6 +17,7 @@ //! IS NULL expression +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use arrow::compute; @@ -32,7 +33,7 @@ use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; /// IS NULL expression -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct IsNullExpr { /// Input expression arg: Arc, @@ -92,6 +93,11 @@ impl PhysicalExpr for IsNullExpr { ) -> Result> { Ok(Arc::new(IsNullExpr::new(children[0].clone()))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for IsNullExpr { @@ -128,7 +134,10 @@ mod tests { let expr = is_null(col("a", &schema)?).unwrap(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 456e477a1e535..37452e278484a 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,27 +15,19 @@ // specific language governing permissions and limitations // under the License. +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use arrow::{ - array::{new_null_array, Array, ArrayRef, LargeStringArray, StringArray}, - record_batch::RecordBatch, -}; +use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; + +use crate::expressions::datum::apply_cmp; +use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use crate::{physical_expr::down_cast_any_ref, AnalysisContext, PhysicalExpr}; - -use arrow::compute::kernels::comparison::{ - ilike_utf8, like_utf8, nilike_utf8, nlike_utf8, -}; -use arrow::compute::kernels::comparison::{ - ilike_utf8_scalar, like_utf8_scalar, nilike_utf8_scalar, nlike_utf8_scalar, -}; - // Like expression -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct LikeExpr { negated: bool, case_insensitive: bool, @@ -109,61 +101,15 @@ impl PhysicalExpr for LikeExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let expr_value = self.expr.evaluate(batch)?; - let pattern_value = self.pattern.evaluate(batch)?; - let expr_data_type = expr_value.data_type(); - let pattern_data_type = pattern_value.data_type(); - - match ( - &expr_value, - &expr_data_type, - &pattern_value, - &pattern_data_type, - ) { - // Types are equal => valid - (_, l, _, r) if l == r => {} - // Allow comparing a dictionary value with its corresponding scalar value - ( - ColumnarValue::Array(_), - DataType::Dictionary(_, dict_t), - ColumnarValue::Scalar(_), - scalar_t, - ) - | ( - ColumnarValue::Scalar(_), - scalar_t, - ColumnarValue::Array(_), - DataType::Dictionary(_, dict_t), - ) if dict_t.as_ref() == scalar_t => {} - _ => { - return Err(DataFusionError::Internal(format!( - "Cannot evaluate {} expression with types {:?} and {:?}", - self.op_name(), - expr_data_type, - pattern_data_type - ))); - } - } - - // Attempt to use special kernels if one input is scalar and the other is an array - let scalar_result = match (&expr_value, &pattern_value) { - (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { - self.evaluate_array_scalar(array, scalar)? - } - (_, _) => None, // default to array implementation - }; - - if let Some(result) = scalar_result { - return result.map(|a| ColumnarValue::Array(a)); + use arrow::compute::*; + let lhs = self.expr.evaluate(batch)?; + let rhs = self.pattern.evaluate(batch)?; + match (self.negated, self.case_insensitive) { + (false, false) => apply_cmp(&lhs, &rhs, like), + (false, true) => apply_cmp(&lhs, &rhs, ilike), + (true, false) => apply_cmp(&lhs, &rhs, nlike), + (true, true) => apply_cmp(&lhs, &rhs, nilike), } - - // if both arrays or both literals - extract arrays and continue execution - let (expr, pattern) = ( - expr_value.into_array(batch.num_rows()), - pattern_value.into_array(batch.num_rows()), - ); - self.evaluate_array_array(expr, pattern) - .map(|a| ColumnarValue::Array(a)) } fn children(&self) -> Vec> { @@ -182,9 +128,9 @@ impl PhysicalExpr for LikeExpr { ))) } - /// Return the boundaries of this binary expression's result. - fn analyze(&self, context: AnalysisContext) -> AnalysisContext { - context.with_boundaries(None) + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); } } @@ -202,71 +148,6 @@ impl PartialEq for LikeExpr { } } -macro_rules! binary_string_array_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ - let result: Result> = match $LEFT.data_type() { - DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $OP_TYPE), - DataType::LargeUtf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $OP_TYPE), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on string array", - other, stringify!($OP) - ))), - }; - Some(result) - }}; -} - -impl LikeExpr { - /// Evaluate the expression if the input is an array and - /// pattern is literal - use scalar operations - fn evaluate_array_scalar( - &self, - array: &dyn Array, - scalar: &ScalarValue, - ) -> Result>> { - let scalar_result = match (self.negated, self.case_insensitive) { - (false, false) => binary_string_array_op_scalar!( - array, - scalar.clone(), - like, - &DataType::Boolean - ), - (true, false) => binary_string_array_op_scalar!( - array, - scalar.clone(), - nlike, - &DataType::Boolean - ), - (false, true) => binary_string_array_op_scalar!( - array, - scalar.clone(), - ilike, - &DataType::Boolean - ), - (true, true) => binary_string_array_op_scalar!( - array, - scalar.clone(), - nilike, - &DataType::Boolean - ), - }; - Ok(scalar_result) - } - - fn evaluate_array_array( - &self, - left: Arc, - right: Arc, - ) -> Result { - match (self.negated, self.case_insensitive) { - (false, false) => binary_string_array_op!(left, right, like), - (true, false) => binary_string_array_op!(left, right, nlike), - (false, true) => binary_string_array_op!(left, right, ilike), - (true, true) => binary_string_array_op!(left, right, nilike), - } - } -} - /// Create a like expression, erroring if the argument types are not compatible. pub fn like( negated: bool, @@ -278,9 +159,9 @@ pub fn like( let expr_type = &expr.data_type(input_schema)?; let pattern_type = &pattern.data_type(input_schema)?; if !expr_type.eq(pattern_type) { - return Err(DataFusionError::Internal(format!( + return internal_err!( "The type of {expr_type} AND {pattern_type} of like physical should be same" - ))); + ); } Ok(Arc::new(LikeExpr::new( negated, @@ -294,7 +175,7 @@ pub fn like( mod test { use super::*; use crate::expressions::col; - use arrow::array::BooleanArray; + use arrow::array::*; use arrow_schema::Field; use datafusion_common::cast::as_boolean_array; @@ -320,7 +201,10 @@ mod test { )?; // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); let expected = &BooleanArray::from($VEC); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 013169ccf7852..cd3b51f09105a 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,21 +18,22 @@ //! Literal expressions for physical operations use std::any::Any; +use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::physical_expr::down_cast_any_ref; +use crate::sort_properties::SortProperties; +use crate::PhysicalExpr; + use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::physical_expr::down_cast_any_ref; -use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Expr}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Literal { value: ScalarValue, } @@ -62,7 +63,7 @@ impl PhysicalExpr for Literal { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.value.get_datatype()) + Ok(self.value.data_type()) } fn nullable(&self, _input_schema: &Schema) -> Result { @@ -84,14 +85,13 @@ impl PhysicalExpr for Literal { Ok(self) } - /// Return the boundaries of this literal expression (which is the same as - /// the value it represents). - fn analyze(&self, context: AnalysisContext) -> AnalysisContext { - context.with_boundaries(Some(ExprBoundaries::new( - self.value.clone(), - self.value.clone(), - Some(1), - ))) + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { + SortProperties::Singleton } } @@ -131,7 +131,10 @@ mod tests { let literal_expr = lit(42i32); assert_eq!("42", format!("{literal_expr}")); - let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows()); + let literal_array = literal_expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let literal_array = as_int32_array(&literal_array)?; // note that the contents of the literal array are unrelated to the batch contents except for the length of the array @@ -142,20 +145,4 @@ mod tests { Ok(()) } - - #[test] - fn literal_bounds_analysis() -> Result<()> { - let schema = Schema::empty(); - let context = AnalysisContext::new(&schema, vec![]); - - let literal_expr = lit(42i32); - let result_ctx = literal_expr.analyze(context); - let boundaries = result_ctx.boundaries.unwrap(); - assert_eq!(boundaries.min_value, ScalarValue::Int32(Some(42))); - assert_eq!(boundaries.max_value, ScalarValue::Int32(Some(42))); - assert_eq!(boundaries.distinct_count, Some(1)); - assert_eq!(boundaries.selectivity, None); - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0ca132aefdac3..b6d0ad5b91043 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -22,7 +22,7 @@ mod binary; mod case; mod cast; mod column; -mod datetime; +mod datum; mod get_indexed_field; mod in_list; mod is_not_null; @@ -46,6 +46,7 @@ pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; +pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::{Avg, AvgAccumulator}; pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor}; pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; @@ -59,8 +60,10 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; @@ -79,8 +82,7 @@ pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, cast_column, cast_with_options, CastExpr}; pub use column::{col, Column, UnKnownColumn}; -pub use datetime::{date_time_interval_expr, DateTimeIntervalExpr}; -pub use get_indexed_field::GetIndexedFieldExpr; +pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; @@ -100,17 +102,23 @@ pub use crate::PhysicalSortExpr; #[cfg(test)] pub(crate) mod tests { - use crate::AggregateExpr; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use crate::{AggregateExpr, EmitTo}; use arrow::record_batch::RecordBatch; + use arrow_array::ArrayRef; + use arrow_schema::{Field, Schema}; use datafusion_common::Result; use datafusion_common::ScalarValue; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; use std::sync::Arc; - /// macro to perform an aggregation and verify the result. + /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the + /// result. #[macro_export] macro_rules! generic_test_op { ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.get_datatype()) + generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) }; ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); @@ -131,6 +139,70 @@ pub(crate) mod tests { }}; } + /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. + /// + /// The difference between this and the above `generic_test_op` is that the former checks + /// the old slow-path [`datafusion_expr::Accumulator`] implementation, while this checks + /// the new [`crate::GroupsAccumulator`] implementation. + #[macro_export] + macro_rules! generic_test_op_new { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + generic_test_op_new!( + $ARRAY, + $DATATYPE, + $OP, + $EXPECTED, + $EXPECTED.data_type().clone() + ) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + )); + let actual = aggregate_new(&batch, agg)?; + assert_eq!($EXPECTED, &actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + + /// Assert `function(array) == expected` performing any necessary type coercion + pub fn assert_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = coerce_types(&function, &[data_type.clone()], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = + create_aggregate_expr(&function, distinct, &[input], &[], &schema, "agg") + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + /// macro to perform an aggregation with two inputs and verify the result. #[macro_export] macro_rules! generic_test_op2 { @@ -142,7 +214,7 @@ pub(crate) mod tests { $DATATYPE2, $OP, $EXPECTED, - $EXPECTED.get_datatype() + $EXPECTED.data_type() ) }; ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ @@ -176,10 +248,30 @@ pub(crate) mod tests { let expr = agg.expressions(); let values = expr .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; accum.update_batch(&values)?; accum.evaluate() } + + pub fn aggregate_new( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_groups_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + let indices = vec![0; batch.num_rows()]; + accum.update_batch(&values, &indices, None, 1)?; + accum.evaluate(EmitTo::All) + } } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 0d6aec879e567..b64b4a0c86def 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,41 +18,27 @@ //! Negation (-) expression use std::any::Any; +use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::compute::kernels::arithmetic::negate; +use crate::physical_expr::down_cast_any_ref; +use crate::sort_properties::SortProperties; +use crate::PhysicalExpr; + use arrow::{ - array::{ - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - }, - datatypes::{DataType, IntervalUnit, Schema}, + compute::kernels::numeric::neg_wrapping, + datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::{ - type_coercion::{is_interval, is_null, is_signed_numeric}, + type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, }; -/// Invoke a compute kernel on array(s) -macro_rules! compute_op { - // invoke unary operator - ($OPERAND:expr, $OP:ident, $DT:ident) => {{ - let operand = $OPERAND - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - Ok(Arc::new($OP(&operand)?)) - }}; -} - /// Negative expression -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct NegativeExpr { /// Input expression arg: Arc, @@ -94,23 +80,8 @@ impl PhysicalExpr for NegativeExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => { - let result: Result = match array.data_type() { - DataType::Int8 => compute_op!(array, negate, Int8Array), - DataType::Int16 => compute_op!(array, negate, Int16Array), - DataType::Int32 => compute_op!(array, negate, Int32Array), - DataType::Int64 => compute_op!(array, negate, Int64Array), - DataType::Float32 => compute_op!(array, negate, Float32Array), - DataType::Float64 => compute_op!(array, negate, Float64Array), - DataType::Interval(IntervalUnit::YearMonth) => compute_op!(array, negate, IntervalYearMonthArray), - DataType::Interval(IntervalUnit::DayTime) => compute_op!(array, negate, IntervalDayTimeArray), - DataType::Interval(IntervalUnit::MonthDayNano) => compute_op!(array, negate, IntervalMonthDayNanoArray), - _ => Err(DataFusionError::Internal(format!( - "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric", - self, - array.data_type(), - ))), - }; - result.map(|a| ColumnarValue::Array(a)) + let result = neg_wrapping(array.as_ref())?; + Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(scalar) => { Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) @@ -128,6 +99,44 @@ impl PhysicalExpr for NegativeExpr { ) -> Result> { Ok(Arc::new(NegativeExpr::new(children[0].clone()))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. + /// It replaces the upper and lower bounds after multiplying them with -1. + /// Ex: `(a, b]` => `[-b, -a)` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) + } + + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that + /// given the input interval is known to be `children`. + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; + + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) + } + + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. + fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { + -children[0] + } } impl PartialEq for NegativeExpr { @@ -151,10 +160,13 @@ pub fn negative( let data_type = arg.data_type(input_schema)?; if is_null(&data_type) { Ok(arg) - } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) { - Err(DataFusionError::Internal( - format!("Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric"), - )) + } else if !is_signed_numeric(&data_type) + && !is_interval(&data_type) + && !is_timestamp(&data_type) + { + internal_err!( + "Can't create negative physical expr for (- '{arg:?}'), the type of child expr is {data_type}, not signed numeric" + ) } else { Ok(Arc::new(NegativeExpr::new(arg))) } @@ -163,13 +175,14 @@ pub fn negative( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - #[allow(unused_imports)] + use crate::expressions::{col, Column}; + use arrow::array::*; use arrow::datatypes::*; use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; - use arrow_schema::IntervalUnit::{DayTime, MonthDayNano, YearMonth}; - use datafusion_common::{cast::as_primitive_array, Result}; + use datafusion_common::cast::as_primitive_array; + use datafusion_common::Result; + use paste::paste; macro_rules! test_array_negative_op { @@ -190,32 +203,7 @@ mod tests { let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)}; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = - as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str()); - assert_eq!(result, expected); - }; - } - - macro_rules! test_array_negative_op_intervals { - ($DATA_TY:tt, $($VALUE:expr),* ) => { - let schema = Schema::new(vec![Field::new("a", DataType::Interval(IntervalUnit::$DATA_TY), true)]); - let expr = negative(col("a", &schema)?, &schema)?; - assert_eq!(expr.data_type(&schema)?, DataType::Interval(IntervalUnit::$DATA_TY)); - assert!(expr.nullable(&schema)?); - let mut arr = Vec::new(); - let mut arr_expected = Vec::new(); - $( - arr.push(Some($VALUE)); - arr_expected.push(Some(-$VALUE)); - )+ - arr.push(None); - arr_expected.push(None); - let input = paste!{[]::from(arr)}; - let expected = &paste!{[]::from(arr_expected)}; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str()); assert_eq!(result, expected); @@ -230,9 +218,38 @@ mod tests { test_array_negative_op!(Int64, 23456i64, 12345i64); test_array_negative_op!(Float32, 2345.0f32, 1234.0f32); test_array_negative_op!(Float64, 23456.0f64, 12345.0f64); - test_array_negative_op_intervals!(YearMonth, 2345i32, 1234i32); - test_array_negative_op_intervals!(DayTime, 23456i64, 12345i64); - test_array_negative_op_intervals!(MonthDayNano, 234567i128, 123456i128); + Ok(()) + } + + #[test] + fn test_evaluate_bounds() -> Result<()> { + let negative_expr = NegativeExpr { + arg: Arc::new(Column::new("a", 0)), + }; + let child_interval = Interval::make(Some(-2), Some(1))?; + let negative_expr_interval = Interval::make(Some(-1), Some(2))?; + assert_eq!( + negative_expr.evaluate_bounds(&[&child_interval])?, + negative_expr_interval + ); + Ok(()) + } + + #[test] + fn test_propagate_constraints() -> Result<()> { + let negative_expr = NegativeExpr { + arg: Arc::new(Column::new("a", 0)), + }; + let original_child_interval = Interval::make(Some(-2), Some(3))?; + let negative_expr_interval = Interval::make(Some(0), Some(4))?; + let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]); + assert_eq!( + negative_expr.propagate_constraints( + &negative_expr_interval, + &[&original_child_interval] + )?, + after_propagation + ); Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 7c7d8cc8977dc..95e6879a6c2d9 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,6 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::{ @@ -27,13 +28,13 @@ use arrow::{ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; /// A place holder expression, can not be evaluated. /// /// Used in some cases where an `Arc` is needed, such as `children()` -#[derive(Debug, PartialEq, Eq, Default)] +#[derive(Debug, PartialEq, Eq, Default, Hash)] pub struct NoOp {} impl NoOp { @@ -64,9 +65,7 @@ impl PhysicalExpr for NoOp { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - Err(DataFusionError::Plan( - "NoOp::evaluate() should not be called".to_owned(), - )) + internal_err!("NoOp::evaluate() should not be called") } fn children(&self) -> Vec> { @@ -79,6 +78,11 @@ impl PhysicalExpr for NoOp { ) -> Result> { Ok(self) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for NoOp { diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index bf935aa97e617..4ceccc6932fe4 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -19,17 +19,20 @@ use std::any::Any; use std::fmt; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + cast::as_boolean_array, internal_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; /// Not expression -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct NotExpr { /// Input expression arg: Arc, @@ -80,12 +83,12 @@ impl PhysicalExpr for NotExpr { if scalar.is_null() { return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); } - let value_type = scalar.get_datatype(); + let value_type = scalar.data_type(); if value_type != DataType::Boolean { - return Err(DataFusionError::Internal(format!( + return internal_err!( "NOT '{:?}' can't be evaluated because the expression's type is {:?}, not boolean or NULL", - self.arg, value_type, - ))); + self.arg, value_type + ); } let bool_value: bool = scalar.try_into()?; Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( @@ -105,6 +108,11 @@ impl PhysicalExpr for NotExpr { ) -> Result> { Ok(Arc::new(NotExpr::new(children[0].clone()))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for NotExpr { @@ -142,7 +150,10 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index a2a61d16af418..dcd883f92965b 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -16,54 +16,44 @@ // under the License. use arrow::array::Array; -use arrow::compute::eq_dyn; +use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::{cast::as_boolean_array, DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use super::binary::array_eq_scalar; - /// Implements NULLIF(expr1, expr2) /// Args: 0 - left expr is any array /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. /// pub fn nullif_func(args: &[ColumnarValue]) -> Result { if args.len() != 2 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but NULLIF takes exactly two args", - args.len(), - ))); + args.len() + ); } let (lhs, rhs) = (&args[0], &args[1]); match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let cond_array = array_eq_scalar(lhs, rhs)?; - - let array = nullif(lhs, as_boolean_array(&cond_array)?)?; + let rhs = rhs.to_scalar()?; + let array = nullif(lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - // Get args0 == args1 evaluated and produce a boolean array - let cond_array = eq_dyn(lhs, rhs)?; - - // Now, invoke nullif on the result - let array = nullif(lhs, as_boolean_array(&cond_array)?)?; + let array = nullif(lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - // Similar to Array-Array case, except of ScalarValue -> Array cast - let lhs = lhs.to_array_of_size(rhs.len()); - let cond_array = eq_dyn(&lhs, rhs)?; - - let array = nullif(&lhs, as_boolean_array(&cond_array)?)?; + let lhs = lhs.to_array_of_size(rhs.len())?; + let array = nullif(&lhs, &eq(&lhs, &rhs)?)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { let val: ScalarValue = match lhs.eq(rhs) { - true => lhs.get_datatype().try_into()?, + true => lhs.data_type().try_into()?, false => lhs.clone(), }; @@ -99,7 +89,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(1), @@ -125,7 +115,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ None, @@ -150,7 +140,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef; @@ -164,10 +154,10 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string()))); + let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(StringArray::from(vec![ Some("foo"), @@ -188,7 +178,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; - let result = result.into_array(0); + let result = result.into_array(0).expect("Failed to convert to array"); let expected = Arc::new(Int32Array::from(vec![ Some(2), @@ -208,7 +198,7 @@ mod tests { let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; - let result_eq = result_eq.into_array(1); + let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef; @@ -218,7 +208,9 @@ mod tests { let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; - let result_neq = result_neq.into_array(1); + let result_neq = result_neq + .into_array(1) + .expect("Failed to convert to array"); let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef; assert_eq!(expected_neq.as_ref(), result_neq.as_ref()); diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index bbb29d6fb5b0a..0f7909097a106 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -17,21 +17,22 @@ use std::any::Any; use std::fmt; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; -use arrow::compute::kernels; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and retuns NULL on invalid cast -#[derive(Debug)] +#[derive(Debug, Hash)] pub struct TryCastExpr { /// The expression to cast expr: Arc, @@ -78,14 +79,18 @@ impl PhysicalExpr for TryCastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; + let options = CastOptions { + safe: true, + format_options: DEFAULT_FORMAT_OPTIONS, + }; match value { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(kernels::cast::cast( - &array, - &self.cast_type, - )?)), + ColumnarValue::Array(array) => { + let cast = cast_with_options(&array, &self.cast_type, &options)?; + Ok(ColumnarValue::Array(cast)) + } ColumnarValue::Scalar(scalar) => { - let scalar_array = scalar.to_array(); - let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?; + let array = scalar.to_array()?; + let cast_array = cast_with_options(&array, &self.cast_type, &options)?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; Ok(ColumnarValue::Scalar(cast_scalar)) } @@ -105,6 +110,11 @@ impl PhysicalExpr for TryCastExpr { self.cast_type.clone(), ))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } } impl PartialEq for TryCastExpr { @@ -131,9 +141,7 @@ pub fn try_cast( } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(TryCastExpr::new(expr, cast_type))) } else { - Err(DataFusionError::NotImplemented(format!( - "Unsupported TRY_CAST from {expr_type:?} to {cast_type:?}" - ))) + not_impl_err!("Unsupported TRY_CAST from {expr_type:?} to {cast_type:?}") } } @@ -179,7 +187,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -227,7 +238,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $TYPE); @@ -261,13 +275,13 @@ mod tests { DataType::Decimal128(10, 3), Decimal128Array, DataType::Decimal128(20, 6), - vec![ + [ Some(1_234_000), Some(2_222_000), Some(3_000), Some(4_000_000), Some(5_000_000), - None, + None ] ); @@ -277,7 +291,7 @@ mod tests { DataType::Decimal128(10, 3), Decimal128Array, DataType::Decimal128(10, 2), - vec![Some(123), Some(222), Some(0), Some(400), Some(500), None,] + [Some(123), Some(222), Some(0), Some(400), Some(500), None] ); Ok(()) @@ -295,13 +309,13 @@ mod tests { DataType::Decimal128(10, 0), Int8Array, DataType::Int8, - vec![ + [ Some(1_i8), Some(2_i8), Some(3_i8), Some(4_i8), Some(5_i8), - None, + None ] ); @@ -312,13 +326,13 @@ mod tests { DataType::Decimal128(10, 0), Int16Array, DataType::Int16, - vec![ + [ Some(1_i16), Some(2_i16), Some(3_i16), Some(4_i16), Some(5_i16), - None, + None ] ); @@ -329,13 +343,13 @@ mod tests { DataType::Decimal128(10, 0), Int32Array, DataType::Int32, - vec![ + [ Some(1_i32), Some(2_i32), Some(3_i32), Some(4_i32), Some(5_i32), - None, + None ] ); @@ -346,13 +360,13 @@ mod tests { DataType::Decimal128(10, 0), Int64Array, DataType::Int64, - vec![ + [ Some(1_i64), Some(2_i64), Some(3_i64), Some(4_i64), Some(5_i64), - None, + None ] ); @@ -364,13 +378,13 @@ mod tests { DataType::Decimal128(10, 3), Float32Array, DataType::Float32, - vec![ + [ Some(1.234_f32), Some(2.222_f32), Some(0.003_f32), Some(4.0_f32), Some(5.0_f32), - None, + None ] ); // decimal to float64 @@ -380,13 +394,13 @@ mod tests { DataType::Decimal128(20, 6), Float64Array, DataType::Float64, - vec![ + [ Some(0.001234_f64), Some(0.002222_f64), Some(0.000003_f64), Some(0.004_f64), Some(0.005_f64), - None, + None ] ); @@ -402,7 +416,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(3, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),] + [Some(1), Some(2), Some(3), Some(4), Some(5)] ); // int16 @@ -412,7 +426,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(5, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),] + [Some(1), Some(2), Some(3), Some(4), Some(5)] ); // int32 @@ -422,7 +436,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(10, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),] + [Some(1), Some(2), Some(3), Some(4), Some(5)] ); // int64 @@ -432,7 +446,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(20, 0), - vec![Some(1), Some(2), Some(3), Some(4), Some(5),] + [Some(1), Some(2), Some(3), Some(4), Some(5)] ); // int64 to different scale @@ -442,7 +456,7 @@ mod tests { vec![1, 2, 3, 4, 5], Decimal128Array, DataType::Decimal128(20, 2), - vec![Some(100), Some(200), Some(300), Some(400), Some(500),] + [Some(100), Some(200), Some(300), Some(400), Some(500)] ); // float32 @@ -452,7 +466,7 @@ mod tests { vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, DataType::Decimal128(10, 2), - vec![Some(150), Some(250), Some(300), Some(112), Some(550),] + [Some(150), Some(250), Some(300), Some(112), Some(550)] ); // float64 @@ -462,12 +476,12 @@ mod tests { vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, DataType::Decimal128(20, 4), - vec![ + [ Some(15000), Some(25000), Some(30000), Some(11235), - Some(55000), + Some(55000) ] ); Ok(()) @@ -481,7 +495,7 @@ mod tests { vec![1, 2, 3, 4, 5], UInt32Array, DataType::UInt32, - vec![ + [ Some(1_u32), Some(2_u32), Some(3_u32), @@ -500,7 +514,7 @@ mod tests { vec![1, 2, 3, 4, 5], StringArray, DataType::Utf8, - vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")] + [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")] ); Ok(()) } @@ -513,7 +527,7 @@ mod tests { vec!["a", "2", "3", "b", "5"], Int32Array, DataType::Int32, - vec![None, Some(2), Some(3), None, Some(5)] + [None, Some(2), Some(3), None, Some(5)] ); Ok(()) } @@ -541,7 +555,11 @@ mod tests { // Ensure a useful error happens at plan time if invalid casts are used let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary); + let result = try_cast( + col("a", &schema).unwrap(), + &schema, + DataType::Interval(IntervalUnit::MonthDayNano), + ); result.expect_err("expected Invalid TRY_CAST"); } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 648dd4a144c6d..53de858439190 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -31,22 +31,24 @@ //! argument is automatically is coerced to f64. use crate::execution_props::ExecutionProps; +use crate::sort_properties::SortProperties; use crate::{ array_expressions, conditional_expressions, datetime_expressions, - expressions::{cast_column, nullif_func}, - math_expressions, string_expressions, struct_expressions, PhysicalExpr, - ScalarFunctionExpr, + expressions::nullif_func, math_expressions, string_expressions, struct_expressions, + PhysicalExpr, ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, compute::kernels::length::{bit_length, length}, - datatypes::TimeUnit, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ - function, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, + ScalarFunctionImplementation, }; +use std::ops::Neg; use std::sync::Arc; /// Create a physical (function) expression. @@ -62,125 +64,45 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - let data_type = function::return_type(fun, &input_expr_types)?; + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, &fun.signature())?; - let fun_expr: ScalarFunctionImplementation = match fun { - // These functions need args and input schema to pick an implementation - // Unlike the string functions, which actually figure out the function to use with each array, - // here we return either a cast fn or string timestamp translation based on the expression data type - // so we don't have to pay a per-array/batch cost. - BuiltinScalarFunction::ToTimestamp => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function to_timestamp", - ))); - } - }) - } - BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Millisecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function to_timestamp_millis", - ))); - } - }) - } - BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Microsecond, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function to_timestamp_micros", - ))); - } - }) - } - BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function to_timestamp_seconds", - ))); - } - } - }), - BuiltinScalarFunction::FromUnixtime => Arc::new({ - match input_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) => |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - None, - ) - }, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function from_unixtime", - ))); - } - } - }), - BuiltinScalarFunction::ArrowTypeof => { - let input_data_type = input_phy_exprs[0].data_type(input_schema)?; - Arc::new(move |_| { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!( - "{input_data_type}" - ))))) - }) - } - // These don't need args and input schema - _ => create_physical_fun(fun, execution_props)?, - }; + let data_type = fun.return_type(&input_expr_types)?; + + let fun_expr: ScalarFunctionImplementation = + create_physical_fun(fun, execution_props)?; + + let monotonicity = fun.monotonicity(); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), fun_expr, input_phy_exprs.to_vec(), - &data_type, + data_type, + monotonicity, ))) } +#[cfg(feature = "encoding_expressions")] +macro_rules! invoke_if_encoding_expressions_feature_flag { + ($FUNC:ident, $NAME:expr) => {{ + use crate::encoding_expressions; + encoding_expressions::$FUNC + }}; +} + +#[cfg(not(feature = "encoding_expressions"))] +macro_rules! invoke_if_encoding_expressions_feature_flag { + ($FUNC:ident, $NAME:expr) => { + |_: &[ColumnarValue]| -> Result { + internal_err!( + "function {} requires compilation with feature flag: encoding_expressions.", + $NAME + ) + } + }; +} + #[cfg(feature = "crypto_expressions")] macro_rules! invoke_if_crypto_expressions_feature_flag { ($FUNC:ident, $NAME:expr) => {{ @@ -193,10 +115,10 @@ macro_rules! invoke_if_crypto_expressions_feature_flag { macro_rules! invoke_if_crypto_expressions_feature_flag { ($FUNC:ident, $NAME:expr) => { |_: &[ColumnarValue]| -> Result { - Err(DataFusionError::Internal(format!( + internal_err!( "function {} requires compilation with feature flag: crypto_expressions.", $NAME - ))) + ) } }; } @@ -213,10 +135,10 @@ macro_rules! invoke_on_array_if_regex_expressions_feature_flag { macro_rules! invoke_on_array_if_regex_expressions_feature_flag { ($FUNC:ident, $T:tt, $NAME:expr) => { |_: &[ArrayRef]| -> Result { - Err(DataFusionError::Internal(format!( + internal_err!( "function {} requires compilation with feature flag: regex_expressions.", $NAME - ))) + ) } }; } @@ -233,10 +155,10 @@ macro_rules! invoke_on_columnar_value_if_regex_expressions_feature_flag { macro_rules! invoke_on_columnar_value_if_regex_expressions_feature_flag { ($FUNC:ident, $T:tt, $NAME:expr) => { |_: &[ColumnarValue]| -> Result { - Err(DataFusionError::Internal(format!( + internal_err!( "function {} requires compilation with feature flag: regex_expressions.", $NAME - ))) + ) } }; } @@ -253,10 +175,10 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { macro_rules! invoke_if_unicode_expressions_feature_flag { ($FUNC:ident, $T:tt, $NAME:expr) => { |_: &[ArrayRef]| -> Result { - Err(DataFusionError::Internal(format!( + internal_err!( "function {} requires compilation with feature flag: unicode_expressions.", $NAME - ))) + ) } }; } @@ -302,6 +224,8 @@ where ColumnarValue::Array(a) => Some(a.len()), }); + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); let args = args .iter() @@ -315,15 +239,16 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>(); + .collect::>>()?; let result = (inner)(&args); - // maybe back to scalar - if len.is_some() { - result.map(ColumnarValue::Array) + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) } else { - ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) + result.map(ColumnarValue::Array) } }) } @@ -335,7 +260,9 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs), + BuiltinScalarFunction::Abs => { + Arc::new(|args| make_scalar_function(math_expressions::abs_invoke)(args)) + } BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), @@ -354,12 +281,21 @@ pub fn create_physical_fun( BuiltinScalarFunction::Gcd => { Arc::new(|args| make_scalar_function(math_expressions::gcd)(args)) } + BuiltinScalarFunction::Isnan => { + Arc::new(|args| make_scalar_function(math_expressions::isnan)(args)) + } + BuiltinScalarFunction::Iszero => { + Arc::new(|args| make_scalar_function(math_expressions::iszero)(args)) + } BuiltinScalarFunction::Lcm => { Arc::new(|args| make_scalar_function(math_expressions::lcm)(args)) } BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Nanvl => { + Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args)) + } BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), BuiltinScalarFunction::Round => { @@ -372,7 +308,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt), BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), - BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), + BuiltinScalarFunction::Trunc => { + Arc::new(|args| make_scalar_function(math_expressions::trunc)(args)) + } BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function(math_expressions::power)(args)) @@ -383,32 +321,114 @@ pub fn create_physical_fun( BuiltinScalarFunction::Log => { Arc::new(|args| make_scalar_function(math_expressions::log)(args)) } + BuiltinScalarFunction::Cot => { + Arc::new(|args| make_scalar_function(math_expressions::cot)(args)) + } // array functions - BuiltinScalarFunction::ArrayAppend => Arc::new(array_expressions::array_append), - BuiltinScalarFunction::ArrayConcat => Arc::new(array_expressions::array_concat), - BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims), - BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), - BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length), - BuiltinScalarFunction::ArrayNdims => Arc::new(array_expressions::array_ndims), + BuiltinScalarFunction::ArrayAppend => { + Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) + } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } + BuiltinScalarFunction::ArrayConcat => { + Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) + } + BuiltinScalarFunction::ArrayEmpty => { + Arc::new(|args| make_scalar_function(array_expressions::array_empty)(args)) + } + BuiltinScalarFunction::ArrayHasAll => { + Arc::new(|args| make_scalar_function(array_expressions::array_has_all)(args)) + } + BuiltinScalarFunction::ArrayHasAny => { + Arc::new(|args| make_scalar_function(array_expressions::array_has_any)(args)) + } + BuiltinScalarFunction::ArrayHas => { + Arc::new(|args| make_scalar_function(array_expressions::array_has)(args)) + } + BuiltinScalarFunction::ArrayDims => { + Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) + } + BuiltinScalarFunction::ArrayDistinct => { + Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args)) + } + BuiltinScalarFunction::ArrayElement => { + Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) + } + BuiltinScalarFunction::ArrayExcept => { + Arc::new(|args| make_scalar_function(array_expressions::array_except)(args)) + } + BuiltinScalarFunction::ArrayLength => { + Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) + } + BuiltinScalarFunction::Flatten => { + Arc::new(|args| make_scalar_function(array_expressions::flatten)(args)) + } + BuiltinScalarFunction::ArrayNdims => { + Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) + } + BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { + make_scalar_function(array_expressions::array_pop_front)(args) + }), + BuiltinScalarFunction::ArrayPopBack => { + Arc::new(|args| make_scalar_function(array_expressions::array_pop_back)(args)) + } BuiltinScalarFunction::ArrayPosition => { - Arc::new(array_expressions::array_position) + Arc::new(|args| make_scalar_function(array_expressions::array_position)(args)) + } + BuiltinScalarFunction::ArrayPositions => Arc::new(|args| { + make_scalar_function(array_expressions::array_positions)(args) + }), + BuiltinScalarFunction::ArrayPrepend => { + Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args)) + } + BuiltinScalarFunction::ArrayRepeat => { + Arc::new(|args| make_scalar_function(array_expressions::array_repeat)(args)) + } + BuiltinScalarFunction::ArrayRemove => { + Arc::new(|args| make_scalar_function(array_expressions::array_remove)(args)) + } + BuiltinScalarFunction::ArrayRemoveN => { + Arc::new(|args| make_scalar_function(array_expressions::array_remove_n)(args)) + } + BuiltinScalarFunction::ArrayRemoveAll => Arc::new(|args| { + make_scalar_function(array_expressions::array_remove_all)(args) + }), + BuiltinScalarFunction::ArrayReplace => { + Arc::new(|args| make_scalar_function(array_expressions::array_replace)(args)) + } + BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { + make_scalar_function(array_expressions::array_replace_n)(args) + }), + BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { + make_scalar_function(array_expressions::array_replace_all)(args) + }), + BuiltinScalarFunction::ArraySlice => { + Arc::new(|args| make_scalar_function(array_expressions::array_slice)(args)) + } + BuiltinScalarFunction::ArrayToString => Arc::new(|args| { + make_scalar_function(array_expressions::array_to_string)(args) + }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), + BuiltinScalarFunction::Range => { + Arc::new(|args| make_scalar_function(array_expressions::gen_range)(args)) + } + BuiltinScalarFunction::Cardinality => { + Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } - BuiltinScalarFunction::ArrayPositions => { - Arc::new(array_expressions::array_positions) + BuiltinScalarFunction::MakeArray => { + Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - BuiltinScalarFunction::ArrayPrepend => Arc::new(array_expressions::array_prepend), - BuiltinScalarFunction::ArrayRemove => Arc::new(array_expressions::array_remove), - BuiltinScalarFunction::ArrayReplace => Arc::new(array_expressions::array_replace), - BuiltinScalarFunction::ArrayToString => { - Arc::new(array_expressions::array_to_string) + BuiltinScalarFunction::ArrayUnion => { + Arc::new(|args| make_scalar_function(array_expressions::array_union)(args)) } - BuiltinScalarFunction::Cardinality => Arc::new(array_expressions::cardinality), - BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::array), - BuiltinScalarFunction::TrimArray => Arc::new(array_expressions::trim_array), + // struct functions + BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), // string functions - BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ascii::)(args) @@ -416,9 +436,7 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::ascii::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function ascii", - ))), + other => internal_err!("Unsupported data type {other:?} for function ascii"), }), BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), @@ -439,9 +457,7 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::btrim::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function btrim", - ))), + other => internal_err!("Unsupported data type {other:?} for function btrim"), }), BuiltinScalarFunction::CharacterLength => { Arc::new(|args| match args[0].data_type() { @@ -461,9 +477,9 @@ pub fn create_physical_fun( ); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function character_length", - ))), + other => internal_err!( + "Unsupported data type {other:?} for function character_length" + ), }) } BuiltinScalarFunction::Chr => { @@ -495,6 +511,24 @@ pub fn create_physical_fun( execution_props.query_execution_start_time, )) } + BuiltinScalarFunction::ToTimestamp => { + Arc::new(datetime_expressions::to_timestamp_invoke) + } + BuiltinScalarFunction::ToTimestampMillis => { + Arc::new(datetime_expressions::to_timestamp_millis_invoke) + } + BuiltinScalarFunction::ToTimestampMicros => { + Arc::new(datetime_expressions::to_timestamp_micros_invoke) + } + BuiltinScalarFunction::ToTimestampNanos => { + Arc::new(datetime_expressions::to_timestamp_nanos_invoke) + } + BuiltinScalarFunction::ToTimestampSeconds => { + Arc::new(datetime_expressions::to_timestamp_seconds_invoke) + } + BuiltinScalarFunction::FromUnixtime => { + Arc::new(datetime_expressions::from_unixtime_invoke) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::initcap::)(args) @@ -502,9 +536,9 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::initcap::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function initcap", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function initcap") + } }), BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -515,9 +549,7 @@ pub fn create_physical_fun( let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function left", - ))), + other => internal_err!("Unsupported data type {other:?} for function left"), }), BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { @@ -529,9 +561,7 @@ pub fn create_physical_fun( let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function lpad", - ))), + other => internal_err!("Unsupported data type {other:?} for function lpad"), }), BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -540,9 +570,7 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::ltrim::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function ltrim", - ))), + other => internal_err!("Unsupported data type {other:?} for function ltrim"), }), BuiltinScalarFunction::MD5 => { Arc::new(invoke_if_crypto_expressions_feature_flag!(md5, "md5")) @@ -550,6 +578,12 @@ pub fn create_physical_fun( BuiltinScalarFunction::Digest => { Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) } + BuiltinScalarFunction::Decode => Arc::new( + invoke_if_encoding_expressions_feature_flag!(decode, "decode"), + ), + BuiltinScalarFunction::Encode => Arc::new( + invoke_if_encoding_expressions_feature_flag!(encode, "encode"), + ), BuiltinScalarFunction::NullIf => Arc::new(nullif_func), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), @@ -581,9 +615,9 @@ pub fn create_physical_fun( ); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Unsupported data type {other:?} for function regexp_match" - ))), + ), }) } BuiltinScalarFunction::RegexpReplace => { @@ -606,9 +640,9 @@ pub fn create_physical_fun( let func = specializer_func(args)?; func(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function regexp_replace", - ))), + other => internal_err!( + "Unsupported data type {other:?} for function regexp_replace" + ), }) } BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { @@ -618,9 +652,7 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::repeat::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function repeat", - ))), + other => internal_err!("Unsupported data type {other:?} for function repeat"), }), BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -629,9 +661,9 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::replace::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function replace", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function replace") + } }), BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -644,9 +676,9 @@ pub fn create_physical_fun( invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function reverse", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function reverse") + } }), BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -659,9 +691,7 @@ pub fn create_physical_fun( invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function right", - ))), + other => internal_err!("Unsupported data type {other:?} for function right"), }), BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -672,9 +702,7 @@ pub fn create_physical_fun( let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad", - ))), + other => internal_err!("Unsupported data type {other:?} for function rpad"), }), BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -683,9 +711,7 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::rtrim::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rtrim", - ))), + other => internal_err!("Unsupported data type {other:?} for function rtrim"), }), BuiltinScalarFunction::SHA224 => { Arc::new(invoke_if_crypto_expressions_feature_flag!(sha224, "sha224")) @@ -706,10 +732,25 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::split_part::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function split_part", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function split_part") + } }), + BuiltinScalarFunction::StringToArray => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(array_expressions::string_to_array::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(array_expressions::string_to_array::)(args) + } + other => { + internal_err!( + "Unsupported data type {other:?} for function string_to_array" + ) + } + }) + } BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::starts_with::)(args) @@ -717,9 +758,9 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::starts_with::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function starts_with", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function starts_with") + } }), BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -734,9 +775,7 @@ pub fn create_physical_fun( ); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function strpos", - ))), + other => internal_err!("Unsupported data type {other:?} for function strpos"), }), BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -749,9 +788,7 @@ pub fn create_physical_fun( invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function substr", - ))), + other => internal_err!("Unsupported data type {other:?} for function substr"), }), BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { DataType::Int32 => { @@ -760,9 +797,7 @@ pub fn create_physical_fun( DataType::Int64 => { make_scalar_function(string_expressions::to_hex::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function to_hex", - ))), + other => internal_err!("Unsupported data type {other:?} for function to_hex"), }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -781,9 +816,9 @@ pub fn create_physical_fun( ); make_scalar_function(func)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function translate", - ))), + other => { + internal_err!("Unsupported data type {other:?} for function translate") + } }), BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -792,20 +827,155 @@ pub fn create_physical_fun( DataType::LargeUtf8 => { make_scalar_function(string_expressions::btrim::)(args) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function trim", - ))), + other => internal_err!("Unsupported data type {other:?} for function trim"), }), BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - _ => { - return Err(DataFusionError::Internal(format!( - "create_physical_fun: Unsupported scalar function {fun:?}" - ))); + BuiltinScalarFunction::ArrowTypeof => Arc::new(move |args| { + if args.len() != 1 { + return internal_err!( + "arrow_typeof function requires 1 arguments, got {}", + args.len() + ); + } + + let input_data_type = args[0].data_type(); + Ok(ColumnarValue::Scalar(ScalarValue::from(format!( + "{input_data_type}" + )))) + }), + BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::overlay::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function overlay", + ))), + }), + BuiltinScalarFunction::Levenshtein => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function levenshtein", + ))), + }) } + BuiltinScalarFunction::SubstrIndex => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i32, + "substr_index" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + substr_index, + i64, + "substr_index" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function substr_index", + ))), + }) + } + BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int32Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + find_in_set, + Int64Type, + "find_in_set" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function find_in_set", + ))), + }), }) } +#[deprecated( + since = "32.0.0", + note = "Moved to `expr` crate. Please use `BuiltinScalarFunction::monotonicity()` instead" +)] +pub fn get_func_monotonicity(fun: &BuiltinScalarFunction) -> Option { + fun.monotonicity() +} + +/// Determines a [`ScalarFunctionExpr`]'s monotonicity for the given arguments +/// and the function's behavior depending on its arguments. +pub fn out_ordering( + func: &FuncMonotonicity, + arg_orderings: &[SortProperties], +) -> SortProperties { + func.iter().zip(arg_orderings).fold( + SortProperties::Singleton, + |prev_sort, (item, arg)| { + let current_sort = func_order_in_one_dimension(item, arg); + + match (prev_sort, current_sort) { + (_, SortProperties::Unordered) => SortProperties::Unordered, + (SortProperties::Singleton, SortProperties::Ordered(_)) => current_sort, + (SortProperties::Ordered(prev), SortProperties::Ordered(current)) + if prev.descending != current.descending => + { + SortProperties::Unordered + } + _ => prev_sort, + } + }, + ) +} + +/// This function decides the monotonicity property of a [`ScalarFunctionExpr`] for a single argument (i.e. across a single dimension), given that argument's sort properties. +fn func_order_in_one_dimension( + func_monotonicity: &Option, + arg: &SortProperties, +) -> SortProperties { + if *arg == SortProperties::Singleton { + SortProperties::Singleton + } else { + match func_monotonicity { + None => SortProperties::Unordered, + Some(false) => { + if let SortProperties::Ordered(_) = arg { + arg.neg() + } else { + SortProperties::Unordered + } + } + Some(true) => { + if let SortProperties::Ordered(_) = arg { + *arg + } else { + SortProperties::Unordered + } + } + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -820,6 +990,7 @@ mod tests { record_batch::RecordBatch, }; use datafusion_common::cast::as_uint64_array; + use datafusion_common::{exec_err, plan_err}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; @@ -851,7 +1022,7 @@ mod tests { match expected { Ok(expected) => { let result = expr.evaluate(&batch)?; - let result = result.into_array(batch.num_rows()); + let result = result.into_array(batch.num_rows()).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); // value is correct @@ -865,7 +1036,7 @@ mod tests { match expr.evaluate(&batch) { Ok(_) => assert!(false, "expected error"), Err(error) => { - assert_eq!(error.to_string(), expected_error.to_string()); + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); } } } @@ -1023,9 +1194,9 @@ mod tests { test_function!( CharacterLength, &[lit("josé")], - Err(DataFusionError::Internal( - "function character_length requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function character_length requires compilation with feature flag: unicode_expressions." + ), i32, Int32, Int32Array @@ -1073,9 +1244,7 @@ mod tests { test_function!( Chr, &[lit(ScalarValue::Int64(Some(0)))], - Err(DataFusionError::Execution( - "null character not permitted.".to_string(), - )), + exec_err!("null character not permitted."), &str, Utf8, StringArray @@ -1083,9 +1252,7 @@ mod tests { test_function!( Chr, &[lit(ScalarValue::Int64(Some(i64::MAX)))], - Err(DataFusionError::Execution( - "requested character too large for encoding.".to_string(), - )), + exec_err!("requested character too large for encoding."), &str, Utf8, StringArray @@ -1300,9 +1467,9 @@ mod tests { lit("abcde"), lit(ScalarValue::Int8(Some(2))), ], - Err(DataFusionError::Internal( - "function left requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function left requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -1451,9 +1618,9 @@ mod tests { lit("josé"), lit(ScalarValue::Int64(Some(5))), ], - Err(DataFusionError::Internal( - "function lpad requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function lpad requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -1537,9 +1704,9 @@ mod tests { test_function!( MD5, &[lit("tom")], - Err(DataFusionError::Internal( - "function md5 requires compilation with feature flag: crypto_expressions.".to_string() - )), + internal_err!( + "function md5 requires compilation with feature flag: crypto_expressions." + ), &str, Utf8, StringArray @@ -1687,9 +1854,9 @@ mod tests { lit("b.."), lit("X"), ], - Err(DataFusionError::Internal( - "function regexp_replace requires compilation with feature flag: regex_expressions.".to_string() - )), + internal_err!( + "function regexp_replace requires compilation with feature flag: regex_expressions." + ), &str, Utf8, StringArray @@ -1761,9 +1928,9 @@ mod tests { test_function!( Reverse, &[lit("abcde")], - Err(DataFusionError::Internal( - "function reverse requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function reverse requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -1859,9 +2026,9 @@ mod tests { lit("abcde"), lit(ScalarValue::Int8(Some(2))), ], - Err(DataFusionError::Internal( - "function right requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function right requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -2010,9 +2177,9 @@ mod tests { lit("josé"), lit(ScalarValue::Int64(Some(5))), ], - Err(DataFusionError::Internal( - "function rpad requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function rpad requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -2104,9 +2271,9 @@ mod tests { test_function!( SHA224, &[lit("tom")], - Err(DataFusionError::Internal( - "function sha224 requires compilation with feature flag: crypto_expressions.".to_string() - )), + internal_err!( + "function sha224 requires compilation with feature flag: crypto_expressions." + ), &[u8], Binary, BinaryArray @@ -2150,9 +2317,9 @@ mod tests { test_function!( SHA256, &[lit("tom")], - Err(DataFusionError::Internal( - "function sha256 requires compilation with feature flag: crypto_expressions.".to_string() - )), + internal_err!( + "function sha256 requires compilation with feature flag: crypto_expressions." + ), &[u8], Binary, BinaryArray @@ -2200,9 +2367,9 @@ mod tests { test_function!( SHA384, &[lit("tom")], - Err(DataFusionError::Internal( - "function sha384 requires compilation with feature flag: crypto_expressions.".to_string() - )), + internal_err!( + "function sha384 requires compilation with feature flag: crypto_expressions." + ), &[u8], Binary, BinaryArray @@ -2252,9 +2419,9 @@ mod tests { test_function!( SHA512, &[lit("tom")], - Err(DataFusionError::Internal( - "function sha512 requires compilation with feature flag: crypto_expressions.".to_string() - )), + internal_err!( + "function sha512 requires compilation with feature flag: crypto_expressions." + ), &[u8], Binary, BinaryArray @@ -2290,9 +2457,7 @@ mod tests { lit("~@~"), lit(ScalarValue::Int64(Some(-1))), ], - Err(DataFusionError::Execution( - "field position must be greater than zero".to_string(), - )), + exec_err!("field position must be greater than zero"), &str, Utf8, StringArray @@ -2390,9 +2555,9 @@ mod tests { lit("joséésoj"), lit(ScalarValue::Utf8(None)), ], - Err(DataFusionError::Internal( - "function strpos requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function strpos requires compilation with feature flag: unicode_expressions." + ), i32, Int32, Int32Array @@ -2593,9 +2758,7 @@ mod tests { lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(-1))), ], - Err(DataFusionError::Execution( - "negative substring length not allowed: substr(, 1, -1)".to_string(), - )), + exec_err!("negative substring length not allowed: substr(, 1, -1)"), &str, Utf8, StringArray @@ -2620,9 +2783,9 @@ mod tests { lit("alphabet"), lit(ScalarValue::Int64(Some(0))), ], - Err(DataFusionError::Internal( - "function substr requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function substr requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -2680,9 +2843,9 @@ mod tests { lit("143"), lit("ax"), ], - Err(DataFusionError::Internal( - "function translate requires compilation with feature flag: unicode_expressions.".to_string() - )), + internal_err!( + "function translate requires compilation with feature flag: unicode_expressions." + ), &str, Utf8, StringArray @@ -2769,21 +2932,16 @@ mod tests { match expr { Ok(..) => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Builtin scalar function {fun} does not support empty arguments" - ))); + ); } - Err(DataFusionError::Plan(err)) => { - if !err - .contains("No function matches the given name and argument types") - { - return Err(DataFusionError::Internal(format!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"))); - } + Err(DataFusionError::Plan(_)) => { + // Continue the loop } Err(..) => { - return Err(DataFusionError::Internal(format!( - "Builtin scalar function {fun} didn't got the right error with empty arguments"))); + return internal_err!( + "Builtin scalar function {fun} didn't got the right error with empty arguments"); } } } @@ -2833,7 +2991,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -2872,7 +3033,10 @@ mod tests { // evaluate works let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // downcast works let result = as_list_array(&result)?; @@ -2921,7 +3085,7 @@ mod tests { execution_props: &ExecutionProps, ) -> Result> { let type_coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap(); + coerce(input_phy_exprs, input_schema, &fun.signature()).unwrap(); create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props) } @@ -2935,9 +3099,7 @@ mod tests { if let ColumnarValue::Array(array) = col? { Ok(as_uint64_array(&array)?.values().to_vec()) } else { - Err(DataFusionError::Internal( - "Unexpected scalar created by a test function".to_string(), - )) + internal_err!("Unexpected scalar created by a test function") } } @@ -2946,8 +3108,11 @@ mod tests { let adapter_func = make_scalar_function(dummy_function); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -2959,8 +3124,11 @@ mod tests { let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 5]); @@ -2975,8 +3143,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); @@ -2985,8 +3156,11 @@ mod tests { #[test] fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let adapter_func = make_scalar_function_with_hints( dummy_function, vec![Hint::Pad, Hint::AcceptsSingular], @@ -3006,8 +3180,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg, scalar_arg.clone(), @@ -3026,8 +3203,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[ array_arg.clone(), scalar_arg.clone(), @@ -3054,8 +3234,11 @@ mod tests { ); let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = - ColumnarValue::Array(ScalarValue::Int64(Some(1)).to_array_of_size(5)); + let array_arg = ColumnarValue::Array( + ScalarValue::Int64(Some(1)) + .to_array_of_size(5) + .expect("Failed to convert to array of size"), + ); let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; assert_eq!(result, vec![5, 1]); diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index a1698e66511a1..5064ad8d5c487 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -21,24 +21,23 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::sync::Arc; -use arrow_schema::DataType; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; +use super::utils::{ + convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, +}; +use crate::expressions::Literal; +use crate::utils::{build_dag, ExprTreeNode}; +use crate::PhysicalExpr; + +use arrow_schema::{DataType, Schema}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; + use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; use petgraph::Outgoing; -use crate::expressions::{BinaryExpr, CastExpr, Column, Literal}; -use crate::intervals::interval_aritmetic::{ - apply_operator, is_operator_supported, Interval, -}; -use crate::utils::{build_dag, ExprTreeNode}; -use crate::PhysicalExpr; - -use super::IntervalBound; - // Interval arithmetic provides a way to perform mathematical operations on // intervals, which represent a range of possible values rather than a single // point value. This allows for the propagation of ranges through mathematical @@ -119,7 +118,7 @@ use super::IntervalBound; /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ExprIntervalGraph { graph: StableGraph, root: NodeIndex, @@ -147,7 +146,7 @@ pub enum PropagationResult { } /// This is a node in the DAEG; it encapsulates a reference to the actual -/// [PhysicalExpr] as well as an interval containing expression bounds. +/// [`PhysicalExpr`] as well as an interval containing expression bounds. #[derive(Clone, Debug)] pub struct ExprIntervalGraphNode { expr: Arc, @@ -162,11 +161,9 @@ impl Display for ExprIntervalGraphNode { impl ExprIntervalGraphNode { /// Constructs a new DAEG node with an [-∞, ∞] range. - pub fn new(expr: Arc) -> Self { - ExprIntervalGraphNode { - expr, - interval: Interval::default(), - } + pub fn new_unbounded(expr: Arc, dt: &DataType) -> Result { + Interval::make_unbounded(dt) + .map(|interval| ExprIntervalGraphNode { expr, interval }) } /// Constructs a new DAEG node with the given range. @@ -179,39 +176,28 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = node.expression().clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); - let interval = Interval::new( - IntervalBound::new(value.clone(), false), - IntervalBound::new(value.clone(), false), - ); - ExprIntervalGraphNode::new_with_interval(expr, interval) + Interval::try_new(value.clone(), value.clone()) + .map(|interval| Self::new_with_interval(expr, interval)) } else { - ExprIntervalGraphNode::new(expr) + expr.data_type(schema) + .and_then(|dt| Self::new_unbounded(expr, &dt)) } } } impl PartialEq for ExprIntervalGraphNode { - fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + fn eq(&self, other: &Self) -> bool { self.expr.eq(&other.expr) } } -// This function returns the inverse operator of the given operator. -fn get_inverse_op(op: Operator) -> Operator { - match op { - Operator::Plus => Operator::Minus, - Operator::Minus => Operator::Plus, - _ => unreachable!(), - } -} - /// This function refines intervals `left_child` and `right_child` by applying /// constraint propagation through `parent` via operation. The main idea is /// that we can shrink ranges of variables x and y using parent interval p. @@ -224,78 +210,150 @@ fn get_inverse_op(op: Operator) -> Operator { /// - For minus operation, specifically, we would first do /// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then /// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +/// - For multiplication operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] / [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] / [xL, xU]) ∩ [yL, yU]. +/// - For division operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] * [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] / [pL, pU]) ∩ [yL, yU]. pub fn propagate_arithmetic( op: &Operator, parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let inverse_op = get_inverse_op(*op); - // First, propagate to the left: - match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { - // Left is feasible: - Some(value) => { - // Propagate to the right using the new left. - let right = match op { - Operator::Minus => apply_operator(op, &value, parent), - Operator::Plus => apply_operator(&inverse_op, parent, &value), - _ => unreachable!(), - }? - .intersect(right_child)?; - // Return intervals for both children: - Ok((Some(value), right)) +) -> Result> { + let inverse_op = get_inverse_op(*op)?; + match (left_child.data_type(), right_child.data_type()) { + // If we have a child whose type is a time interval (i.e. DataType::Interval), + // we need special handling since timestamp differencing results in a + // Duration type. + (DataType::Timestamp(..), DataType::Interval(_)) => { + propagate_time_interval_at_right( + left_child, + right_child, + parent, + op, + &inverse_op, + ) + } + (DataType::Interval(_), DataType::Timestamp(..)) => { + propagate_time_interval_at_left( + left_child, + right_child, + parent, + op, + &inverse_op, + ) + } + _ => { + // First, propagate to the left: + match apply_operator(&inverse_op, parent, right_child)? + .intersect(left_child)? + { + // Left is feasible: + Some(value) => Ok( + // Propagate to the right using the new left. + propagate_right(&value, parent, right_child, op, &inverse_op)? + .map(|right| (value, right)), + ), + // If the left child is infeasible, short-circuit. + None => Ok(None), + } } - // If the left child is infeasible, short-circuit. - None => Ok((None, None)), } } -/// This function provides a target parent interval for comparison operators. -/// If we have expression > 0, expression must have the range [0, ∞]. -/// If we have expression < 0, expression must have the range [-∞, 0]. -/// Currently, we only support strict inequalities since open/closed intervals -/// are not implemented yet. -fn comparison_operator_target( - left_datatype: &DataType, - op: &Operator, - right_datatype: &DataType, -) -> Result { - let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?; - let unbounded = IntervalBound::make_unbounded(&datatype)?; - let zero = ScalarValue::new_zero(&datatype)?; - Ok(match *op { - Operator::GtEq => Interval::new(IntervalBound::new(zero, false), unbounded), - Operator::Gt => Interval::new(IntervalBound::new(zero, true), unbounded), - Operator::LtEq => Interval::new(unbounded, IntervalBound::new(zero, false)), - Operator::Lt => Interval::new(unbounded, IntervalBound::new(zero, true)), - _ => unreachable!(), - }) -} - -/// This function propagates constraints arising from comparison operators. -/// The main idea is that we can analyze an inequality like x > y through the -/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] -/// and [yL, yU], we simply apply constraint propagation across [xL, xU], -/// [yL, yH] and [0, ∞]. Specifically, we would first do -/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then -/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. +/// This function refines intervals `left_child` and `right_child` by applying +/// comparison propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. +/// Two intervals can be ordered in 6 ways for a Gt `>` operator: +/// ```text +/// (1): Infeasible, short-circuit +/// left: | ================ | +/// right: | ======================== | +/// +/// (2): Update both interval +/// left: | ====================== | +/// right: | ====================== | +/// | +/// V +/// left: | ======= | +/// right: | ======= | +/// +/// (3): Update left interval +/// left: | ============================== | +/// right: | ========== | +/// | +/// V +/// left: | ===================== | +/// right: | ========== | +/// +/// (4): Update right interval +/// left: | ========== | +/// right: | =========================== | +/// | +/// V +/// left: | ========== | +/// right | ================== | +/// +/// (5): No change +/// left: | ============================ | +/// right: | =================== | +/// +/// (6): No change +/// left: | ==================== | +/// right: | =============== | +/// +/// -inf --------------------------------------------------------------- +inf +/// ``` pub fn propagate_comparison( op: &Operator, + parent: &Interval, left_child: &Interval, right_child: &Interval, -) -> Result<(Option, Option)> { - let parent = comparison_operator_target( - &left_child.get_datatype()?, - op, - &right_child.get_datatype()?, - )?; - propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child) +) -> Result> { + if parent == &Interval::CERTAINLY_TRUE { + match op { + Operator::Eq => left_child.intersect(right_child).map(|result| { + result.map(|intersection| (intersection.clone(), intersection)) + }), + Operator::Gt => satisfy_greater(left_child, right_child, true), + Operator::GtEq => satisfy_greater(left_child, right_child, false), + Operator::Lt => satisfy_greater(right_child, left_child, true) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(right_child, left_child, false) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), + } + } else if parent == &Interval::CERTAINLY_FALSE { + match op { + Operator::Eq => { + // TODO: Propagation is not possible until we support interval sets. + Ok(None) + } + Operator::Gt => satisfy_greater(right_child, left_child, false), + Operator::GtEq => satisfy_greater(right_child, left_child, true), + Operator::Lt => satisfy_greater(left_child, right_child, false) + .map(|t| t.map(reverse_tuple)), + Operator::LtEq => satisfy_greater(left_child, right_child, true) + .map(|t| t.map(reverse_tuple)), + _ => internal_err!( + "The operator must be a comparison operator to propagate intervals" + ), + } + } else { + // Uncertainty cannot change any end-point of the intervals. + Ok(None) + } } impl ExprIntervalGraph { - pub fn try_new(expr: Arc) -> Result { + pub fn try_new(expr: Arc, schema: &Schema) -> Result { // Build the full graph: - let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; + let (root, graph) = + build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?; Ok(Self { graph, root }) } @@ -348,7 +406,7 @@ impl ExprIntervalGraph { // // ``` - /// This function associates stable node indices with [PhysicalExpr]s so + /// This function associates stable node indices with [`PhysicalExpr`]s so /// that we can match `Arc` and NodeIndex objects during /// membership tests. pub fn gather_node_indices( @@ -402,6 +460,33 @@ impl ExprIntervalGraph { nodes } + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_ranges( + &mut self, + leaf_bounds: &mut [(usize, Interval)], + given_range: Interval, + ) -> Result { + self.assign_intervals(leaf_bounds); + let bounds = self.evaluate_bounds()?; + // There are three possible cases to consider: + // (1) given_range ⊇ bounds => Nothing to propagate + // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate + // (3) Disjoint sets => Infeasible + if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE { + // First case: + Ok(PropagationResult::CannotPropagate) + } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE { + // Second case: + let result = self.propagate_constraints(given_range); + self.update_intervals(leaf_bounds); + result + } else { + // Third case: + Ok(PropagationResult::Infeasible) + } + } + /// This function assigns given ranges to expressions in the DAEG. /// The argument `assignments` associates indices of sought expressions /// with their corresponding new ranges. @@ -431,34 +516,43 @@ impl ExprIntervalGraph { /// # Examples /// /// ``` - /// use std::sync::Arc; - /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; - /// use datafusion_physical_expr::intervals::{Interval, IntervalBound, ExprIntervalGraph}; - /// use datafusion_physical_expr::PhysicalExpr; - /// let expr = Arc::new(BinaryExpr::new( - /// Arc::new(Column::new("gnz", 0)), - /// Operator::Plus, - /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), - /// )); - /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); - /// // Do it once, while constructing. - /// let node_indices = graph + /// use arrow::datatypes::DataType; + /// use arrow::datatypes::Field; + /// use arrow::datatypes::Schema; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::interval_arithmetic::Interval; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; + /// use datafusion_physical_expr::PhysicalExpr; + /// use std::sync::Arc; + /// + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// + /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]); + /// + /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); - /// let left_index = node_indices.get(0).unwrap().1; - /// // Provide intervals for leaf variables (here, there is only one). - /// let intervals = vec![( + /// let left_index = node_indices.get(0).unwrap().1; + /// + /// // Provide intervals for leaf variables (here, there is only one). + /// let intervals = vec![( /// left_index, - /// Interval::make(Some(10), Some(20), (true, true)), - /// )]; - /// // Evaluate bounds for the composite expression: - /// graph.assign_intervals(&intervals); - /// assert_eq!( - /// graph.evaluate_bounds().unwrap(), - /// &Interval::make(Some(20), Some(30), (true, true)), - /// ) + /// Interval::make(Some(10), Some(20)).unwrap(), + /// )]; /// + /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); + /// assert_eq!( + /// graph.evaluate_bounds().unwrap(), + /// &Interval::make(Some(20), Some(30)).unwrap(), + /// ) /// ``` pub fn evaluate_bounds(&mut self) -> Result<&Interval> { let mut dfs = DfsPostOrder::new(&self.graph, self.root); @@ -470,7 +564,7 @@ impl ExprIntervalGraph { // If the current expression is a leaf, its interval should already // be set externally, just continue with the evaluation procedure: if !children_intervals.is_empty() { - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children_intervals.reverse(); self.graph[node].interval = self.graph[node].expr.evaluate_bounds(&children_intervals)?; @@ -481,8 +575,19 @@ impl ExprIntervalGraph { /// Updates/shrinks bounds for leaf expressions using interval arithmetic /// via a top-down traversal. - fn propagate_constraints(&mut self) -> Result { + fn propagate_constraints( + &mut self, + given_range: Interval, + ) -> Result { let mut bfs = Bfs::new(&self.graph, self.root); + + // Adjust the root node with the given range: + if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? { + self.graph[self.root].interval = interval; + } else { + return Ok(PropagationResult::Infeasible); + } + while let Some(node) = bfs.next(&self.graph) { let neighbors = self.graph.neighbors_directed(node, Outgoing); let mut children = neighbors.collect::>(); @@ -491,7 +596,7 @@ impl ExprIntervalGraph { if children.is_empty() { continue; } - // Reverse to align with [PhysicalExpr]'s children: + // Reverse to align with `PhysicalExpr`'s children: children.reverse(); let children_intervals = children .iter() @@ -501,66 +606,132 @@ impl ExprIntervalGraph { let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; - for (child, interval) in children.into_iter().zip(propagated_intervals) { - if let Some(interval) = interval { + if let Some(propagated_intervals) = propagated_intervals { + for (child, interval) in children.into_iter().zip(propagated_intervals) { self.graph[child].interval = interval; - } else { - // The constraint is infeasible, report: - return Ok(PropagationResult::Infeasible); } + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); } } Ok(PropagationResult::Success) } - /// Updates intervals for all expressions in the DAEG by successive - /// bottom-up and top-down traversals. - pub fn update_ranges( - &mut self, - leaf_bounds: &mut [(usize, Interval)], - ) -> Result { - self.assign_intervals(leaf_bounds); - let bounds = self.evaluate_bounds()?; - if bounds == &Interval::CERTAINLY_FALSE { - Ok(PropagationResult::Infeasible) - } else if bounds == &Interval::UNCERTAIN { - let result = self.propagate_constraints(); - self.update_intervals(leaf_bounds); - result - } else { - Ok(PropagationResult::CannotPropagate) - } + /// Returns the interval associated with the node at the given `index`. + pub fn get_interval(&self, index: usize) -> Interval { + self.graph[NodeIndex::new(index)].interval.clone() } } -/// Indicates whether interval arithmetic is supported for the given expression. -/// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. -/// We do not support every type of [`Operator`]s either. Over time, this check -/// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. -/// Currently, [`CastExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. -pub fn check_support(expr: &Arc) -> bool { - let expr_any = expr.as_any(); - let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::() - { - is_operator_supported(binary_expr.op()) +/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child. +fn propagate_right( + left: &Interval, + parent: &Interval, + right: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + match op { + Operator::Minus => apply_operator(op, left, parent), + Operator::Plus => apply_operator(inverse_op, parent, left), + Operator::Divide => apply_operator(op, left, parent), + Operator::Multiply => apply_operator(inverse_op, parent, left), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), + }? + .intersect(right) +} + +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the left side of the operation. +fn propagate_time_interval_at_left( + left_child: &Interval, + right_child: &Interval, + parent: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + // We check if the child's time interval(s) has a non-zero month or day field(s). + // If so, we return it as is without propagating. Otherwise, we first convert + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(left_child) { + match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? { + Some(value) => { + let left = convert_duration_type_to_interval(&value); + let right = propagate_right(&value, parent, right_child, op, inverse_op)?; + match (left, right) { + (Some(left), Some(right)) => Some((left, right)), + _ => None, + } + } + None => None, + } + } else { + propagate_right(left_child, parent, right_child, op, inverse_op)? + .map(|right| (left_child.clone(), right)) + }; + Ok(result) +} + +/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`], +/// if there exists a `timestamp - timestamp` operation, the result would be +/// of type `Duration`. However, we may encounter a situation where a time interval +/// is involved in an arithmetic operation with a `Duration` type. This function +/// offers special handling for such cases, where the time interval resides on +/// the right side of the operation. +fn propagate_time_interval_at_right( + left_child: &Interval, + right_child: &Interval, + parent: &Interval, + op: &Operator, + inverse_op: &Operator, +) -> Result> { + // We check if the child's time interval(s) has a non-zero month or day field(s). + // If so, we return it as is without propagating. Otherwise, we first convert + // the time intervals to the `Duration` type, then propagate, and then convert + // the bounds to time intervals again. + let result = if let Some(duration) = convert_interval_type_to_duration(right_child) { + match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? { + Some(value) => { + propagate_right(left_child, parent, &duration, op, inverse_op)? + .and_then(|right| convert_duration_type_to_interval(&right)) + .map(|right| (value, right)) + } + None => None, + } } else { - expr_any.is::() || expr_any.is::() || expr_any.is::() + apply_operator(inverse_op, parent, right_child)? + .intersect(left_child)? + .map(|value| (value, right_child.clone())) }; - expr_supported && expr.children().iter().all(check_support) + Ok(result) +} + +fn reverse_tuple((first, second): (T, U)) -> (U, T) { + (second, first) } #[cfg(test)] mod tests { use super::*; - use itertools::Itertools; - use crate::expressions::{BinaryExpr, Column}; use crate::intervals::test_utils::gen_conjunctive_numerical_expr; + + use arrow::datatypes::TimeUnit; + use arrow_schema::{DataType, Field}; use datafusion_common::ScalarValue; + + use itertools::Itertools; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rstest::*; + #[allow(clippy::too_many_arguments)] fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), @@ -569,6 +740,7 @@ mod tests { left_expected: Interval, right_expected: Interval, result: PropagationResult, + schema: &Schema, ) -> Result<()> { let col_stats = vec![ (exprs_with_interval.0.clone(), left_interval), @@ -578,7 +750,7 @@ mod tests { (exprs_with_interval.0.clone(), left_expected), (exprs_with_interval.1.clone(), right_expected), ]; - let mut graph = ExprIntervalGraph::try_new(expr)?; + let mut graph = ExprIntervalGraph::try_new(expr, schema)?; let expr_indexes = graph .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); @@ -593,14 +765,37 @@ mod tests { .map(|((_, interval), (_, index))| (*index, interval.clone())) .collect_vec(); - let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; + let exp_result = + graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?; assert_eq!(exp_result, result); col_stat_nodes.iter().zip(expected_nodes.iter()).for_each( |((_, calculated_interval_node), (_, expected))| { // NOTE: These randomized tests only check for conservative containment, // not openness/closedness of endpoints. - assert!(calculated_interval_node.lower.value <= expected.lower.value); - assert!(calculated_interval_node.upper.value >= expected.upper.value); + + // Calculated bounds are relaxed by 1 to cover all strict and + // and non-strict comparison cases since we have only closed bounds. + let one = ScalarValue::new_one(&expected.data_type()).unwrap(); + assert!( + calculated_interval_node.lower() + <= &expected.lower().add(&one).unwrap(), + "{}", + format!( + "Calculated {} must be less than or equal {}", + calculated_interval_node.lower(), + expected.lower() + ) + ); + assert!( + calculated_interval_node.upper() + >= &expected.upper().sub(&one).unwrap(), + "{}", + format!( + "Calculated {} must be greater than or equal {}", + calculated_interval_node.upper(), + expected.upper() + ) + ); }, ); Ok(()) @@ -640,12 +835,24 @@ mod tests { experiment( expr, - (left_col, right_col), - Interval::make(left_given.0, left_given.1, (true, true)), - Interval::make(right_given.0, right_given.1, (true, true)), - Interval::make(left_expected.0, left_expected.1, (true, true)), - Interval::make(right_expected.0, right_expected.1, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(left_given.0, left_given.1).unwrap(), + Interval::make(right_given.0, right_given.1).unwrap(), + Interval::make(left_expected.0, left_expected.1).unwrap(), + Interval::make(right_expected.0, right_expected.1).unwrap(), PropagationResult::Success, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::$SCALAR, + true, + ), + ]), ) } }; @@ -669,12 +876,24 @@ mod tests { let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); experiment( expr, - (left_col, right_col), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), - Interval::make(Some(10), Some(20), (true, true)), - Interval::make(Some(100), None, (true, true)), + (left_col.clone(), right_col.clone()), + Interval::make(Some(10_i32), Some(20_i32))?, + Interval::make(Some(100), None)?, + Interval::make(Some(10), Some(20))?, + Interval::make(Some(100), None)?, PropagationResult::Infeasible, + &Schema::new(vec![ + Field::new( + left_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + Field::new( + right_col.as_any().downcast_ref::().unwrap().name(), + DataType::Int32, + true, + ), + ]), ) } @@ -979,7 +1198,14 @@ mod tests { Arc::new(Column::new("b", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1018,7 +1244,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1057,7 +1292,15 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1080,9 +1323,9 @@ mod tests { fn test_gather_node_indices_cannot_provide() -> Result<()> { // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. - // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. - // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. - // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. let left_expr = Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1099,7 +1342,16 @@ mod tests { Arc::new(Column::new("z", 1)), )); let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); - let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + let mut graph = ExprIntervalGraph::try_new( + expr, + &Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]), + ) + .unwrap(); // Define a test leaf node. let leaf_node = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), @@ -1116,4 +1368,281 @@ mod tests { assert_eq!(prev_node_count, final_node_count); Ok(()) } + + #[test] + fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> { + let expression = BinaryExpr::new( + Arc::new(Column::new("ts_column", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))), + ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None), + // 16.10.2020 - 10:11:12.000_000_321 AM + ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None), + )?; + let left_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None), + )?; + let right_child = Interval::try_new( + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + // 1 day 321 ns + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + )?; + let children = vec![&left_child, &right_child]; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); + + assert_eq!( + vec![ + Interval::try_new( + // 14.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond( + Some(1_602_670_272_000_000_000), + None + ), + // 15.10.2020 - 10:11:12 AM + ScalarValue::TimestampNanosecond( + Some(1_602_756_672_000_000_000), + None + ), + )?, + Interval::try_new( + // 1 day 321 ns in Duration type + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + // 1 day 321 ns in Duration type + ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + )? + ], + result + ); + + Ok(()) + } + + #[test] + fn test_propagate_constraints_column_interval_at_left() -> Result<()> { + let expression = BinaryExpr::new( + Arc::new(Column::new("interval_column", 1)), + Operator::Plus, + Arc::new(Column::new("ts_column", 0)), + ); + let parent = Interval::try_new( + // 15.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None), + // 16.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None), + )?; + let right_child = Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 20.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), + )?; + let left_child = Interval::try_new( + // 2 days + ScalarValue::IntervalDayTime(Some(172_800_000)), + // 10 days + ScalarValue::IntervalDayTime(Some(864_000_000)), + )?; + let children = vec![&left_child, &right_child]; + let result = expression + .propagate_constraints(&parent, &children)? + .unwrap(); + + assert_eq!( + vec![ + Interval::try_new( + // 2 days + ScalarValue::IntervalDayTime(Some(172_800_000)), + // 6 days + ScalarValue::IntervalDayTime(Some(518_400_000)), + )?, + Interval::try_new( + // 10.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None), + // 14.10.2020 - 10:11:12 AM + ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None), + )? + ], + result + ); + + Ok(()) + } + + #[test] + fn test_propagate_comparison() -> Result<()> { + // In the examples below: + // `left` is unbounded: [?, ?], + // `right` is known to be [1000,1000] + // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999] + let left = Interval::make_unbounded(&DataType::Int64)?; + let right = Interval::make(Some(1000_i64), Some(1000_i64))?; + assert_eq!( + (Some(( + Interval::make(None, Some(999_i64))?, + Interval::make(Some(1000_i64), Some(1000_i64))?, + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? + ); + + let left = + Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )?; + assert_eq!( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( + TimeUnit::Nanosecond, + None + )) + .unwrap(), + ScalarValue::TimestampNanosecond(Some(999), None), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(1000), None), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? + ); + + let left = Interval::make_unbounded(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + ))?; + let right = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )?; + assert_eq!( + (Some(( + Interval::try_new( + ScalarValue::try_from(&DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+05:00".into()), + )) + .unwrap(), + ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())), + )?, + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())), + )? + ))), + propagate_comparison( + &Operator::Lt, + &Interval::CERTAINLY_TRUE, + &left, + &right + )? + ); + + Ok(()) + } + + #[test] + fn test_propagate_or() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE], + vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE], + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + ]; + for children in children_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE], + ); + } + + let parent = Interval::CERTAINLY_FALSE; + let children_set = vec![ + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + ]; + for children in children_set { + assert_eq!(expr.propagate_constraints(&parent, &children)?, None,); + } + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE] + ); + + let parent = Interval::CERTAINLY_TRUE; + let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN]; + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + // Empty means unchanged intervals. + vec![] + ); + + Ok(()) + } + + #[test] + fn test_propagate_certainly_false_and() -> Result<()> { + let expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + Arc::new(Column::new("b", 1)), + )); + let parent = Interval::CERTAINLY_FALSE; + let children_and_results_set = vec![ + ( + vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN], + vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE], + vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE], + ), + ( + vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN], + // Empty means unchanged intervals. + vec![], + ), + ( + vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN], + vec![], + ), + ]; + for (children, result) in children_and_results_set { + assert_eq!( + expr.propagate_constraints(&parent, &children)?.unwrap(), + result + ); + } + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs deleted file mode 100644 index 3e2b4697a11b5..0000000000000 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ /dev/null @@ -1,1175 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Interval arithmetic library - -use std::borrow::Borrow; -use std::fmt; -use std::fmt::{Display, Formatter}; - -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::type_coercion::binary::get_result_type; -use datafusion_expr::Operator; - -use crate::aggregate::min_max::{max, min}; -use crate::intervals::rounding::alter_fp_rounding_mode; - -/// This type represents a single endpoint of an [`Interval`]. An endpoint can -/// be open or closed, denoting whether the interval includes or excludes the -/// endpoint itself. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct IntervalBound { - pub value: ScalarValue, - pub open: bool, -} - -impl IntervalBound { - /// Creates a new `IntervalBound` object using the given value. - pub const fn new(value: ScalarValue, open: bool) -> IntervalBound { - IntervalBound { value, open } - } - - /// This convenience function creates an unbounded interval endpoint. - pub fn make_unbounded>(data_type: T) -> Result { - ScalarValue::try_from(data_type.borrow()).map(|v| IntervalBound::new(v, true)) - } - - /// This convenience function returns the data type associated with this - /// `IntervalBound`. - pub fn get_datatype(&self) -> DataType { - self.value.get_datatype() - } - - /// This convenience function checks whether the `IntervalBound` represents - /// an unbounded interval endpoint. - pub fn is_unbounded(&self) -> bool { - self.value.is_null() - } - - /// This function casts the `IntervalBound` to the given data type. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - cast_scalar_value(&self.value, data_type, cast_options) - .map(|value| IntervalBound::new(value, self.open)) - } - - /// This function adds the given `IntervalBound` to this `IntervalBound`. - /// The result is unbounded if either is; otherwise, their values are - /// added. The result is closed if both original bounds are closed, or open - /// otherwise. - pub fn add>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Plus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.add(rhs) - }) - } - _ => self.value.add(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function subtracts the given `IntervalBound` from `self`. - /// The result is unbounded if either is; otherwise, their values are - /// subtracted. The result is closed if both original bounds are closed, - /// or open otherwise. - pub fn sub>( - &self, - other: T, - ) -> Result { - let rhs = other.borrow(); - if self.is_unbounded() || rhs.is_unbounded() { - return IntervalBound::make_unbounded(get_result_type( - &self.get_datatype(), - &Operator::Minus, - &rhs.get_datatype(), - )?); - } - match self.get_datatype() { - DataType::Float64 | DataType::Float32 => { - alter_fp_rounding_mode::(&self.value, &rhs.value, |lhs, rhs| { - lhs.sub(rhs) - }) - } - _ => self.value.sub(&rhs.value), - } - .map(|v| IntervalBound::new(v, self.open || rhs.open)) - } - - /// This function chooses one of the given `IntervalBound`s according to - /// the given function `decide`. The result is unbounded if both are. If - /// only one of the arguments is unbounded, the other one is chosen by - /// default. If neither is unbounded, the function `decide` is used. - pub fn choose( - first: &IntervalBound, - second: &IntervalBound, - decide: fn(&ScalarValue, &ScalarValue) -> Result, - ) -> Result { - Ok(if first.is_unbounded() { - second.clone() - } else if second.is_unbounded() { - first.clone() - } else if first.value != second.value { - let chosen = decide(&first.value, &second.value)?; - if chosen.eq(&first.value) { - first.clone() - } else { - second.clone() - } - } else { - IntervalBound::new(second.value.clone(), first.open || second.open) - }) - } -} - -impl Display for IntervalBound { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "IntervalBound [{}]", self.value) - } -} - -/// This type represents an interval, which is used to calculate reliable -/// bounds for expressions. Currently, we only support addition and -/// subtraction, but more capabilities will be added in the future. -/// Upper/lower bounds having NULL values indicate an unbounded side. For -/// example; [10, 20], [10, ∞), (-∞, 100] and (-∞, ∞) are all valid intervals. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Interval { - pub lower: IntervalBound, - pub upper: IntervalBound, -} - -impl Default for Interval { - fn default() -> Self { - Interval::new( - IntervalBound::new(ScalarValue::Null, true), - IntervalBound::new(ScalarValue::Null, true), - ) - } -} - -impl Display for Interval { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Interval [{}, {}]", self.lower, self.upper) - } -} - -impl Interval { - /// Creates a new interval object using the given bounds. - /// For boolean intervals, having an open false lower bound is equivalent - /// to having a true closed lower bound. Similarly, open true upper bound - /// is equivalent to having a false closed upper bound. Also for boolean - /// intervals, having an unbounded left endpoint is equivalent to having a - /// false closed lower bound, while having an unbounded right endpoint is - /// equivalent to having a true closed upper bound. Therefore; input - /// parameters to construct an Interval can have different types, but they - /// all result in [false, false], [false, true] or [true, true]. - pub fn new(lower: IntervalBound, upper: IntervalBound) -> Interval { - // Boolean intervals need a special handling. - if let ScalarValue::Boolean(_) = lower.value { - let standardized_lower = match lower.value { - ScalarValue::Boolean(None) if lower.open => { - ScalarValue::Boolean(Some(false)) - } - ScalarValue::Boolean(Some(false)) if lower.open => { - ScalarValue::Boolean(Some(true)) - } - // The rest may include some invalid interval cases. The validation of - // interval construction parameters will be implemented later. - // For now, let's return them unchanged. - _ => lower.value, - }; - let standardized_upper = match upper.value { - ScalarValue::Boolean(None) if upper.open => { - ScalarValue::Boolean(Some(true)) - } - ScalarValue::Boolean(Some(true)) if upper.open => { - ScalarValue::Boolean(Some(false)) - } - _ => upper.value, - }; - Interval { - lower: IntervalBound::new(standardized_lower, false), - upper: IntervalBound::new(standardized_upper, false), - } - } else { - Interval { lower, upper } - } - } - - pub fn make(lower: Option, upper: Option, open: (bool, bool)) -> Interval - where - ScalarValue: From>, - { - Interval::new( - IntervalBound::new(ScalarValue::from(lower), open.0), - IntervalBound::new(ScalarValue::from(upper), open.1), - ) - } - - /// Casts this interval to `data_type` using `cast_options`. - pub(crate) fn cast_to( - &self, - data_type: &DataType, - cast_options: &CastOptions, - ) -> Result { - let lower = self.lower.cast_to(data_type, cast_options)?; - let upper = self.upper.cast_to(data_type, cast_options)?; - Ok(Interval::new(lower, upper)) - } - - /// This function returns the data type of this interval. If both endpoints - /// do not have the same data type, returns an error. - pub(crate) fn get_datatype(&self) -> Result { - let lower_type = self.lower.get_datatype(); - let upper_type = self.upper.get_datatype(); - if lower_type == upper_type { - Ok(lower_type) - } else { - Err(DataFusionError::Internal(format!( - "Interval bounds have different types: {lower_type} != {upper_type}", - ))) - } - } - - /// Decide if this interval is certainly greater than, possibly greater than, - /// or can't be greater than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - { - // Values in this interval are certainly less than or equal to those - // in the given interval. - (false, false) - } else if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - && (self.lower.value > rhs.upper.value || self.lower.open || rhs.upper.open) - { - // Values in this interval are certainly greater than those in the - // given interval. - (true, true) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly greater than or equal to, possibly greater than - /// or equal to, or can't be greater than or equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn gt_eq>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value >= rhs.upper.value - { - // Values in this interval are certainly greater than or equal to those - // in the given interval. - (true, true) - } else if !self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value <= rhs.lower.value - && (self.upper.value < rhs.lower.value || self.upper.open || rhs.lower.open) - { - // Values in this interval are certainly less than those in the - // given interval. - (false, false) - } else { - // All outcomes are possible. - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Decide if this interval is certainly less than, possibly less than, - /// or can't be less than `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn lt>(&self, other: T) -> Interval { - other.borrow().gt(self) - } - - /// Decide if this interval is certainly less than or equal to, possibly - /// less than or equal to, or can't be less than or equal to `other` by returning - /// [true, true], [false, true] or [false, false] respectively. - pub(crate) fn lt_eq>(&self, other: T) -> Interval { - other.borrow().gt_eq(self) - } - - /// Decide if this interval is certainly equal to, possibly equal to, - /// or can't be equal to `other` by returning [true, true], - /// [false, true] or [false, false] respectively. - pub(crate) fn equal>(&self, other: T) -> Interval { - let rhs = other.borrow(); - let flags = if !self.lower.is_unbounded() - && (self.lower.value == self.upper.value) - && (rhs.lower.value == rhs.upper.value) - && (self.lower.value == rhs.lower.value) - { - (true, true) - } else if self.gt(rhs) == Interval::CERTAINLY_TRUE - || self.lt(rhs) == Interval::CERTAINLY_TRUE - { - (false, false) - } else { - (false, true) - }; - - Interval::make(Some(flags.0), Some(flags.1), (false, false)) - } - - /// Compute the logical conjunction of this (boolean) interval with the given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { - let rhs = other.borrow(); - match ( - &self.lower.value, - &self.upper.value, - &rhs.lower.value, - &rhs.upper.value, - ) { - ( - ScalarValue::Boolean(Some(self_lower)), - ScalarValue::Boolean(Some(self_upper)), - ScalarValue::Boolean(Some(other_lower)), - ScalarValue::Boolean(Some(other_upper)), - ) => { - let lower = *self_lower && *other_lower; - let upper = *self_upper && *other_upper; - - Ok(Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(lower)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(upper)), false), - }) - } - _ => Err(DataFusionError::Internal( - "Incompatible types for logical conjunction".to_string(), - )), - } - } - - /// Compute the intersection of the interval with the given interval. - /// If the intersection is empty, return None. - pub(crate) fn intersect>( - &self, - other: T, - ) -> Result> { - let rhs = other.borrow(); - // If it is evident that the result is an empty interval, - // do not make any calculation and directly return None. - if (!self.lower.is_unbounded() - && !rhs.upper.is_unbounded() - && self.lower.value > rhs.upper.value) - || (!self.upper.is_unbounded() - && !rhs.lower.is_unbounded() - && self.upper.value < rhs.lower.value) - { - // This None value signals an empty interval. - return Ok(None); - } - - let lower = IntervalBound::choose(&self.lower, &rhs.lower, max)?; - let upper = IntervalBound::choose(&self.upper, &rhs.upper, min)?; - - let non_empty = lower.is_unbounded() - || upper.is_unbounded() - || lower.value != upper.value - || (!lower.open && !upper.open); - Ok(non_empty.then_some(Interval::new(lower, upper))) - } - - /// Add the given interval (`other`) to this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. - /// Note that this represents all possible values the sum can take if - /// one can choose single values arbitrarily from each of the operands. - pub fn add>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.add::(&rhs.lower)?, - self.upper.add::(&rhs.upper)?, - )) - } - - /// Subtract the given interval (`other`) from this interval. Say we have - /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. - /// Note that this represents all possible values the difference can take - /// if one can choose single values arbitrarily from each of the operands. - pub fn sub>(&self, other: T) -> Result { - let rhs = other.borrow(); - Ok(Interval::new( - self.lower.sub::(&rhs.upper)?, - self.upper.sub::(&rhs.lower)?, - )) - } - - pub const CERTAINLY_FALSE: Interval = Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(false)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(false)), false), - }; - - pub const UNCERTAIN: Interval = Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(false)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(true)), false), - }; - - pub const CERTAINLY_TRUE: Interval = Interval { - lower: IntervalBound::new(ScalarValue::Boolean(Some(true)), false), - upper: IntervalBound::new(ScalarValue::Boolean(Some(true)), false), - }; -} - -/// Indicates whether interval arithmetic is supported for the given operator. -pub fn is_operator_supported(op: &Operator) -> bool { - matches!( - op, - &Operator::Plus - | &Operator::Minus - | &Operator::And - | &Operator::Gt - | &Operator::GtEq - | &Operator::Lt - | &Operator::LtEq - ) -} - -/// Indicates whether interval arithmetic is supported for the given data type. -pub fn is_datatype_supported(data_type: &DataType) -> bool { - matches!( - data_type, - &DataType::Int64 - | &DataType::Int32 - | &DataType::Int16 - | &DataType::Int8 - | &DataType::UInt64 - | &DataType::UInt32 - | &DataType::UInt16 - | &DataType::UInt8 - | &DataType::Float64 - | &DataType::Float32 - ) -} - -pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { - match *op { - Operator::Eq => Ok(lhs.equal(rhs)), - Operator::Gt => Ok(lhs.gt(rhs)), - Operator::GtEq => Ok(lhs.gt_eq(rhs)), - Operator::Lt => Ok(lhs.lt(rhs)), - Operator::LtEq => Ok(lhs.lt_eq(rhs)), - Operator::And => lhs.and(rhs), - Operator::Plus => lhs.add(rhs), - Operator::Minus => lhs.sub(rhs), - _ => Ok(Interval::default()), - } -} - -/// Cast scalar value to the given data type using an arrow kernel. -fn cast_scalar_value( - value: &ScalarValue, - data_type: &DataType, - cast_options: &CastOptions, -) -> Result { - let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; - ScalarValue::try_from_array(&cast_array, 0) -} - -#[cfg(test)] -mod tests { - use crate::intervals::{Interval, IntervalBound}; - use datafusion_common::{Result, ScalarValue}; - use ScalarValue::Boolean; - - fn open_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, true)) - } - - fn open_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (true, false)) - } - - fn closed_open(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, true)) - } - - fn closed_closed(lower: Option, upper: Option) -> Interval - where - ScalarValue: From>, - { - Interval::make(lower, upper, (false, false)) - } - - #[test] - fn intersect_test() -> Result<()> { - let possible_cases = vec![ - (Some(1000_i64), None, None, None, Some(1000_i64), None), - (None, Some(1000_i64), None, None, None, Some(1000_i64)), - (None, None, Some(1000_i64), None, Some(1000_i64), None), - (None, None, None, Some(1000_i64), None, Some(1000_i64)), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(1000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - Some(999_i64), - Some(1000_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in possible_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - Some(open_open(case.4, case.5)) - ) - } - - let empty_cases = vec![ - (None, Some(1000_i64), Some(1001_i64), None), - (Some(1001_i64), None, None, Some(1000_i64)), - (None, Some(1000_i64), Some(1001_i64), Some(1002_i64)), - (Some(1001_i64), Some(1002_i64), None, Some(1000_i64)), - ]; - - for case in empty_cases { - assert_eq!( - open_open(case.0, case.1).intersect(open_open(case.2, case.3))?, - None - ) - } - - Ok(()) - } - - #[test] - fn gt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, false, false), - (None, Some(1000_i64), Some(1001_i64), None, false, false), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - false, - false, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - ( - Some(1002_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - ( - Some(1003_i64), - None, - Some(999_i64), - Some(1002_i64), - true, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).gt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn lt_test() { - let cases = vec![ - (Some(1000_i64), None, None, None, false, true), - (None, Some(1000_i64), None, None, false, true), - (None, None, Some(1000_i64), None, false, true), - (None, None, None, Some(1000_i64), false, true), - (None, Some(1000_i64), Some(1000_i64), None, true, true), - (None, Some(1000_i64), Some(1001_i64), None, true, true), - (Some(1000_i64), None, Some(1000_i64), None, false, true), - ( - None, - Some(1000_i64), - Some(1001_i64), - Some(1002_i64), - true, - true, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - false, - true, - ), - (None, None, None, None, false, true), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).lt(open_open(case.2, case.3)), - closed_closed(Some(case.4), Some(case.5)) - ); - } - } - - #[test] - fn and_test() -> Result<()> { - let cases = vec![ - (false, true, false, false, false, false), - (false, false, false, true, false, false), - (false, true, false, true, false, true), - (false, true, true, true, false, true), - (false, false, false, false, false, false), - (true, true, true, true, true, true), - ]; - - for case in cases { - assert_eq!( - open_open(Some(case.0), Some(case.1)) - .and(open_open(Some(case.2), Some(case.3)))?, - open_open(Some(case.4), Some(case.5)) - ); - } - Ok(()) - } - - #[test] - fn add_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - ( - Some(1000_i64), - None, - Some(1000_i64), - None, - Some(2000_i64), - None, - ), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(2002_i64), - ), - (None, Some(1000_i64), Some(1000_i64), None, None, None), - ( - Some(2001_i64), - Some(1_i64), - Some(1005_i64), - Some(-999_i64), - Some(3006_i64), - Some(-998_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).add(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test() -> Result<()> { - let cases = vec![ - (Some(1000_i64), None, None, None, None, None), - (None, Some(1000_i64), None, None, None, None), - (None, None, Some(1000_i64), None, None, None), - (None, None, None, Some(1000_i64), None, None), - (Some(1000_i64), None, Some(1000_i64), None, None, None), - ( - None, - Some(1000_i64), - Some(999_i64), - Some(1002_i64), - None, - Some(1_i64), - ), - ( - None, - Some(1000_i64), - Some(1000_i64), - None, - None, - Some(0_i64), - ), - ( - Some(2001_i64), - Some(1000_i64), - Some(1005), - Some(999_i64), - Some(1002_i64), - Some(-5_i64), - ), - (None, None, None, None, None, None), - ]; - - for case in cases { - assert_eq!( - open_open(case.0, case.1).sub(open_open(case.2, case.3))?, - open_open(case.4, case.5) - ); - } - Ok(()) - } - - #[test] - fn sub_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - closed_open(Some(200_i64), None), - open_closed(None, Some(0_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_closed(Some(300_i64), Some(150_i64)), - closed_open(Some(-50_i64), Some(-100_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(None, Some(0_i64)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(-10_i64), Some(-10_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.sub(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn add_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(200_i64)), - open_closed(None, Some(400_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - closed_open(Some(-300_i64), Some(150_i64)), - closed_open(Some(-200_i64), Some(350_i64)), - ), - ( - closed_open(Some(100_i64), Some(200_i64)), - open_open(Some(200_i64), None), - open_open(Some(300_i64), None), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_closed(Some(11_i64), Some(11_i64)), - closed_closed(Some(12_i64), Some(12_i64)), - ), - ]; - for case in cases { - assert_eq!(case.0.add(case.1)?, case.2) - } - Ok(()) - } - - #[test] - fn lt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt(case.1), case.2) - } - Ok(()) - } - - #[test] - fn lt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ]; - for case in cases { - assert_eq!(case.0.lt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn gt_eq_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(true), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - closed_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(true)), - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - closed_closed(Some(false), Some(false)), - ), - ]; - for case in cases { - assert_eq!(case.0.gt_eq(case.1), case.2) - } - Ok(()) - } - - #[test] - fn intersect_test_various_bounds() -> Result<()> { - let cases = vec![ - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_closed(None, Some(100_i64)), - Some(closed_closed(Some(100_i64), Some(100_i64))), - ), - ( - closed_closed(Some(100_i64), Some(200_i64)), - open_open(None, Some(100_i64)), - None, - ), - ( - open_open(Some(100_i64), Some(200_i64)), - closed_closed(Some(0_i64), Some(100_i64)), - None, - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_closed(Some(1_i64), Some(2_i64)), - Some(closed_closed(Some(2_i64), Some(2_i64))), - ), - ( - closed_closed(Some(2_i64), Some(2_i64)), - closed_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(1_i64)), - open_open(Some(1_i64), Some(2_i64)), - None, - ), - ( - closed_closed(Some(1_i64), Some(3_i64)), - open_open(Some(1_i64), Some(2_i64)), - Some(open_open(Some(1_i64), Some(2_i64))), - ), - ]; - for case in cases { - assert_eq!(case.0.intersect(case.1)?, case.2) - } - Ok(()) - } - - // This function tests if valid constructions produce standardized objects - // ([false, false], [false, true], [true, true]) for boolean intervals. - #[test] - fn non_standard_interval_constructs() { - let cases = vec![ - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), false), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(Some(true)), true), - closed_closed(Some(false), Some(false)), - ), - ( - IntervalBound::new(Boolean(Some(false)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(true)), false), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ( - IntervalBound::new(Boolean(None), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(false), Some(true)), - ), - ( - IntervalBound::new(Boolean(Some(false)), true), - IntervalBound::new(Boolean(None), true), - closed_closed(Some(true), Some(true)), - ), - ]; - - for case in cases { - assert_eq!(Interval::new(case.0, case.1), case.2) - } - } - - macro_rules! capture_mode_change { - ($TYPE:ty) => { - paste::item! { - capture_mode_change_helper!([], - [], - $TYPE); - } - }; - } - - macro_rules! capture_mode_change_helper { - ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => { - fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval { - Interval::make(Some(lower as $TYPE), Some(upper as $TYPE), (true, true)) - } - - fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool, expect_high: bool) { - assert!(expect_low || expect_high); - let interval1 = $CREATE_FN_NAME(input.0, input.0); - let interval2 = $CREATE_FN_NAME(input.1, input.1); - let result = interval1.add(&interval2).unwrap(); - let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 + input.1); - assert!( - (!expect_low || result.lower.value < without_fe.lower.value) - && (!expect_high || result.upper.value > without_fe.upper.value) - ); - } - }; - } - - capture_mode_change!(f32); - capture_mode_change!(f64); - - #[cfg(all( - any(target_arch = "x86_64", target_arch = "aarch64"), - not(target_os = "windows") - ))] - #[test] - fn test_add_intervals_lower_affected_f32() { - // Lower is affected - let lower = f32::from_bits(1073741887); //1000000000000000000000000111111 - let upper = f32::from_bits(1098907651); //1000001100000000000000000000011 - capture_mode_change_f32((lower, upper), true, false); - - // Upper is affected - let lower = f32::from_bits(1072693248); //111111111100000000000000000000 - let upper = f32::from_bits(715827883); //101010101010101010101010101011 - capture_mode_change_f32((lower, upper), false, true); - - // Lower is affected - let lower = 1.0; // 0x3FF0000000000000 - let upper = 0.3; // 0x3FD3333333333333 - capture_mode_change_f64((lower, upper), true, false); - - // Upper is affected - let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF - let upper = 0.000_000_000_000_000_022_044_604_925_031_31; // 0x3C796A6B413BB21F - capture_mode_change_f64((lower, upper), false, true); - } - - #[cfg(any( - not(any(target_arch = "x86_64", target_arch = "aarch64")), - target_os = "windows" - ))] - #[test] - fn test_next_impl_add_intervals_f64() { - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f64((lower, upper), true, true); - - let lower = 1.5; - let upper = 1.5; - capture_mode_change_f32((lower, upper), true, true); - } -} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs index a9255752fea44..9752ca27b5a38 100644 --- a/datafusion/physical-expr/src/intervals/mod.rs +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -18,9 +18,5 @@ //! Interval arithmetic and constraint propagation library pub mod cp_solver; -pub mod interval_aritmetic; -pub mod rounding; - pub mod test_utils; -pub use cp_solver::{check_support, ExprIntervalGraph}; -pub use interval_aritmetic::*; +pub mod utils; diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index 8e695c2556965..075b8240353d2 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use crate::expressions::{date_time_interval_expr, BinaryExpr, Literal}; +use crate::expressions::{binary, BinaryExpr, Literal}; use crate::PhysicalExpr; use arrow_schema::Schema; use datafusion_common::{DataFusionError, ScalarValue}; @@ -78,22 +78,10 @@ pub fn gen_conjunctive_temporal_expr( d: ScalarValue, schema: &Schema, ) -> Result, DataFusionError> { - let left_and_1 = date_time_interval_expr( - left_col.clone(), - op_1, - Arc::new(Literal::new(a)), - schema, - )?; - let left_and_2 = date_time_interval_expr( - right_col.clone(), - op_2, - Arc::new(Literal::new(b)), - schema, - )?; - let right_and_1 = - date_time_interval_expr(left_col, op_3, Arc::new(Literal::new(c)), schema)?; - let right_and_2 = - date_time_interval_expr(right_col, op_4, Arc::new(Literal::new(d)), schema)?; + let left_and_1 = binary(left_col.clone(), op_1, Arc::new(Literal::new(a)), schema)?; + let left_and_2 = binary(right_col.clone(), op_2, Arc::new(Literal::new(b)), schema)?; + let right_and_1 = binary(left_col, op_3, Arc::new(Literal::new(c)), schema)?; + let right_and_2 = binary(right_col, op_4, Arc::new(Literal::new(d)), schema)?; let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); Ok(Arc::new(BinaryExpr::new( diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs new file mode 100644 index 0000000000000..03d13632104dd --- /dev/null +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions for the interval arithmetic library + +use std::sync::Arc; + +use crate::{ + expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, + PhysicalExpr, +}; + +use arrow_schema::{DataType, SchemaRef}; +use datafusion_common::{ + internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::Operator; + +const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000; +const MDN_NS_MASK: i128 = 0xFFFF_FFFF_FFFF_FFFF; +const DT_MS_MASK: i64 = 0xFFFF_FFFF; + +/// Indicates whether interval arithmetic is supported for the given expression. +/// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. +/// We do not support every type of [`Operator`]s either. Over time, this check +/// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. +/// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. +pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { + let expr_any = expr.as_any(); + if let Some(binary_expr) = expr_any.downcast_ref::() { + is_operator_supported(binary_expr.op()) + && check_support(binary_expr.left(), schema) + && check_support(binary_expr.right(), schema) + } else if let Some(column) = expr_any.downcast_ref::() { + if let Ok(field) = schema.field_with_name(column.name()) { + is_datatype_supported(field.data_type()) + } else { + return false; + } + } else if let Some(literal) = expr_any.downcast_ref::() { + if let Ok(dt) = literal.data_type(schema) { + is_datatype_supported(&dt) + } else { + return false; + } + } else if let Some(cast) = expr_any.downcast_ref::() { + check_support(cast.expr(), schema) + } else if let Some(negative) = expr_any.downcast_ref::() { + check_support(negative.arg(), schema) + } else { + false + } +} + +// This function returns the inverse operator of the given operator. +pub fn get_inverse_op(op: Operator) -> Result { + match op { + Operator::Plus => Ok(Operator::Minus), + Operator::Minus => Ok(Operator::Plus), + Operator::Multiply => Ok(Operator::Divide), + Operator::Divide => Ok(Operator::Multiply), + _ => internal_err!("Interval arithmetic does not support the operator {}", op), + } +} + +/// Indicates whether interval arithmetic is supported for the given operator. +pub fn is_operator_supported(op: &Operator) -> bool { + matches!( + op, + &Operator::Plus + | &Operator::Minus + | &Operator::And + | &Operator::Gt + | &Operator::GtEq + | &Operator::Lt + | &Operator::LtEq + | &Operator::Eq + | &Operator::Multiply + | &Operator::Divide + ) +} + +/// Indicates whether interval arithmetic is supported for the given data type. +pub fn is_datatype_supported(data_type: &DataType) -> bool { + matches!( + data_type, + &DataType::Int64 + | &DataType::Int32 + | &DataType::Int16 + | &DataType::Int8 + | &DataType::UInt64 + | &DataType::UInt32 + | &DataType::UInt16 + | &DataType::UInt8 + | &DataType::Float64 + | &DataType::Float32 + ) +} + +/// Converts an [`Interval`] of time intervals to one of `Duration`s, if applicable. Otherwise, returns [`None`]. +pub fn convert_interval_type_to_duration(interval: &Interval) -> Option { + if let (Some(lower), Some(upper)) = ( + convert_interval_bound_to_duration(interval.lower()), + convert_interval_bound_to_duration(interval.upper()), + ) { + Interval::try_new(lower, upper).ok() + } else { + None + } +} + +/// Converts an [`ScalarValue`] containing a time interval to one containing a `Duration`, if applicable. Otherwise, returns [`None`]. +fn convert_interval_bound_to_duration( + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::IntervalMonthDayNano(Some(mdn)) => interval_mdn_to_duration_ns(mdn) + .ok() + .map(|duration| ScalarValue::DurationNanosecond(Some(duration))), + ScalarValue::IntervalDayTime(Some(dt)) => interval_dt_to_duration_ms(dt) + .ok() + .map(|duration| ScalarValue::DurationMillisecond(Some(duration))), + _ => None, + } +} + +/// Converts an [`Interval`] of `Duration`s to one of time intervals, if applicable. Otherwise, returns [`None`]. +pub fn convert_duration_type_to_interval(interval: &Interval) -> Option { + if let (Some(lower), Some(upper)) = ( + convert_duration_bound_to_interval(interval.lower()), + convert_duration_bound_to_interval(interval.upper()), + ) { + Interval::try_new(lower, upper).ok() + } else { + None + } +} + +/// Converts a [`ScalarValue`] containing a `Duration` to one containing a time interval, if applicable. Otherwise, returns [`None`]. +fn convert_duration_bound_to_interval( + interval_bound: &ScalarValue, +) -> Option { + match interval_bound { + ScalarValue::DurationNanosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration)) + } + ScalarValue::DurationMicrosecond(Some(duration)) => { + Some(ScalarValue::new_interval_mdn(0, 0, *duration * 1000)) + } + ScalarValue::DurationMillisecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32)) + } + ScalarValue::DurationSecond(Some(duration)) => { + Some(ScalarValue::new_interval_dt(0, *duration as i32 * 1000)) + } + _ => None, + } +} + +/// If both the month and day fields of [`ScalarValue::IntervalMonthDayNano`] are zero, this function returns the nanoseconds part. +/// Otherwise, it returns an error. +fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { + let months = mdn >> 96; + let days = (mdn & MDN_DAY_MASK) >> 64; + let nanoseconds = mdn & MDN_NS_MASK; + + if months == 0 && days == 0 { + nanoseconds + .try_into() + .map_err(|_| internal_datafusion_err!("Resulting duration exceeds i64::MAX")) + } else { + internal_err!( + "The interval cannot have a non-zero month or day value for duration convertibility" + ) + } +} + +/// If the day field of the [`ScalarValue::IntervalDayTime`] is zero, this function returns the milliseconds part. +/// Otherwise, it returns an error. +fn interval_dt_to_duration_ms(dt: &i64) -> Result { + let days = dt >> 32; + let milliseconds = dt & DT_MS_MASK; + + if days == 0 { + Ok(milliseconds) + } else { + internal_err!( + "The interval cannot have a non-zero day value for duration convertibility" + ) + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 0a2e0e58df7a9..fffa8f602d875 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -16,24 +16,28 @@ // under the License. pub mod aggregate; +pub mod analysis; pub mod array_expressions; pub mod conditional_expressions; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod datetime_expressions; +#[cfg(feature = "encoding_expressions")] +pub mod encoding_expressions; pub mod equivalence; pub mod execution_props; pub mod expressions; pub mod functions; -pub mod hash_utils; pub mod intervals; pub mod math_expressions; +mod partitioning; mod physical_expr; pub mod planner; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; mod scalar_function; mod sort_expr; +pub mod sort_properties; pub mod string_expressions; pub mod struct_expressions; pub mod tree_node; @@ -44,21 +48,21 @@ pub mod utils; pub mod var_provider; pub mod window; -// reexport this to maintain compatibility with anything that used from_slice previously +pub use aggregate::groups_accumulator::{ + EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, +}; pub use aggregate::AggregateExpr; -pub use equivalence::{ - project_equivalence_properties, project_ordering_equivalence_properties, - EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties, - OrderingEquivalentClass, +pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; +pub use equivalence::EquivalenceProperties; +pub use partitioning::{Distribution, Partitioning}; +pub use physical_expr::{ + physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, + PhysicalExpr, PhysicalExprRef, }; -pub use physical_expr::{AnalysisContext, ExprBoundaries, PhysicalExpr, PhysicalExprRef}; pub use planner::create_physical_expr; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::{ - LexOrdering, LexOrderingReq, PhysicalSortExpr, PhysicalSortRequirement, -}; -pub use utils::{ - expr_list_eq_any_order, expr_list_eq_strict_order, - normalize_expr_with_equivalence_properties, normalize_out_expr_with_columns_map, - reverse_order_bys, split_conjunction, + LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, + PhysicalSortRequirement, }; +pub use utils::{reverse_order_bys, split_conjunction}; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index fbfb82814e0fe..af66862aecc5a 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -18,10 +18,15 @@ //! Math expressions use arrow::array::ArrayRef; -use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow::array::{ + BooleanArray, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, +}; use arrow::datatypes::DataType; +use arrow::error::ArrowError; use datafusion_common::ScalarValue; -use datafusion_common::ScalarValue::Float32; +use datafusion_common::ScalarValue::{Float32, Int64}; +use datafusion_common::{internal_err, not_impl_err}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use rand::{thread_rng, Rng}; @@ -30,6 +35,8 @@ use std::iter; use std::mem::swap; use std::sync::Arc; +type MathArrayFunction = fn(&[ArrayRef]) -> Result; + macro_rules! downcast_compute_op { ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); @@ -39,10 +46,7 @@ macro_rules! downcast_compute_op { arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); Ok(Arc::new(res)) } - _ => Err(DataFusionError::Internal(format!( - "Invalid data type for {}", - $NAME - ))), + _ => internal_err!("Invalid data type for {}", $NAME), } }}; } @@ -59,10 +63,11 @@ macro_rules! unary_primitive_array_op { let result = downcast_compute_op!(array, $NAME, $FUNC, Float64Array); Ok(ColumnarValue::Array(result?)) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "Unsupported data type {:?} for function {}", - other, $NAME, - ))), + other, + $NAME + ), }, ColumnarValue::Scalar(a) => match a { ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( @@ -71,11 +76,11 @@ macro_rules! unary_primitive_array_op { ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( ScalarValue::Float64(a.map(|x| x.$FUNC())), )), - _ => Err(DataFusionError::Internal(format!( + _ => internal_err!( "Unsupported data type {:?} for function {}", ($VALUE).data_type(), - $NAME, - ))), + $NAME + ), }, } }}; @@ -142,6 +147,19 @@ macro_rules! make_function_inputs2 { }}; } +macro_rules! make_function_scalar_inputs_return_type { + ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$RETURN_TYPE>() + }}; +} + math_unary_function!("sqrt", sqrt); math_unary_function!("cbrt", cbrt); math_unary_function!("sin", sin); @@ -158,7 +176,6 @@ math_unary_function!("acosh", acosh); math_unary_function!("atanh", atanh); math_unary_function!("floor", floor); math_unary_function!("ceil", ceil); -math_unary_function!("trunc", trunc); math_unary_function!("abs", abs); math_unary_function!("signum", signum); math_unary_function!("exp", exp); @@ -177,9 +194,7 @@ pub fn factorial(args: &[ArrayRef]) -> Result { Int64Array, { |value: i64| { (1..=value).product() } } )) as ArrayRef), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function factorial." - ))), + other => internal_err!("Unsupported data type {other:?} for function factorial."), } } @@ -226,9 +241,7 @@ pub fn gcd(args: &[ArrayRef]) -> Result { Int64Array, { compute_gcd } )) as ArrayRef), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function gcd" - ))), + other => internal_err!("Unsupported data type {other:?} for function gcd"), } } @@ -254,18 +267,105 @@ pub fn lcm(args: &[ArrayRef]) -> Result { Int64Array, { compute_lcm } )) as ArrayRef), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function lcm" - ))), + other => internal_err!("Unsupported data type {other:?} for function lcm"), + } +} + +/// Nanvl SQL function +pub fn nanvl(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => { + let compute_nanvl = |x: f64, y: f64| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float64Array, + { compute_nanvl } + )) as ArrayRef) + } + + DataType::Float32 => { + let compute_nanvl = |x: f32, y: f32| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float32Array, + { compute_nanvl } + )) as ArrayRef) + } + + other => internal_err!("Unsupported data type {other:?} for function nanvl"), + } +} + +/// Isnan SQL function +pub fn isnan(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { f64::is_nan } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { f32::is_nan } + )) as ArrayRef), + + other => internal_err!("Unsupported data type {other:?} for function isnan"), + } +} + +/// Iszero SQL function +pub fn iszero(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { |x: f64| { x == 0_f64 } } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { |x: f32| { x == 0_f32 } } + )) as ArrayRef), + + other => internal_err!("Unsupported data type {other:?} for function iszero"), } } /// Pi SQL function pub fn pi(args: &[ColumnarValue]) -> Result { if !matches!(&args[0], ColumnarValue::Array(_)) { - return Err(DataFusionError::Internal( - "Expect pi function to take no param".to_string(), - )); + return internal_err!("Expect pi function to take no param"); } let array = Float64Array::from_value(std::f64::consts::PI, 1); Ok(ColumnarValue::Array(Arc::new(array))) @@ -275,11 +375,7 @@ pub fn pi(args: &[ColumnarValue]) -> Result { pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { ColumnarValue::Array(array) => array.len(), - _ => { - return Err(DataFusionError::Internal( - "Expect random function to take no param".to_string(), - )) - } + _ => return internal_err!("Expect random function to take no param"), }; let mut rng = thread_rng(); let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); @@ -290,10 +386,10 @@ pub fn random(args: &[ColumnarValue]) -> Result { /// Round SQL function pub fn round(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "round function requires one or two arguments, got {}", args.len() - ))); + ); } let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); @@ -334,10 +430,9 @@ pub fn round(args: &[ArrayRef]) -> Result { } } )) as ArrayRef), - _ => Err(DataFusionError::Internal( + _ => internal_err!( "round function requires a scalar or array for decimal_places" - .to_string(), - )), + ), }, DataType::Float32 => match decimal_places { @@ -371,15 +466,12 @@ pub fn round(args: &[ArrayRef]) -> Result { } } )) as ArrayRef), - _ => Err(DataFusionError::Internal( + _ => internal_err!( "round function requires a scalar or array for decimal_places" - .to_string(), - )), + ), }, - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function round" - ))), + other => internal_err!("Unsupported data type {other:?} for function round"), } } @@ -404,9 +496,7 @@ pub fn power(args: &[ArrayRef]) -> Result { { i64::pow } )) as ArrayRef), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function power" - ))), + other => internal_err!("Unsupported data type {other:?} for function power"), } } @@ -431,9 +521,7 @@ pub fn atan2(args: &[ArrayRef]) -> Result { { f32::atan2 } )) as ArrayRef), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function atan2" - ))), + other => internal_err!("Unsupported data type {other:?} for function atan2"), } } @@ -466,9 +554,7 @@ pub fn log(args: &[ArrayRef]) -> Result { Float64Array, { f64::log } )) as ArrayRef), - _ => Err(DataFusionError::Internal( - "log function requires a scalar or array for base".to_string(), - )), + _ => internal_err!("log function requires a scalar or array for base"), }, DataType::Float32 => match base { @@ -486,30 +572,205 @@ pub fn log(args: &[ArrayRef]) -> Result { Float32Array, { f32::log } )) as ArrayRef), - _ => Err(DataFusionError::Internal( - "log function requires a scalar or array for base".to_string(), - )), + _ => internal_err!("log function requires a scalar or array for base"), }, - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function log" - ))), + other => internal_err!("Unsupported data type {other:?} for function log"), } } +///cot SQL function +pub fn cot(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float64Array, + { compute_cot64 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float32Array, + { compute_cot32 } + )) as ArrayRef), + + other => internal_err!("Unsupported data type {other:?} for function cot"), + } +} + +fn compute_cot32(x: f32) -> f32 { + let a = f32::tan(x); + 1.0 / a +} + +fn compute_cot64(x: f64) -> f64 { + let a = f64::tan(x); + 1.0 / a +} + +/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function +pub fn trunc(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return internal_err!( + "truncate function requires one or two arguments, got {}", + args.len() + ); + } + + //if only one arg then invoke toolchain trunc(num) and precision = 0 by default + //or then invoke the compute_truncate method to process precision + let num = &args[0]; + let precision = if args.len() == 1 { + ColumnarValue::Scalar(Int64(Some(0))) + } else { + ColumnarValue::Array(args[1].clone()) + }; + + match args[0].data_type() { + DataType::Float64 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float64Array, + Int64Array, + { compute_truncate64 } + )) as ArrayRef), + _ => internal_err!("trunc function requires a scalar or array for precision"), + }, + DataType::Float32 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float32Array, + Int64Array, + { compute_truncate32 } + )) as ArrayRef), + _ => internal_err!("trunc function requires a scalar or array for precision"), + }, + other => internal_err!("Unsupported data type {other:?} for function trunc"), + } +} + +fn compute_truncate32(x: f32, y: i64) -> f32 { + let factor = 10.0_f32.powi(y as i32); + (x * factor).round() / factor +} + +fn compute_truncate64(x: f64, y: i64) -> f64 { + let factor = 10.0_f64.powi(y as i32); + (x * factor).round() / factor +} + +macro_rules! make_abs_function { + ($ARRAY_TYPE:ident) => {{ + |args: &[ArrayRef]| { + let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let res: $ARRAY_TYPE = array.unary(|x| x.abs()); + Ok(Arc::new(res) as ArrayRef) + } + }}; +} + +macro_rules! make_try_abs_function { + ($ARRAY_TYPE:ident) => {{ + |args: &[ArrayRef]| { + let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let res: $ARRAY_TYPE = array.try_unary(|x| { + x.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({})", + stringify!($ARRAY_TYPE), + x + )) + }) + })?; + Ok(Arc::new(res) as ArrayRef) + } + }}; +} + +macro_rules! make_decimal_abs_function { + ($ARRAY_TYPE:ident) => {{ + |args: &[ArrayRef]| { + let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let res: $ARRAY_TYPE = array + .unary(|x| x.wrapping_abs()) + .with_data_type(args[0].data_type().clone()); + Ok(Arc::new(res) as ArrayRef) + } + }}; +} + +/// Abs SQL function +/// Return different implementations based on input datatype to reduce branches during execution +pub(super) fn create_abs_function( + input_data_type: &DataType, +) -> Result { + match input_data_type { + DataType::Float32 => Ok(make_abs_function!(Float32Array)), + DataType::Float64 => Ok(make_abs_function!(Float64Array)), + + // Types that may overflow, such as abs(-128_i8). + DataType::Int8 => Ok(make_try_abs_function!(Int8Array)), + DataType::Int16 => Ok(make_try_abs_function!(Int16Array)), + DataType::Int32 => Ok(make_try_abs_function!(Int32Array)), + DataType::Int64 => Ok(make_try_abs_function!(Int64Array)), + + // Types of results are the same as the input. + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())), + + // Decimal types + DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), + DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), + + other => not_impl_err!("Unsupported data type {other:?} for function abs"), + } +} + +/// abs() SQL function implementation +pub fn abs_invoke(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return internal_err!("abs function requires 1 argument, got {}", args.len()); + } + + let input_data_type = args[0].data_type(); + let abs_fun = create_abs_function(input_data_type)?; + + abs_fun(args) +} + #[cfg(test)] mod tests { use super::*; use arrow::array::{Float64Array, NullArray}; - use datafusion_common::cast::{as_float32_array, as_float64_array, as_int64_array}; + use datafusion_common::cast::{ + as_boolean_array, as_float32_array, as_float64_array, as_int64_array, + }; #[test] fn test_random_expression() { let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; let array = random(&args) .expect("failed to initialize function random") - .into_array(1); + .into_array(1) + .expect("Failed to convert to array"); let floats = as_float64_array(&array).expect("failed to initialize function random"); @@ -739,4 +1000,230 @@ mod tests { assert_eq!(ints.value(2), 75); assert_eq!(ints.value(3), 16); } + + #[test] + fn test_cot_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float32_array(&result).expect("failed to initialize function cot"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + + #[test] + fn test_cot_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float64_array(&result).expect("failed to initialize function cot"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + + #[test] + fn test_truncate_32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![ + 15.0, + 1_234.267_8, + 1_233.123_4, + 3.312_979_2, + -21.123_4, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float32_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 15.0); + assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(2), 1_233.12); + assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(4), -21.123_4); + } + + #[test] + fn test_truncate_64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812_176, + 123.123_456_789, + 123.312_979_313_2, + -321.123_1, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(2), 123.12); + assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(4), -321.123_1); + } + + #[test] + fn test_truncate_64_one_arg() { + let args: Vec = vec![Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812, + 123.123_45, + 123.312_979_313_2, + -321.123, + ]))]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.0); + assert_eq!(floats.value(2), 123.0); + assert_eq!(floats.value(3), 123.0); + assert_eq!(floats.value(4), -321.0); + } + + #[test] + fn test_nanvl_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function nanvl"); + let floats = + as_float64_array(&result).expect("failed to initialize function nanvl"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } + + #[test] + fn test_nanvl_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function nanvl"); + let floats = + as_float32_array(&result).expect("failed to initialize function nanvl"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } + + #[test] + fn test_isnan_f64() { + let args: Vec = vec![Arc::new(Float64Array::from(vec![ + 1.0, + f64::NAN, + 3.0, + -f64::NAN, + ]))]; + + let result = isnan(&args).expect("failed to initialize function isnan"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function isnan"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_isnan_f32() { + let args: Vec = vec![Arc::new(Float32Array::from(vec![ + 1.0, + f32::NAN, + 3.0, + f32::NAN, + ]))]; + + let result = isnan(&args).expect("failed to initialize function isnan"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function isnan"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } } diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs new file mode 100644 index 0000000000000..301f12e9aa2ea --- /dev/null +++ b/datafusion/physical-expr/src/partitioning.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Partitioning`] and [`Distribution`] for `ExecutionPlans` + +use std::fmt; +use std::sync::Arc; + +use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; + +/// Output partitioning supported by [`ExecutionPlan`]s. +/// +/// When `executed`, `ExecutionPlan`s produce one or more independent stream of +/// data batches in parallel, referred to as partitions. The streams are Rust +/// `async` [`Stream`]s (a special kind of future). The number of output +/// partitions varies based on the input and the operation performed. +/// +/// For example, an `ExecutionPlan` that has output partitioning of 3 will +/// produce 3 distinct output streams as the result of calling +/// `ExecutionPlan::execute(0)`, `ExecutionPlan::execute(1)`, and +/// `ExecutionPlan::execute(2)`, as shown below: +/// +/// ```text +/// ... ... ... +/// ... ▲ ▲ ▲ +/// │ │ │ +/// ▲ │ │ │ +/// │ │ │ │ +/// │ ┌───┴────┐ ┌───┴────┐ ┌───┴────┐ +/// ┌────────────────────┐ │ Stream │ │ Stream │ │ Stream │ +/// │ ExecutionPlan │ │ (0) │ │ (1) │ │ (2) │ +/// └────────────────────┘ └────────┘ └────────┘ └────────┘ +/// ▲ ▲ ▲ ▲ +/// │ │ │ │ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ │ │ +/// Input │ │ │ │ +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │ │ │ +/// ▲ ┌ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ┌ ─ ─ ─ ─ +/// │ Input │ Input │ Input │ +/// │ │ Stream │ Stream │ Stream +/// (0) │ (1) │ (2) │ +/// ... └ ─ ▲ ─ ─ └ ─ ▲ ─ ─ └ ─ ▲ ─ ─ +/// │ │ │ +/// │ │ │ +/// │ │ │ +/// +/// ExecutionPlan with 1 input 3 (async) streams, one for each +/// that has 3 partitions, which itself output partition +/// has 3 output partitions +/// ``` +/// +/// It is common (but not required) that an `ExecutionPlan` has the same number +/// of input partitions as output partitions. However, some plans have different +/// numbers such as the `RepartitionExec` that redistributes batches from some +/// number of inputs to some number of outputs +/// +/// ```text +/// ... ... ... ... +/// +/// ▲ ▲ ▲ +/// ▲ │ │ │ +/// │ │ │ │ +/// ┌────────┴───────────┐ │ │ │ +/// │ RepartitionExec │ ┌────┴───┐ ┌────┴───┐ ┌────┴───┐ +/// └────────────────────┘ │ Stream │ │ Stream │ │ Stream │ +/// ▲ │ (0) │ │ (1) │ │ (2) │ +/// │ └────────┘ └────────┘ └────────┘ +/// │ ▲ ▲ ▲ +/// ... │ │ │ +/// └──────────┐│┌──────────┘ +/// │││ +/// │││ +/// RepartitionExec with one input +/// that has 3 partitions, but 3 (async) streams, that internally +/// itself has only 1 output partition pull from the same input stream +/// ... +/// ``` +/// +/// # Additional Examples +/// +/// A simple `FileScanExec` might produce one output stream (partition) for each +/// file (note the actual DataFusion file scaners can read individual files in +/// parallel, potentially producing multiple partitions per file) +/// +/// Plans such as `SortPreservingMerge` produce a single output stream +/// (1 output partition) by combining some number of input streams (input partitions) +/// +/// Plans such as `FilterExec` produce the same number of output streams +/// (partitions) as input streams (partitions). +/// +/// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html +/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html +#[derive(Debug, Clone)] +pub enum Partitioning { + /// Allocate batches using a round-robin algorithm and the specified number of partitions + RoundRobinBatch(usize), + /// Allocate rows based on a hash of one of more expressions and the specified number of + /// partitions + Hash(Vec>, usize), + /// Unknown partitioning scheme with a known number of partitions + UnknownPartitioning(usize), +} + +impl fmt::Display for Partitioning { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Partitioning::RoundRobinBatch(size) => write!(f, "RoundRobinBatch({size})"), + Partitioning::Hash(phy_exprs, size) => { + let phy_exprs_str = phy_exprs + .iter() + .map(|e| format!("{e}")) + .collect::>() + .join(", "); + write!(f, "Hash([{phy_exprs_str}], {size})") + } + Partitioning::UnknownPartitioning(size) => { + write!(f, "UnknownPartitioning({size})") + } + } + } +} +impl Partitioning { + /// Returns the number of partitions in this partitioning scheme + pub fn partition_count(&self) -> usize { + use Partitioning::*; + match self { + RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, + } + } + + /// Returns true when the guarantees made by this [[Partitioning]] are sufficient to + /// satisfy the partitioning scheme mandated by the `required` [[Distribution]] + pub fn satisfy EquivalenceProperties>( + &self, + required: Distribution, + eq_properties: F, + ) -> bool { + match required { + Distribution::UnspecifiedDistribution => true, + Distribution::SinglePartition if self.partition_count() == 1 => true, + Distribution::HashPartitioned(required_exprs) => { + match self { + // Here we do not check the partition count for hash partitioning and assumes the partition count + // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins, + // then we need to have the partition count and hash functions validation. + Partitioning::Hash(partition_exprs, _) => { + let fast_match = + physical_exprs_equal(&required_exprs, partition_exprs); + // If the required exprs do not match, need to leverage the eq_properties provided by the child + // and normalize both exprs based on the equivalent groups. + if !fast_match { + let eq_properties = eq_properties(); + let eq_groups = eq_properties.eq_group(); + if !eq_groups.is_empty() { + let normalized_required_exprs = required_exprs + .iter() + .map(|e| eq_groups.normalize_expr(e.clone())) + .collect::>(); + let normalized_partition_exprs = partition_exprs + .iter() + .map(|e| eq_groups.normalize_expr(e.clone())) + .collect::>(); + return physical_exprs_equal( + &normalized_required_exprs, + &normalized_partition_exprs, + ); + } + } + fast_match + } + _ => false, + } + } + _ => false, + } + } +} + +impl PartialEq for Partitioning { + fn eq(&self, other: &Partitioning) -> bool { + match (self, other) { + ( + Partitioning::RoundRobinBatch(count1), + Partitioning::RoundRobinBatch(count2), + ) if count1 == count2 => true, + (Partitioning::Hash(exprs1, count1), Partitioning::Hash(exprs2, count2)) + if physical_exprs_equal(exprs1, exprs2) && (count1 == count2) => + { + true + } + _ => false, + } + } +} + +/// How data is distributed amongst partitions. See [`Partitioning`] for more +/// details. +#[derive(Debug, Clone)] +pub enum Distribution { + /// Unspecified distribution + UnspecifiedDistribution, + /// A single partition is required + SinglePartition, + /// Requires children to be distributed in such a way that the same + /// values of the keys end up in the same partition + HashPartitioned(Vec>), +} + +impl Distribution { + /// Creates a `Partitioning` that satisfies this `Distribution` + pub fn create_partitioning(&self, partition_count: usize) -> Partitioning { + match self { + Distribution::UnspecifiedDistribution => { + Partitioning::UnknownPartitioning(partition_count) + } + Distribution::SinglePartition => Partitioning::UnknownPartitioning(1), + Distribution::HashPartitioned(expr) => { + Partitioning::Hash(expr.clone(), partition_count) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::expressions::Column; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + + #[test] + fn partitioning_satisfy_distribution() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("column_1", DataType::Int64, false), + Field::new("column_2", DataType::Utf8, false), + ])); + + let partition_exprs1: Vec> = vec![ + Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), + Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), + ]; + + let partition_exprs2: Vec> = vec![ + Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), + Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), + ]; + + let distribution_types = vec![ + Distribution::UnspecifiedDistribution, + Distribution::SinglePartition, + Distribution::HashPartitioned(partition_exprs1.clone()), + ]; + + let single_partition = Partitioning::UnknownPartitioning(1); + let unspecified_partition = Partitioning::UnknownPartitioning(10); + let round_robin_partition = Partitioning::RoundRobinBatch(10); + let hash_partition1 = Partitioning::Hash(partition_exprs1, 10); + let hash_partition2 = Partitioning::Hash(partition_exprs2, 10); + + for distribution in distribution_types { + let result = ( + single_partition.satisfy(distribution.clone(), || { + EquivalenceProperties::new(schema.clone()) + }), + unspecified_partition.satisfy(distribution.clone(), || { + EquivalenceProperties::new(schema.clone()) + }), + round_robin_partition.satisfy(distribution.clone(), || { + EquivalenceProperties::new(schema.clone()) + }), + hash_partition1.satisfy(distribution.clone(), || { + EquivalenceProperties::new(schema.clone()) + }), + hash_partition2.satisfy(distribution.clone(), || { + EquivalenceProperties::new(schema.clone()) + }), + ); + + match distribution { + Distribution::UnspecifiedDistribution => { + assert_eq!(result, (true, true, true, true, true)) + } + Distribution::SinglePartition => { + assert_eq!(result, (true, false, false, false, false)) + } + Distribution::HashPartitioned(_) => { + assert_eq!(result, (false, false, false, true, false)) + } + } + } + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index d6dd14e8a116d..a8d1e3638a177 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -15,30 +15,29 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Schema}; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; -use arrow::record_batch::RecordBatch; +use crate::sort_properties::SortProperties; +use crate::utils::scatter; +use arrow::array::BooleanArray; +use arrow::compute::filter_record_batch; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; use datafusion_common::utils::DataPtr; -use datafusion_common::{ - ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, -}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; -use std::cmp::Ordering; -use std::fmt::{Debug, Display}; - -use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; -use arrow::compute::{and_kleene, filter_record_batch, is_not_null, SlicesIterator}; - -use crate::intervals::Interval; -use std::any::Any; -use std::sync::Arc; +use itertools::izip; /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { - /// Returns the physical expression as [`Any`](std::any::Any) so that it can be + /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input @@ -57,13 +56,12 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { let tmp_batch = filter_record_batch(batch, selection)?; let tmp_result = self.evaluate(&tmp_batch)?; - // All values from the `selection` filter are true. + if batch.num_rows() == tmp_batch.num_rows() { - return Ok(tmp_result); - } - if let ColumnarValue::Array(a) = tmp_result { - let result = scatter(selection, a.as_ref())?; - Ok(ColumnarValue::Array(result)) + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + scatter(selection, a.as_ref()).map(ColumnarValue::Array) } else { Ok(tmp_result) } @@ -78,162 +76,110 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { children: Vec>, ) -> Result>; - /// Return the boundaries of this expression. This method (and all the - /// related APIs) are experimental and subject to change. - fn analyze(&self, context: AnalysisContext) -> AnalysisContext { - context - } - - /// Computes bounds for the expression using interval arithmetic. + /// Computes the output interval for the expression, given the input + /// intervals. + /// + /// # Arguments + /// + /// * `children` are the intervals for the children (inputs) of this + /// expression. + /// + /// # Example + /// + /// If the expression is `a + b`, and the input intervals are `a: [1, 2]` + /// and `b: [3, 4]`, then the output interval would be `[4, 6]`. fn evaluate_bounds(&self, _children: &[&Interval]) -> Result { - Err(DataFusionError::NotImplemented(format!( - "Not implemented for {self}" - ))) + not_impl_err!("Not implemented for {self}") } - /// Updates/shrinks bounds for the expression using interval arithmetic. - /// If constraint propagation reveals an infeasibility, returns [None] for - /// the child causing infeasibility. If none of the children intervals - /// change, may return an empty vector instead of cloning `children`. + /// Updates bounds for child expressions, given a known interval for this + /// expression. + /// + /// This is used to propagate constraints down through an expression tree. + /// + /// # Arguments + /// + /// * `interval` is the currently known interval for this expression. + /// * `children` are the current intervals for the children of this expression. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of propagation, + /// may return an empty vector instead of cloning `children`. This is the default + /// (and conservative) return value. + /// + /// # Example + /// + /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the + /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then + /// propagation would would return `[0, 2]` and `[2, 4]` as `b` must be at + /// least `2` to make the output at least `4`. fn propagate_constraints( &self, _interval: &Interval, _children: &[&Interval], - ) -> Result>> { - Err(DataFusionError::NotImplemented(format!( - "Not implemented for {self}" - ))) - } -} - -/// Shared [`PhysicalExpr`]. -pub type PhysicalExprRef = Arc; - -/// The shared context used during the analysis of an expression. Includes -/// the boundaries for all known columns. -#[derive(Clone, Debug, PartialEq)] -pub struct AnalysisContext { - /// A list of known column boundaries, ordered by the index - /// of the column in the current schema. - pub column_boundaries: Vec>, - // Result of the current analysis. - pub boundaries: Option, -} - -impl AnalysisContext { - pub fn new( - input_schema: &Schema, - column_boundaries: Vec>, - ) -> Self { - assert_eq!(input_schema.fields().len(), column_boundaries.len()); - Self { - column_boundaries, - boundaries: None, - } - } - - /// Create a new analysis context from column statistics. - pub fn from_statistics(input_schema: &Schema, statistics: &Statistics) -> Self { - // Even if the underlying statistics object doesn't have any column level statistics, - // we can still create an analysis context with the same number of columns and see whether - // we can infer it during the way. - let column_boundaries = match &statistics.column_statistics { - Some(columns) => columns - .iter() - .map(ExprBoundaries::from_column) - .collect::>(), - None => vec![None; input_schema.fields().len()], - }; - Self::new(input_schema, column_boundaries) - } - - pub fn boundaries(&self) -> Option<&ExprBoundaries> { - self.boundaries.as_ref() + ) -> Result>> { + Ok(Some(vec![])) } - /// Set the result of the current analysis. - pub fn with_boundaries(mut self, result: Option) -> Self { - self.boundaries = result; - self + /// Update the hash `state` with this expression requirements from + /// [`Hash`]. + /// + /// This method is required to support hashing [`PhysicalExpr`]s. To + /// implement it, typically the type implementing + /// [`PhysicalExpr`] implements [`Hash`] and + /// then the following boiler plate is used: + /// + /// # Example: + /// ``` + /// // User defined expression that derives Hash + /// #[derive(Hash, Debug, PartialEq, Eq)] + /// struct MyExpr { + /// val: u64 + /// } + /// + /// // impl PhysicalExpr { + /// // ... + /// # impl MyExpr { + /// // Boiler plate to call the derived Hash impl + /// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { + /// use std::hash::Hash; + /// let mut s = state; + /// self.hash(&mut s); + /// } + /// // } + /// # } + /// ``` + /// Note: [`PhysicalExpr`] is not constrained by [`Hash`] + /// directly because it must remain object safe. + fn dyn_hash(&self, _state: &mut dyn Hasher); + + /// The order information of a PhysicalExpr can be estimated from its children. + /// This is especially helpful for projection expressions. If we can ensure that the + /// order of a PhysicalExpr to project matches with the order of SortExec, we can + /// eliminate that SortExecs. + /// + /// By recursively calling this function, we can obtain the overall order + /// information of the PhysicalExpr. Since `SortOptions` cannot fully handle + /// the propagation of unordered columns and literals, the `SortProperties` + /// struct is used. + fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { + SortProperties::Unordered } - - /// Update the boundaries of a column. - pub fn with_column_update( - mut self, - column: usize, - boundaries: ExprBoundaries, - ) -> Self { - self.column_boundaries[column] = Some(boundaries); - self - } -} - -/// Represents the boundaries of the resulting value from a physical expression, -/// if it were to be an expression, if it were to be evaluated. -#[derive(Clone, Debug, PartialEq)] -pub struct ExprBoundaries { - /// Minimum value this expression's result can have. - pub min_value: ScalarValue, - /// Maximum value this expression's result can have. - pub max_value: ScalarValue, - /// Maximum number of distinct values this expression can produce, if known. - pub distinct_count: Option, - /// The estimated percantage of rows that this expression would select, if - /// it were to be used as a boolean predicate on a filter. The value will be - /// between 0.0 (selects nothing) and 1.0 (selects everything). - pub selectivity: Option, } -impl ExprBoundaries { - /// Create a new `ExprBoundaries`. - pub fn new( - min_value: ScalarValue, - max_value: ScalarValue, - distinct_count: Option, - ) -> Self { - Self::new_with_selectivity(min_value, max_value, distinct_count, None) - } - - /// Create a new `ExprBoundaries` with a selectivity value. - pub fn new_with_selectivity( - min_value: ScalarValue, - max_value: ScalarValue, - distinct_count: Option, - selectivity: Option, - ) -> Self { - assert!(!matches!( - min_value.partial_cmp(&max_value), - Some(Ordering::Greater) - )); - Self { - min_value, - max_value, - distinct_count, - selectivity, - } - } - - /// Create a new `ExprBoundaries` from a column level statistics. - pub fn from_column(column: &ColumnStatistics) -> Option { - Some(Self { - min_value: column.min_value.clone()?, - max_value: column.max_value.clone()?, - distinct_count: column.distinct_count, - selectivity: None, - }) - } - - /// Try to reduce the boundaries into a single scalar value, if possible. - pub fn reduce(&self) -> Option { - // TODO: should we check distinct_count is `Some(1) | None`? - if self.min_value == self.max_value { - Some(self.min_value.clone()) - } else { - None - } +impl Hash for dyn PhysicalExpr { + fn hash(&self, state: &mut H) { + self.dyn_hash(state); } } +/// Shared [`PhysicalExpr`]. +pub type PhysicalExprRef = Arc; + /// Returns a copy of this expr if we change any child according to the pointer comparison. /// The size of `children` must be equal to the size of `PhysicalExpr::children()`. pub fn with_new_children_if_necessary( @@ -242,9 +188,7 @@ pub fn with_new_children_if_necessary( ) -> Result> { let old_children = expr.children(); if children.len() != old_children.len() { - Err(DataFusionError::Internal( - "PhysicalExpr: Wrong number of children".to_string(), - )) + internal_err!("PhysicalExpr: Wrong number of children") } else if children.is_empty() || children .iter() @@ -271,162 +215,226 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { } } -/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` -/// are taken, when the mask evaluates `false` values null values are filled. -/// -/// # Arguments -/// * `mask` - Boolean values used to determine where to put the `truthy` values -/// * `truthy` - All values of this array are to scatter according to `mask` into final result. -fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { - let truthy = truthy.to_data(); - - // update the mask so that any null values become false - // (SlicesIterator doesn't respect nulls) - let mask = and_kleene(mask, &is_not_null(mask)?)?; - - let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); - - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values - - // keep track of how much is filled - let mut filled = 0; - // keep track of current position we have in truthy array - let mut true_pos = 0; - - SlicesIterator::new(&mask).for_each(|(start, end)| { - // the gap needs to be filled with nulls - if start > filled { - mutable.extend_nulls(start - filled); +/// This function is similar to the `contains` method of `Vec`. It finds +/// whether `expr` is among `physical_exprs`. +pub fn physical_exprs_contains( + physical_exprs: &[Arc], + expr: &Arc, +) -> bool { + physical_exprs + .iter() + .any(|physical_expr| physical_expr.eq(expr)) +} + +/// Checks whether the given physical expression slices are equal. +pub fn physical_exprs_equal( + lhs: &[Arc], + rhs: &[Arc], +) -> bool { + lhs.len() == rhs.len() && izip!(lhs, rhs).all(|(lhs, rhs)| lhs.eq(rhs)) +} + +/// Checks whether the given physical expression slices are equal in the sense +/// of bags (multi-sets), disregarding their orderings. +pub fn physical_exprs_bag_equal( + lhs: &[Arc], + rhs: &[Arc], +) -> bool { + // TODO: Once we can use `HashMap`s with `Arc`, this + // function should use a `HashMap` to reduce computational complexity. + if lhs.len() == rhs.len() { + let mut rhs_vec = rhs.to_vec(); + for expr in lhs { + if let Some(idx) = rhs_vec.iter().position(|e| expr.eq(e)) { + rhs_vec.swap_remove(idx); + } else { + return false; + } } - // fill with truthy values - let len = end - start; - mutable.extend(0, true_pos, true_pos + len); - true_pos += len; - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - mutable.extend_nulls(mask.len() - filled); + true + } else { + false } - - let data = mutable.freeze(); - Ok(make_array(data)) } -#[macro_export] -// If the given expression is None, return the given context -// without setting the boundaries. -macro_rules! analysis_expect { - ($context: ident, $expr: expr) => { - match $expr { - Some(expr) => expr, - None => return $context.with_boundaries(None), +/// This utility function removes duplicates from the given `exprs` vector. +/// Note that this function does not necessarily preserve its input ordering. +pub fn deduplicate_physical_exprs(exprs: &mut Vec>) { + // TODO: Once we can use `HashSet`s with `Arc`, this + // function should use a `HashSet` to reduce computational complexity. + // See issue: https://github.com/apache/arrow-datafusion/issues/8027 + let mut idx = 0; + while idx < exprs.len() { + let mut rest_idx = idx + 1; + while rest_idx < exprs.len() { + if exprs[idx].eq(&exprs[rest_idx]) { + exprs.swap_remove(rest_idx); + } else { + rest_idx += 1; + } } - }; + idx += 1; + } } #[cfg(test)] mod tests { use std::sync::Arc; - use super::*; - use arrow::array::Int32Array; - use datafusion_common::{ - cast::{as_boolean_array, as_int32_array}, - Result, + use crate::expressions::{Column, Literal}; + use crate::physical_expr::{ + deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, + physical_exprs_equal, PhysicalExpr, }; - #[test] - fn scatter_int() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = - Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } + use datafusion_common::ScalarValue; #[test] - fn scatter_int_end_with_false() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, false, true, false, false, false]); - - // output should be same length as mask - let expected = - Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) + fn test_physical_exprs_contains() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit4 = + Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + + // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b) + let physical_exprs: Vec> = vec![ + lit_true.clone(), + lit_false.clone(), + lit4.clone(), + lit2.clone(), + col_a_expr.clone(), + col_b_expr.clone(), + ]; + // below expressions are inside physical_exprs + assert!(physical_exprs_contains(&physical_exprs, &lit_true)); + assert!(physical_exprs_contains(&physical_exprs, &lit2)); + assert!(physical_exprs_contains(&physical_exprs, &col_b_expr)); + + // below expressions are not inside physical_exprs + assert!(!physical_exprs_contains(&physical_exprs, &col_c_expr)); + assert!(!physical_exprs_contains(&physical_exprs, &lit1)); } #[test] - fn scatter_with_null_mask() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); - let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] - .into_iter() - .collect(); - - // output should treat nulls as though they are false - let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) + fn test_physical_exprs_equal() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let vec1 = vec![lit_true.clone(), lit_false.clone()]; + let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; + let vec3 = vec![lit2.clone(), lit1.clone()]; + let vec4 = vec![lit_true.clone(), lit_false.clone()]; + + // these vectors are same + assert!(physical_exprs_equal(&vec1, &vec1)); + assert!(physical_exprs_equal(&vec1, &vec4)); + assert!(physical_exprs_bag_equal(&vec1, &vec1)); + assert!(physical_exprs_bag_equal(&vec1, &vec4)); + + // these vectors are different + assert!(!physical_exprs_equal(&vec1, &vec2)); + assert!(!physical_exprs_equal(&vec1, &vec3)); + assert!(!physical_exprs_bag_equal(&vec1, &vec2)); + assert!(!physical_exprs_bag_equal(&vec1, &vec3)); } #[test] - fn scatter_boolean() -> Result<()> { - let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = BooleanArray::from_iter(vec![ - Some(false), - Some(false), - None, - None, - Some(false), - ]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_boolean_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) + fn test_physical_exprs_set_equal() { + let list1: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list2: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + ]; + assert!(!physical_exprs_bag_equal( + list1.as_slice(), + list2.as_slice() + )); + assert!(!physical_exprs_bag_equal( + list2.as_slice(), + list1.as_slice() + )); + assert!(!physical_exprs_equal(list1.as_slice(), list2.as_slice())); + assert!(!physical_exprs_equal(list2.as_slice(), list1.as_slice())); + + let list3: Vec> = vec![ + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("b", 1)), + ]; + let list4: Vec> = vec![ + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("b", 1)), + Arc::new(Column::new("a", 0)), + Arc::new(Column::new("c", 2)), + Arc::new(Column::new("a", 0)), + ]; + assert!(physical_exprs_bag_equal(list3.as_slice(), list4.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); + assert!(!physical_exprs_equal(list3.as_slice(), list4.as_slice())); + assert!(!physical_exprs_equal(list4.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); + assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); } #[test] - fn reduce_boundaries() -> Result<()> { - let different_boundaries = ExprBoundaries::new( - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(10)), - None, - ); - assert_eq!(different_boundaries.reduce(), None); - - let scalar_boundaries = ExprBoundaries::new( - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(1)), - None, - ); - assert_eq!( - scalar_boundaries.reduce(), - Some(ScalarValue::Int32(Some(1))) - ); - - // Can still reduce. - let no_boundaries = - ExprBoundaries::new(ScalarValue::Int32(None), ScalarValue::Int32(None), None); - assert_eq!(no_boundaries.reduce(), Some(ScalarValue::Int32(None))); - - Ok(()) + fn test_deduplicate_physical_exprs() { + let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc); + let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc); + let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4)))) + as Arc); + let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) + as Arc); + let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc); + let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc); + + // First vector in the tuple is arguments, second one is the expected value. + let test_cases = vec![ + // ---------- TEST CASE 1----------// + ( + vec![ + lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr, + lit_true, lit2, + ], + vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr], + ), + // ---------- TEST CASE 2----------// + ( + vec![lit_true, lit_true, lit_false, lit4], + vec![lit_true, lit4, lit_false], + ), + ]; + for (exprs, expected) in test_cases { + let mut exprs = exprs.into_iter().cloned().collect::>(); + let expected = expected.into_iter().cloned().collect::>(); + deduplicate_physical_exprs(&mut exprs); + assert!(physical_exprs_equal(&exprs, &expected)); + } } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 40241a459e6e5..9c212cb81f6b3 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::GetFieldAccessExpr; use crate::var_provider::is_system_variables; use crate::{ execution_props::ExecutionProps, - expressions::{ - self, binary, date_time_interval_expr, like, Column, GetIndexedFieldExpr, Literal, - }, + expressions::{self, binary, like, Column, GetIndexedFieldExpr, Literal}, functions, udf, var_provider::VarType, PhysicalExpr, }; -use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{Cast, InList, ScalarFunction, ScalarUDF}; +use arrow::datatypes::Schema; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, + ScalarValue, +}; +use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::{ - binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast, + binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, + Operator, ScalarFunctionDefinition, TryCast, }; use std::sync::Arc; @@ -49,15 +52,15 @@ pub fn create_physical_expr( execution_props: &ExecutionProps, ) -> Result> { if input_schema.fields.len() != input_dfschema.fields().len() { - return Err(DataFusionError::Internal(format!( + return internal_err!( "create_physical_expr expected same number of fields, got \ Arrow schema with {} and DataFusion schema with {}", input_schema.fields.len(), input_dfschema.fields().len() - ))); + ); } match e { - Expr::Alias(expr, ..) => Ok(create_physical_expr( + Expr::Alias(Alias { expr, .. }) => Ok(create_physical_expr( expr, input_dfschema, input_schema, @@ -75,9 +78,7 @@ pub fn create_physical_expr( let scalar_value = provider.get_value(variable_names.clone())?; Ok(Arc::new(Literal::new(scalar_value))) } - _ => Err(DataFusionError::Plan( - "No system variable provider found".to_string(), - )), + _ => plan_err!("No system variable provider found"), } } else { match execution_props.get_var_provider(VarType::UserDefined) { @@ -85,9 +86,7 @@ pub fn create_physical_expr( let scalar_value = provider.get_value(variable_names.clone())?; Ok(Arc::new(Literal::new(scalar_value))) } - _ => Err(DataFusionError::Plan( - "No user defined variable provider found".to_string(), - )), + _ => plan_err!("No user defined variable provider found"), } } } @@ -183,87 +182,24 @@ pub fn create_physical_expr( input_schema, execution_props, )?; - // Match the data types and operator to determine the appropriate expression, if - // they are supported temporal types and operations, create DateTimeIntervalExpr, - // else create BinaryExpr. - match ( - lhs.data_type(input_schema)?, - op, - rhs.data_type(input_schema)?, - ) { - ( - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - Operator::Plus | Operator::Minus, - DataType::Interval(_), - ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?), - ( - DataType::Interval(_), - Operator::Plus | Operator::Minus, - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), - ) => Ok(date_time_interval_expr(rhs, *op, lhs, input_schema)?), - ( - DataType::Timestamp(_, _), - Operator::Minus, - DataType::Timestamp(_, _), - ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?), - ( - DataType::Interval(_), - Operator::Plus | Operator::Minus, - DataType::Interval(_), - ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?), - _ => { - // Note that the logical planner is responsible - // for type coercion on the arguments (e.g. if one - // argument was originally Int32 and one was - // Int64 they will both be coerced to Int64). - // - // There should be no coercion during physical - // planning. - binary(lhs, *op, rhs, input_schema) - } - } + // Note that the logical planner is responsible + // for type coercion on the arguments (e.g. if one + // argument was originally Int32 and one was + // Int64 they will both be coerced to Int64). + // + // There should be no coercion during physical + // planning. + binary(lhs, *op, rhs, input_schema) } Expr::Like(Like { negated, expr, pattern, escape_char, + case_insensitive, }) => { if escape_char.is_some() { - return Err(DataFusionError::Execution( - "LIKE does not support escape_char".to_string(), - )); - } - let physical_expr = create_physical_expr( - expr, - input_dfschema, - input_schema, - execution_props, - )?; - let physical_pattern = create_physical_expr( - pattern, - input_dfschema, - input_schema, - execution_props, - )?; - like( - *negated, - false, - physical_expr, - physical_pattern, - input_schema, - ) - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - if escape_char.is_some() { - return Err(DataFusionError::Execution( - "ILIKE does not support escape_char".to_string(), - )); + return exec_err!("LIKE does not support escape_char"); } let physical_expr = create_physical_expr( expr, @@ -279,7 +215,7 @@ pub fn create_physical_expr( )?; like( *negated, - true, + *case_insensitive, physical_expr, physical_pattern, input_schema, @@ -371,7 +307,36 @@ pub fn create_physical_expr( input_schema, execution_props, )?), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + GetFieldAccessExpr::NamedStructField { name: name.clone() } + } + GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex { + key: create_physical_expr( + key, + input_dfschema, + input_schema, + execution_props, + )?, + }, + GetFieldAccess::ListRange { start, stop } => { + GetFieldAccessExpr::ListRange { + start: create_physical_expr( + start, + input_dfschema, + input_schema, + execution_props, + )?, + stop: create_physical_expr( + stop, + input_dfschema, + input_schema, + execution_props, + )?, + } + } + }; Ok(Arc::new(GetIndexedFieldExpr::new( create_physical_expr( expr, @@ -379,39 +344,41 @@ pub fn create_physical_expr( input_schema, execution_props, )?, - key.clone(), + field, ))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let physical_args = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let mut physical_args = args .iter() .map(|e| { create_physical_expr(e, input_dfschema, input_schema, execution_props) }) .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - )?); - } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + functions::create_physical_expr( + fun, + &physical_args, + input_schema, + execution_props, + ) + } + ScalarFunctionDefinition::UDF(fun) => { + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } } - udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) } Expr::Between(Between { expr, @@ -478,8 +445,42 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, - other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {other:?}" - ))), + other => { + not_impl_err!("Physical plan does not support logical expression {other:?}") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{col, left, Literal}; + + #[test] + fn test_create_physical_expr_scalar_input_output() -> Result<()> { + let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + + let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); + let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; + let p = create_physical_expr(&expr, &df_schema, &schema, &ExecutionProps::new())?; + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StringArray::from_iter_values(vec![ + "A", "B", "C", "D", + ]))], + )?; + let result = p.evaluate(&batch)?; + let result = result.into_array(4).expect("Failed to convert to array"); + + assert_eq!( + &result, + &(Arc::new(BooleanArray::from(vec![true, false, false, false,])) as ArrayRef) + ); + + Ok(()) } } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 3965897093ace..41cd01949595a 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -26,12 +26,14 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::compute; -use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result}; +use datafusion_common::plan_err; +use datafusion_common::{ + cast::as_generic_string_array, internal_err, DataFusionError, Result, +}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use hashbrown::HashMap; -use lazy_static::lazy_static; use regex::Regex; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hint}; @@ -65,24 +67,25 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { match flags { Some(f) if f.iter().any(|s| s == Some("g")) => { - Err(DataFusionError::Plan("regexp_match() does not support the \"global\" option".to_owned())) + plan_err!("regexp_match() does not support the \"global\" option") }, _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), } } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "regexp_match was called with {other} arguments. It requires at least 2 and at most 3." - ))), + ), } } /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { - lazy_static! { - static ref CAPTURE_GROUPS_RE: Regex = Regex::new(r"(\\)(\d*)").unwrap(); + fn capture_groups_re() -> &'static Regex { + static CAPTURE_GROUPS_RE_LOCK: OnceLock = OnceLock::new(); + CAPTURE_GROUPS_RE_LOCK.get_or_init(|| Regex::new(r"(\\)(\d*)").unwrap()) } - CAPTURE_GROUPS_RE + capture_groups_re() .replace_all(replacement, "$${$2}") .into_owned() } @@ -185,9 +188,9 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." - ))), + ), } } @@ -218,9 +221,9 @@ fn _regexp_replace_static_pattern_replace( 3 => None, 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort)), other => { - return Err(DataFusionError::Internal(format!( + return internal_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." - ))) + ) } }; @@ -389,7 +392,7 @@ mod tests { regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); - assert_eq!(re_err.to_string(), "Error during planning: regexp_match() does not support the \"global\" option"); + assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option"); } #[test] @@ -497,7 +500,7 @@ mod tests { ]); let pattern_err = re.expect_err("broken pattern should have failed"); assert_eq!( - pattern_err.to_string(), + pattern_err.strip_backtrace(), "External error: regex parse error:\n [\n ^\nerror: unclosed character class" ); } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index da47a55aa9e39..0a9d69720e19a 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -29,19 +29,23 @@ //! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed //! to a function that supports f64, it is coerced to f64. -use crate::physical_expr::down_cast_any_ref; -use crate::utils::expr_list_eq_strict_order; +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::functions::out_ordering; +use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; +use crate::sort_properties::SortProperties; use crate::PhysicalExpr; + use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_expr::BuiltinScalarFunction; -use datafusion_expr::ColumnarValue; -use datafusion_expr::ScalarFunctionImplementation; -use std::any::Any; -use std::fmt::Debug; -use std::fmt::{self, Formatter}; -use std::sync::Arc; +use datafusion_expr::{ + expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, + ScalarFunctionImplementation, +}; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { @@ -49,6 +53,11 @@ pub struct ScalarFunctionExpr { name: String, args: Vec>, return_type: DataType, + // Keeps monotonicity information of the function. + // FuncMonotonicity vector is one to one mapped to `args`, + // and it specifies the effect of an increase or decrease in + // the corresponding `arg` to the function value. + monotonicity: Option, } impl Debug for ScalarFunctionExpr { @@ -68,13 +77,15 @@ impl ScalarFunctionExpr { name: &str, fun: ScalarFunctionImplementation, args: Vec>, - return_type: &DataType, + return_type: DataType, + monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type: return_type.clone(), + return_type, + monotonicity, } } @@ -97,20 +108,16 @@ impl ScalarFunctionExpr { pub fn return_type(&self) -> &DataType { &self.return_type } + + /// Monotonicity information of the function + pub fn monotonicity(&self) -> &Option { + &self.monotonicity + } } impl fmt::Display for ScalarFunctionExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}({})", - self.name, - self.args - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - ) + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) } } @@ -132,7 +139,14 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the arguments, if there are no arguments we'll instead pass in a null array // indicating the batch size (as a convention) let inputs = match (self.args.len(), self.name.parse::()) { - (0, Ok(scalar_fun)) if scalar_fun.supports_zero_argument() => { + // MakeArray support zero argument but has the different behavior from the array with one null. + (0, Ok(scalar_fun)) + if scalar_fun + .signature() + .type_signature + .supports_zero_argument() + && scalar_fun != BuiltinScalarFunction::MakeArray => + { vec![ColumnarValue::create_null_array(batch.num_rows())] } _ => self @@ -159,9 +173,25 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, self.fun.clone(), children, - self.return_type(), + self.return_type().clone(), + self.monotonicity.clone(), ))) } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.name.hash(&mut s); + self.args.hash(&mut s); + self.return_type.hash(&mut s); + // Add `self.fun` when hash is available + } + + fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { + self.monotonicity + .as_ref() + .map(|monotonicity| out_ordering(monotonicity, children)) + .unwrap_or(SortProperties::Unordered) + } } impl PartialEq for ScalarFunctionExpr { @@ -171,7 +201,7 @@ impl PartialEq for ScalarFunctionExpr { .downcast_ref::() .map(|x| { self.name == x.name - && expr_list_eq_strict_order(&self.args, &x.args) + && physical_exprs_equal(&self.args, &x.args) && self.return_type == x.return_type }) .unwrap_or(false) diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index df519551d8a69..914d76f9261a1 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -17,12 +17,17 @@ //! Sort expressions +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + use crate::PhysicalExpr; + use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use arrow_schema::Schema; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; -use std::sync::Arc; /// Represents Sort operation for a column in a RecordBatch #[derive(Clone, Debug)] @@ -39,6 +44,15 @@ impl PartialEq for PhysicalSortExpr { } } +impl Eq for PhysicalSortExpr {} + +impl Hash for PhysicalSortExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.options.hash(state); + } +} + impl std::fmt::Display for PhysicalSortExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{} {}", self.expr, to_str(&self.options)) @@ -51,11 +65,7 @@ impl PhysicalSortExpr { let value_to_sort = self.expr.evaluate(batch)?; let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - return Err(DataFusionError::Plan(format!( - "Sort operation is not applicable to scalar value {scalar}" - ))); - } + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, }; Ok(SortColumn { values: array_to_sort, @@ -63,18 +73,46 @@ impl PhysicalSortExpr { }) } - /// Check whether sort expression satisfies [`PhysicalSortRequirement`]. - /// - /// If sort options is Some in `PhysicalSortRequirement`, `expr` - /// and `options` field are compared for equality. - /// - /// If sort options is None in `PhysicalSortRequirement`, only - /// `expr` is compared for equality. - pub fn satisfy(&self, requirement: &PhysicalSortRequirement) -> bool { + /// Checks whether this sort expression satisfies the given `requirement`. + /// If sort options are unspecified in `requirement`, only expressions are + /// compared for inequality. + pub fn satisfy( + &self, + requirement: &PhysicalSortRequirement, + schema: &Schema, + ) -> bool { + // If the column is not nullable, NULLS FIRST/LAST is not important. + let nullable = self.expr.nullable(schema).unwrap_or(true); self.expr.eq(&requirement.expr) - && requirement - .options - .map_or(true, |opts| self.options == opts) + && if nullable { + requirement + .options + .map_or(true, |opts| self.options == opts) + } else { + requirement + .options + .map_or(true, |opts| self.options.descending == opts.descending) + } + } + + /// Returns a [`Display`]able list of `PhysicalSortExpr`. + pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [PhysicalSortExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for sort_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", sort_expr)?; + } + Ok(()) + } + } + DisplayableList(input) } } @@ -214,8 +252,18 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` is a type alias for lexicographical ordering definition`Vec` +///`LexOrdering` is an alias for the type `Vec`, which represents +/// a lexicographical ordering. pub type LexOrdering = Vec; -///`LexOrderingReq` is a type alias for lexicographical ordering requirement definition`Vec` -pub type LexOrderingReq = Vec; +///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents +/// a reference to a lexicographical ordering. +pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; + +///`LexRequirement` is an alias for the type `Vec`, which +/// represents a lexicographical ordering requirement. +pub type LexRequirement = Vec; + +///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which +/// represents a reference to a lexicographical ordering requirement. +pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; diff --git a/datafusion/physical-expr/src/sort_properties.rs b/datafusion/physical-expr/src/sort_properties.rs new file mode 100644 index 0000000000000..f513744617769 --- /dev/null +++ b/datafusion/physical-expr/src/sort_properties.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ops::Neg, sync::Arc}; + +use arrow_schema::SortOptions; + +use crate::PhysicalExpr; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::Result; + +/// To propagate [`SortOptions`] across the [`PhysicalExpr`], it is insufficient +/// to simply use `Option`: There must be a differentiation between +/// unordered columns and literal values, since literals may not break the ordering +/// when they are used as a child of some binary expression when the other child has +/// some ordering. On the other hand, unordered columns cannot maintain ordering when +/// they take part in such operations. +/// +/// Example: ((a_ordered + b_unordered) + c_ordered) expression cannot end up with +/// sorted data; however the ((a_ordered + 999) + c_ordered) expression can. Therefore, +/// we need two different variants for literals and unordered columns as literals are +/// often more ordering-friendly under most mathematical operations. +#[derive(PartialEq, Debug, Clone, Copy, Default)] +pub enum SortProperties { + /// Use the ordinary [`SortOptions`] struct to represent ordered data: + Ordered(SortOptions), + // This alternative represents unordered data: + #[default] + Unordered, + // Singleton is used for single-valued literal numbers: + Singleton, +} + +impl SortProperties { + pub fn add(&self, rhs: &Self) -> Self { + match (self, rhs) { + (Self::Singleton, _) => *rhs, + (_, Self::Singleton) => *self, + (Self::Ordered(lhs), Self::Ordered(rhs)) + if lhs.descending == rhs.descending => + { + Self::Ordered(SortOptions { + descending: lhs.descending, + nulls_first: lhs.nulls_first || rhs.nulls_first, + }) + } + _ => Self::Unordered, + } + } + + pub fn sub(&self, rhs: &Self) -> Self { + match (self, rhs) { + (Self::Singleton, Self::Singleton) => Self::Singleton, + (Self::Singleton, Self::Ordered(rhs)) => Self::Ordered(SortOptions { + descending: !rhs.descending, + nulls_first: rhs.nulls_first, + }), + (_, Self::Singleton) => *self, + (Self::Ordered(lhs), Self::Ordered(rhs)) + if lhs.descending != rhs.descending => + { + Self::Ordered(SortOptions { + descending: lhs.descending, + nulls_first: lhs.nulls_first || rhs.nulls_first, + }) + } + _ => Self::Unordered, + } + } + + pub fn gt_or_gteq(&self, rhs: &Self) -> Self { + match (self, rhs) { + (Self::Singleton, Self::Ordered(rhs)) => Self::Ordered(SortOptions { + descending: !rhs.descending, + nulls_first: rhs.nulls_first, + }), + (_, Self::Singleton) => *self, + (Self::Ordered(lhs), Self::Ordered(rhs)) + if lhs.descending != rhs.descending => + { + *self + } + _ => Self::Unordered, + } + } + + pub fn and_or(&self, rhs: &Self) -> Self { + match (self, rhs) { + (Self::Ordered(lhs), Self::Ordered(rhs)) + if lhs.descending == rhs.descending => + { + Self::Ordered(SortOptions { + descending: lhs.descending, + nulls_first: lhs.nulls_first || rhs.nulls_first, + }) + } + (Self::Ordered(opt), Self::Singleton) + | (Self::Singleton, Self::Ordered(opt)) => Self::Ordered(SortOptions { + descending: opt.descending, + nulls_first: opt.nulls_first, + }), + (Self::Singleton, Self::Singleton) => Self::Singleton, + _ => Self::Unordered, + } + } +} + +impl Neg for SortProperties { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + SortProperties::Ordered(SortOptions { + descending, + nulls_first, + }) => SortProperties::Ordered(SortOptions { + descending: !descending, + nulls_first, + }), + SortProperties::Singleton => SortProperties::Singleton, + SortProperties::Unordered => SortProperties::Unordered, + } + } +} + +/// The `ExprOrdering` struct is designed to aid in the determination of ordering (represented +/// by [`SortProperties`]) for a given [`PhysicalExpr`]. When analyzing the orderings +/// of a [`PhysicalExpr`], the process begins by assigning the ordering of its leaf nodes. +/// By propagating these leaf node orderings upwards in the expression tree, the overall +/// ordering of the entire [`PhysicalExpr`] can be derived. +/// +/// This struct holds the necessary state information for each expression in the [`PhysicalExpr`]. +/// It encapsulates the orderings (`state`) associated with the expression (`expr`), and +/// orderings of the children expressions (`children_states`). The [`ExprOrdering`] of a parent +/// expression is determined based on the [`ExprOrdering`] states of its children expressions. +#[derive(Debug)] +pub struct ExprOrdering { + pub expr: Arc, + pub state: SortProperties, + pub children: Vec, +} + +impl ExprOrdering { + /// Creates a new [`ExprOrdering`] with [`SortProperties::Unordered`] states + /// for `expr` and its children. + pub fn new(expr: Arc) -> Self { + let children = expr.children(); + Self { + expr, + state: Default::default(), + children: children.into_iter().map(Self::new).collect(), + } + } + + /// Get a reference to each child state. + pub fn children_state(&self) -> Vec { + self.children.iter().map(|c| c.state).collect() + } +} + +impl TreeNode for ExprOrdering { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in &self.children { + match op(child)? { + VisitRecursion::Continue => {} + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } + } + Ok(VisitRecursion::Continue) + } + + fn map_children(mut self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + if self.children.is_empty() { + Ok(self) + } else { + self.children = self + .children + .into_iter() + .map(transform) + .collect::>>()?; + Ok(self) + } + } +} diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 34319b9d97805..7d9fecf614075 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,21 +23,25 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait, - StringArray, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + OffsetSizeTrait, StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use datafusion_common::utils::datafusion_strsim; use datafusion_common::{ cast::{ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, }, - ScalarValue, + exec_err, ScalarValue, }; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; -use std::iter; use std::sync::Arc; +use std::{ + fmt::{Display, Formatter}, + iter, +}; use uuid::Uuid; /// applies a unary expression to `args[0]` that is expected to be downcastable to @@ -58,11 +62,11 @@ where F: Fn(&'a str) -> R, { if args.len() != 1 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "{:?} args were supplied but {} takes exactly one argument", args.len(), - name, - ))); + name + ); } let string_array = as_generic_string_array::(args[0])?; @@ -98,9 +102,7 @@ where &[a.as_ref()], op, name )?))) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {name}", - ))), + other => internal_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(a) => { @@ -111,9 +113,7 @@ where let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function {name}", - ))), + other => internal_err!("Unsupported data type {other:?} for function {name}"), }, } } @@ -136,53 +136,6 @@ pub fn ascii(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - string.trim_start_matches(' ').trim_end_matches(' ') - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some( - string - .trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]), - ) - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => Err(DataFusionError::Internal(format!( - "btrim was called with {other} arguments. It requires at least 1 and at most 2." - ))), - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -195,15 +148,13 @@ pub fn chr(args: &[ArrayRef]) -> Result { integer .map(|integer| { if integer == 0 { - Err(DataFusionError::Execution( - "null character not permitted.".to_string(), - )) + exec_err!("null character not permitted.") } else { match core::char::from_u32(integer as u32) { Some(integer) => Ok(integer.to_string()), - None => Err(DataFusionError::Execution( - "requested character too large for encoding.".to_string(), - )), + None => { + exec_err!("requested character too large for encoding.") + } } } }) @@ -219,10 +170,10 @@ pub fn chr(args: &[ArrayRef]) -> Result { pub fn concat(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal(format!( + return internal_err!( "concat was called with {} arguments. It requires at least 1.", args.len() - ))); + ); } // first, decide whether to return a scalar or a vector. @@ -285,10 +236,10 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { // do not accept 0 or 1 arguments. if args.len() < 2 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "concat_ws was called with {} arguments. It requires at least 2.", args.len() - ))); + ); } // first map is the iterator, second is for the `Option<_>` @@ -351,44 +302,95 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Removes the longest string containing only characters in characters (a space by default) from the start of string. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + match args.len() { 1 => { - let string_array = as_generic_string_array::(&args[0])?; - let result = string_array .iter() - .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| func(string, " "))) .collect::>(); Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_start_matches(&chars[..])) - } + (Some(string), Some(characters)) => Some(func(string, characters)), _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } - other => Err(DataFusionError::Internal(format!( - "ltrim was called with {other} arguments. It requires at least 1 and at most 2." - ))), + other => { + internal_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } } } +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -427,44 +429,6 @@ pub fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Removes the longest string containing only characters in characters (a space by default) from the end of string. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - match args.len() { - 1 => { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => { - let chars: Vec = characters.chars().collect(); - Some(string.trim_end_matches(&chars[..])) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => Err(DataFusionError::Internal(format!( - "rtrim was called with {other} arguments. It requires at least 1 and at most 2." - ))), - } -} - /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' pub fn split_part(args: &[ArrayRef]) -> Result { @@ -478,9 +442,7 @@ pub fn split_part(args: &[ArrayRef]) -> Result { .map(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { if n <= 0 { - Err(DataFusionError::Execution( - "field position must be greater than zero".to_string(), - )) + exec_err!("field position must be greater than zero") } else { let split_string: Vec<&str> = string.split(delimiter).collect(); match split_string.get(n as usize - 1) { @@ -531,9 +493,7 @@ where } else if let Some(value_isize) = value.to_isize() { Ok(Some(format!("{value_isize:x}"))) } else { - Err(DataFusionError::Internal(format!( - "Unsupported data type {integer:?} for function to_hex", - ))) + internal_err!("Unsupported data type {integer:?} for function to_hex") } } else { Ok(None) @@ -555,11 +515,7 @@ pub fn upper(args: &[ColumnarValue]) -> Result { pub fn uuid(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { ColumnarValue::Array(array) => array.len(), - _ => { - return Err(DataFusionError::Internal( - "Expect uuid function to take no param".to_string(), - )) - } + _ => return internal_err!("Expect uuid function to take no param"), }; let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); @@ -567,11 +523,149 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "overlay was called with {other} arguments. It requires 3 or 4." + ) + } + } +} + +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein function requires two arguments, got {}", + args.len() + ))); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow_array::Int64Array; + use datafusion_common::cast::as_int32_array; use super::*; @@ -613,4 +707,36 @@ mod tests { Ok(()) } + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index dc8812b1ee242..0eed1d16fba8c 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -19,16 +19,14 @@ use arrow::array::*; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use std::sync::Arc; fn array_struct(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "struct requires at least one argument".to_string(), - )); + return exec_err!("struct requires at least one argument"); } let vec: Vec<_> = args @@ -57,9 +55,9 @@ fn array_struct(args: &[ArrayRef]) -> Result { )), arg.clone(), )), - data_type => Err(DataFusionError::NotImplemented(format!( - "Struct is not implemented for type '{data_type:?}'." - ))), + data_type => { + not_impl_err!("Struct is not implemented for type '{data_type:?}'.") + } } }) .collect::>>()?; @@ -69,12 +67,67 @@ fn array_struct(args: &[ArrayRef]) -> Result { /// put values in a struct array. pub fn struct_expr(values: &[ColumnarValue]) -> Result { - let arrays: Vec = values + let arrays = values .iter() - .map(|x| match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), + .map(|x| { + Ok(match x { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), + }) }) - .collect(); + .collect::>>()?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_struct_array; + use datafusion_common::ScalarValue; + + #[test] + fn test_struct() { + // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} + let args = [ + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + let struc = struct_expr(&args) + .expect("failed to initialize function struct") + .into_array(1) + .expect("Failed to convert to array"); + let result = + as_struct_array(&struc).expect("failed to initialize function struct"); + assert_eq!( + &Int64Array::from(vec![1]), + result + .column_by_name("c0") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![2]), + result + .column_by_name("c1") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![3]), + result + .column_by_name("c2") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + } +} diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 5aca1df8a8005..0ec1cf3f256b0 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -35,9 +35,10 @@ pub fn create_physical_expr( .collect::>>()?; Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), + fun.name(), + fun.fun().clone(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type(&input_exprs_types)?, + None, ))) } diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 6654904cf1b7e..240efe4223c33 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -27,7 +27,7 @@ use arrow::{ }; use datafusion_common::{ cast::{as_generic_string_array, as_int64_array}, - DataFusionError, Result, + exec_err, internal_err, DataFusionError, Result, }; use hashbrown::HashMap; use std::cmp::{max, Ordering}; @@ -102,9 +102,9 @@ pub fn lpad(args: &[ArrayRef]) -> Result { .map(|(string, length)| match (string, length) { (Some(string), Some(length)) => { if length > i32::MAX as i64 { - return Err(DataFusionError::Internal(format!( + return exec_err!( "lpad requested length {length} too large" - ))); + ); } let length = if length < 0 { 0 } else { length as usize }; @@ -139,9 +139,9 @@ pub fn lpad(args: &[ArrayRef]) -> Result { .map(|((string, length), fill)| match (string, length, fill) { (Some(string), Some(length), Some(fill)) => { if length > i32::MAX as i64 { - return Err(DataFusionError::Internal(format!( + return exec_err!( "lpad requested length {length} too large" - ))); + ); } let length = if length < 0 { 0 } else { length as usize }; @@ -178,9 +178,9 @@ pub fn lpad(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - other => Err(DataFusionError::Internal(format!( + other => exec_err!( "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ))), + ), } } @@ -245,9 +245,9 @@ pub fn rpad(args: &[ArrayRef]) -> Result { .map(|(string, length)| match (string, length) { (Some(string), Some(length)) => { if length > i32::MAX as i64 { - return Err(DataFusionError::Internal(format!( + return exec_err!( "rpad requested length {length} too large" - ))); + ); } let length = if length < 0 { 0 } else { length as usize }; @@ -281,9 +281,9 @@ pub fn rpad(args: &[ArrayRef]) -> Result { .map(|((string, length), fill)| match (string, length, fill) { (Some(string), Some(length), Some(fill)) => { if length > i32::MAX as i64 { - return Err(DataFusionError::Internal(format!( + return exec_err!( "rpad requested length {length} too large" - ))); + ); } let length = if length < 0 { 0 } else { length as usize }; @@ -312,9 +312,9 @@ pub fn rpad(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - other => Err(DataFusionError::Internal(format!( + other => internal_err!( "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ))), + ), } } @@ -391,9 +391,9 @@ pub fn substr(args: &[ArrayRef]) -> Result { .map(|((string, start), count)| match (string, start, count) { (Some(string), Some(start), Some(count)) => { if count < 0 { - Err(DataFusionError::Execution(format!( + exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ))) + ) } else { let skip = max(0, start - 1); let count = max(0, count + (if start < 1 {start - 1} else {0})); @@ -406,9 +406,9 @@ pub fn substr(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - other => Err(DataFusionError::Internal(format!( - "substr was called with {other} arguments. It requires 2 or 3." - ))), + other => { + internal_err!("substr was called with {other} arguments. It requires 2 or 3.") + } } } @@ -455,3 +455,107 @@ pub fn translate(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return internal_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + let mut res = String::new(); + match n { + 0 => { + "".to_string(); + } + _other => { + if n > 0 { + let idx = string + .split(delimiter) + .take(n as usize) + .fold(0, |len, x| len + x.len() + delimiter.len()) + - delimiter.len(); + res.push_str(if idx >= string.len() { + string + } else { + &string[..idx] + }); + } else { + let idx = (string.split(delimiter).take((-n) as usize).fold( + string.len() as isize, + |len, x| { + len - x.len() as isize - delimiter.len() as isize + }, + ) + delimiter.len() as isize) + as usize; + res.push_str(if idx >= string.len() { + string + } else { + &string[idx..] + }); + } + } + } + Some(res) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return internal_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index f95ec032eb9e8..71a7ff5fb7785 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::equivalence::{ - EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties, - OrderingEquivalentClass, -}; -use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; -use crate::{PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement}; +use std::borrow::Borrow; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use crate::expressions::{BinaryExpr, Column}; +use crate::{PhysicalExpr, PhysicalSortExpr}; + +use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, @@ -29,45 +31,9 @@ use datafusion_common::tree_node::{ use datafusion_common::Result; use datafusion_expr::Operator; +use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; -use std::borrow::Borrow; -use std::collections::HashMap; -use std::collections::HashSet; -use std::ops::Range; -use std::sync::Arc; - -/// Compare the two expr lists are equal no matter the order. -/// For example two InListExpr can be considered to be equals no matter the order: -/// -/// In('a','b','c') == In('c','b','a') -pub fn expr_list_eq_any_order( - list1: &[Arc], - list2: &[Arc], -) -> bool { - if list1.len() == list2.len() { - let mut expr_vec1 = list1.to_vec(); - let mut expr_vec2 = list2.to_vec(); - while let Some(expr1) = expr_vec1.pop() { - if let Some(idx) = expr_vec2.iter().position(|expr2| expr1.eq(expr2)) { - expr_vec2.swap_remove(idx); - } else { - break; - } - } - expr_vec1.is_empty() && expr_vec2.is_empty() - } else { - false - } -} - -/// Strictly compare the two expr lists are equal in the given order. -pub fn expr_list_eq_strict_order( - list1: &[Arc], - list2: &[Arc], -) -> bool { - list1.len() == list2.len() && list1.iter().zip(list2.iter()).all(|(e1, e2)| e1.eq(e2)) -} /// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs. /// @@ -100,336 +66,6 @@ fn split_conjunction_impl<'a>( } } -/// Normalize the output expressions based on Columns Map. -/// -/// If there is a mapping in Columns Map, replace the Column in the output expressions with the 1st Column in the Columns Map. -/// Otherwise, replace the Column with a place holder of [UnKnownColumn] -/// -pub fn normalize_out_expr_with_columns_map( - expr: Arc, - columns_map: &HashMap>, -) -> Arc { - expr.clone() - .transform(&|expr| { - let normalized_form = match expr.as_any().downcast_ref::() { - Some(column) => columns_map - .get(column) - .map(|c| Arc::new(c[0].clone()) as _) - .or_else(|| Some(Arc::new(UnKnownColumn::new(column.name())) as _)), - None => None, - }; - Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) - } else { - Transformed::No(expr) - }) - }) - .unwrap_or(expr) -} - -pub fn normalize_expr_with_equivalence_properties( - expr: Arc, - eq_properties: &[EquivalentClass], -) -> Arc { - expr.clone() - .transform(&|expr| { - let normalized_form = - expr.as_any().downcast_ref::().and_then(|column| { - for class in eq_properties { - if class.contains(column) { - return Some(Arc::new(class.head().clone()) as _); - } - } - None - }); - Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) - } else { - Transformed::No(expr) - }) - }) - .unwrap_or(expr) -} - -fn normalize_sort_requirement_with_equivalence_properties( - mut sort_requirement: PhysicalSortRequirement, - eq_properties: &[EquivalentClass], -) -> PhysicalSortRequirement { - sort_requirement.expr = - normalize_expr_with_equivalence_properties(sort_requirement.expr, eq_properties); - sort_requirement -} - -/// This function searches for the slice `section` inside the slice `given`. -/// It returns each range where `section` is compatible with the corresponding -/// slice in `given`. -fn get_compatible_ranges( - given: &[PhysicalSortRequirement], - section: &[PhysicalSortRequirement], -) -> Vec> { - let n_section = section.len(); - let n_end = if given.len() >= n_section { - given.len() - n_section + 1 - } else { - 0 - }; - (0..n_end) - .filter_map(|idx| { - let end = idx + n_section; - given[idx..end] - .iter() - .zip(section) - .all(|(req, given)| given.compatible(req)) - .then_some(Range { start: idx, end }) - }) - .collect() -} - -/// This function constructs a duplicate-free vector by filtering out duplicate -/// entries inside the given vector `input`. -fn collapse_vec(input: Vec) -> Vec { - let mut output = vec![]; - for item in input { - if !output.contains(&item) { - output.push(item); - } - } - output -} - -/// Transform `sort_exprs` vector, to standardized version using `eq_properties` and `ordering_eq_properties` -/// Assume `eq_properties` states that `Column a` and `Column b` are aliases. -/// Also assume `ordering_eq_properties` states that ordering `vec![d ASC]` and `vec![a ASC, c ASC]` are -/// ordering equivalent (in the sense that both describe the ordering of the table). -/// If the `sort_exprs` input to this function were `vec![b ASC, c ASC]`, -/// This function converts `sort_exprs` `vec![b ASC, c ASC]` to first `vec![a ASC, c ASC]` after considering `eq_properties` -/// Then converts `vec![a ASC, c ASC]` to `vec![d ASC]` after considering `ordering_eq_properties`. -/// Standardized version `vec![d ASC]` is used in subsequent operations. -pub fn normalize_sort_exprs( - sort_exprs: &[PhysicalSortExpr], - eq_properties: &[EquivalentClass], - ordering_eq_properties: &[OrderingEquivalentClass], -) -> Vec { - let sort_requirements = PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()); - let normalized_exprs = normalize_sort_requirements( - &sort_requirements, - eq_properties, - ordering_eq_properties, - ); - let normalized_exprs = PhysicalSortRequirement::to_sort_exprs(normalized_exprs); - collapse_vec(normalized_exprs) -} - -/// Transform `sort_reqs` vector, to standardized version using `eq_properties` and `ordering_eq_properties` -/// Assume `eq_properties` states that `Column a` and `Column b` are aliases. -/// Also assume `ordering_eq_properties` states that ordering `vec![d ASC]` and `vec![a ASC, c ASC]` are -/// ordering equivalent (in the sense that both describe the ordering of the table). -/// If the `sort_reqs` input to this function were `vec![b Some(ASC), c None]`, -/// This function converts `sort_exprs` `vec![b Some(ASC), c None]` to first `vec![a Some(ASC), c None]` after considering `eq_properties` -/// Then converts `vec![a Some(ASC), c None]` to `vec![d Some(ASC)]` after considering `ordering_eq_properties`. -/// Standardized version `vec![d Some(ASC)]` is used in subsequent operations. -pub fn normalize_sort_requirements( - sort_reqs: &[PhysicalSortRequirement], - eq_properties: &[EquivalentClass], - ordering_eq_properties: &[OrderingEquivalentClass], -) -> Vec { - let mut normalized_exprs = sort_reqs - .iter() - .map(|sort_req| { - normalize_sort_requirement_with_equivalence_properties( - sort_req.clone(), - eq_properties, - ) - }) - .collect::>(); - for ordering_eq_class in ordering_eq_properties { - for item in ordering_eq_class.others() { - let item = item - .clone() - .into_iter() - .map(|elem| elem.into()) - .collect::>(); - let ranges = get_compatible_ranges(&normalized_exprs, &item); - let mut offset: i64 = 0; - for Range { start, end } in ranges { - let mut head = ordering_eq_class - .head() - .clone() - .into_iter() - .map(|elem| elem.into()) - .collect::>(); - let updated_start = (start as i64 + offset) as usize; - let updated_end = (end as i64 + offset) as usize; - let range = end - start; - offset += head.len() as i64 - range as i64; - let all_none = normalized_exprs[updated_start..updated_end] - .iter() - .all(|req| req.options.is_none()); - if all_none { - for req in head.iter_mut() { - req.options = None; - } - } - normalized_exprs.splice(updated_start..updated_end, head); - } - } - } - collapse_vec(normalized_exprs) -} - -/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. -pub fn ordering_satisfy< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => ordering_satisfy_concrete( - provided, - required, - equal_properties, - ordering_equal_properties, - ), - } -} - -/// Checks whether the required [`PhysicalSortExpr`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_concrete< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: &[PhysicalSortExpr], - required: &[PhysicalSortExpr], - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let ordering_eq_classes = oeq_properties.classes(); - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - let required_normalized = - normalize_sort_exprs(required, eq_classes, ordering_eq_classes); - let provided_normalized = - normalize_sort_exprs(provided, eq_classes, ordering_eq_classes); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given == req) -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_requirement< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortRequirement]>, - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => ordering_satisfy_requirement_concrete( - provided, - required, - equal_properties, - ordering_equal_properties, - ), - } -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are satisfied by the -/// provided [`PhysicalSortExpr`]s. -pub fn ordering_satisfy_requirement_concrete< - F: FnOnce() -> EquivalenceProperties, - F2: FnOnce() -> OrderingEquivalenceProperties, ->( - provided: &[PhysicalSortExpr], - required: &[PhysicalSortRequirement], - equal_properties: F, - ordering_equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let ordering_eq_classes = oeq_properties.classes(); - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - let required_normalized = - normalize_sort_requirements(required, eq_classes, ordering_eq_classes); - let provided_normalized = - normalize_sort_exprs(provided, eq_classes, ordering_eq_classes); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given.satisfy(&req)) -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are equal or more -/// specific than the provided [`PhysicalSortRequirement`]s. -pub fn requirements_compatible< - F: FnOnce() -> OrderingEquivalenceProperties, - F2: FnOnce() -> EquivalenceProperties, ->( - provided: Option<&[PhysicalSortRequirement]>, - required: Option<&[PhysicalSortRequirement]>, - ordering_equal_properties: F, - equal_properties: F2, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => requirements_compatible_concrete( - provided, - required, - ordering_equal_properties, - equal_properties, - ), - } -} - -/// Checks whether the given [`PhysicalSortRequirement`]s are equal or more -/// specific than the provided [`PhysicalSortRequirement`]s. -fn requirements_compatible_concrete< - F: FnOnce() -> OrderingEquivalenceProperties, - F2: FnOnce() -> EquivalenceProperties, ->( - provided: &[PhysicalSortRequirement], - required: &[PhysicalSortRequirement], - ordering_equal_properties: F, - equal_properties: F2, -) -> bool { - let oeq_properties = ordering_equal_properties(); - let ordering_eq_classes = oeq_properties.classes(); - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - - let required_normalized = - normalize_sort_requirements(required, eq_classes, ordering_eq_classes); - let provided_normalized = - normalize_sort_requirements(provided, eq_classes, ordering_eq_classes); - if required_normalized.len() > provided_normalized.len() { - return false; - } - required_normalized - .into_iter() - .zip(provided_normalized) - .all(|(req, given)| given.compatible(&req)) -} - /// This function maps back requirement after ProjectionExec /// to the Executor for its input. // Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor. @@ -472,33 +108,9 @@ pub fn convert_to_expr>( .collect() } -/// This function finds the indices of `targets` within `items`, taking into -/// account equivalences according to `equal_properties`. -pub fn get_indices_of_matching_exprs< - T: Borrow>, - F: FnOnce() -> EquivalenceProperties, ->( - targets: impl IntoIterator, - items: &[Arc], - equal_properties: F, -) -> Vec { - if let eq_classes @ [_, ..] = equal_properties().classes() { - let normalized_targets = targets.into_iter().map(|e| { - normalize_expr_with_equivalence_properties(e.borrow().clone(), eq_classes) - }); - let normalized_items = items - .iter() - .map(|e| normalize_expr_with_equivalence_properties(e.clone(), eq_classes)) - .collect::>(); - get_indices_of_exprs_strict(normalized_targets, &normalized_items) - } else { - get_indices_of_exprs_strict(targets, items) - } -} - /// This function finds the indices of `targets` within `items` using strict /// equality. -fn get_indices_of_exprs_strict>>( +pub fn get_indices_of_exprs_strict>>( targets: impl IntoIterator, items: &[Arc], ) -> Vec { @@ -517,10 +129,11 @@ pub struct ExprTreeNode { impl ExprTreeNode { pub fn new(expr: Arc) -> Self { + let children = expr.children(); ExprTreeNode { expr, data: None, - child_nodes: vec![], + child_nodes: children.into_iter().map(Self::new).collect_vec(), } } @@ -528,12 +141,8 @@ impl ExprTreeNode { &self.expr } - pub fn children(&self) -> Vec> { - self.expr - .children() - .into_iter() - .map(ExprTreeNode::new) - .collect() + pub fn children(&self) -> &[ExprTreeNode] { + &self.child_nodes } } @@ -543,7 +152,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(&Self) -> Result, { for child in self.children() { - match op(&child)? { + match op(child)? { VisitRecursion::Continue => {} VisitRecursion::Skip => return Ok(VisitRecursion::Continue), VisitRecursion::Stop => return Ok(VisitRecursion::Stop), @@ -558,7 +167,7 @@ impl TreeNode for ExprTreeNode { F: FnMut(Self) -> Result, { self.child_nodes = self - .children() + .child_nodes .into_iter() .map(transform) .collect::>>()?; @@ -571,7 +180,7 @@ impl TreeNode for ExprTreeNode { /// identical expressions in one node. Caller specifies the node type in the /// DAEG via the `constructor` argument, which constructs nodes in the DAEG /// from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, // A vector of visited expression nodes and their corresponding node indices. @@ -580,7 +189,7 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter for PhysicalExprDAEGBuilder<'a, T, F> { type N = ExprTreeNode; @@ -601,7 +210,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> T> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)); + let node_idx = self.graph.add_node((self.constructor)(&node)?); for expr_node in node.child_nodes.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } @@ -622,7 +231,7 @@ pub fn build_dag( constructor: &F, ) -> Result<(NodeIndex, StableGraph)> where - F: Fn(&ExprTreeNode) -> T, + F: Fn(&ExprTreeNode) -> Result, { // Create a new expression tree node from the input expression. let init = ExprTreeNode::new(expr); @@ -692,44 +301,76 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec EquivalenceProperties, - F2: Fn() -> OrderingEquivalenceProperties, ->( - req1: &'a [PhysicalSortExpr], - req2: &'a [PhysicalSortExpr], - eq_properties: F, - ordering_eq_properties: F2, -) -> Option<&'a [PhysicalSortExpr]> { - if ordering_satisfy_concrete(req1, req2, &eq_properties, &ordering_eq_properties) { - // Finer requirement is `provided`, since it satisfies the other: - return Some(req1); - } - if ordering_satisfy_concrete(req2, req1, &eq_properties, &ordering_eq_properties) { - // Finer requirement is `req`, since it satisfies the other: - return Some(req2); +/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` +/// are taken, when the mask evaluates `false` values null values are filled. +/// +/// # Arguments +/// * `mask` - Boolean values used to determine where to put the `truthy` values +/// * `truthy` - All values of this array are to scatter according to `mask` into final result. +pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { + let truthy = truthy.to_data(); + + // update the mask so that any null values become false + // (SlicesIterator doesn't respect nulls) + let mask = and_kleene(mask, &is_not_null(mask)?)?; + + let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); + + // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to + // fill with falsy values + + // keep track of how much is filled + let mut filled = 0; + // keep track of current position we have in truthy array + let mut true_pos = 0; + + SlicesIterator::new(&mask).for_each(|(start, end)| { + // the gap needs to be filled with nulls + if start > filled { + mutable.extend_nulls(start - filled); + } + // fill with truthy values + let len = end - start; + mutable.extend(0, true_pos, true_pos + len); + true_pos += len; + filled = end; + }); + // the remaining part is falsy + if filled < mask.len() { + mutable.extend_nulls(mask.len() - filled); } - // Neither `provided` nor `req` satisfies one another, they are incompatible. - None + + let data = mutable.freeze(); + Ok(make_array(data)) +} + +/// Merge left and right sort expressions, checking for duplicates. +pub fn merge_vectors( + left: &[PhysicalSortExpr], + right: &[PhysicalSortExpr], +) -> Vec { + left.iter() + .cloned() + .chain(right.iter().cloned()) + .unique() + .collect() } #[cfg(test)] mod tests { + use std::fmt::{Display, Formatter}; + use std::sync::Arc; + use super::*; use crate::expressions::{binary, cast, col, in_list, lit, Column, Literal}; use crate::PhysicalSortExpr; - use arrow::compute::SortOptions; - use datafusion_common::{Result, ScalarValue}; - use std::fmt::{Display, Formatter}; - use crate::equivalence::OrderingEquivalenceProperties; + use arrow_array::Int32Array; use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use datafusion_common::{Result, ScalarValue}; + use petgraph::visit::Bfs; - use std::sync::Arc; #[derive(Clone)] struct DummyProperty { @@ -750,7 +391,7 @@ mod tests { } } - fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { + fn make_dummy_node(node: &ExprTreeNode) -> Result { let expr = node.expression().clone(); let dummy_property = if expr.as_any().is::() { "Binary" @@ -762,85 +403,12 @@ mod tests { "Other" } .to_owned(); - PhysicalExprDummyNode { + Ok(PhysicalExprDummyNode { expr, property: DummyProperty { expr_type: dummy_property, }, - } - } - - // Generate a schema which consists of 5 columns (a, b, c, d, e) - fn create_test_schema() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, true); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, true); - let e = Field::new("e", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); - - Ok(schema) - } - - fn create_test_params() -> Result<( - SchemaRef, - EquivalenceProperties, - OrderingEquivalenceProperties, - )> { - // Assume schema satisfies ordering a ASC NULLS LAST - // and d ASC NULLS LAST, b ASC NULLS LAST and e DESC NULLS FIRST, b ASC NULLS LAST - // Assume that column a and c are aliases. - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - let test_schema = create_test_schema()?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions((col_a, col_c)); - let mut ordering_eq_properties = - OrderingEquivalenceProperties::new(test_schema.clone()); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![ - PhysicalSortExpr { - expr: Arc::new(col_d.clone()), - options: option1, - }, - PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: option1, - }, - ], - )); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![ - PhysicalSortExpr { - expr: Arc::new(col_e.clone()), - options: option2, - }, - PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: option1, - }, - ], - )); - Ok((test_schema, eq_properties, ordering_eq_properties)) + }) } #[test] @@ -921,9 +489,7 @@ mod tests { } #[test] - fn test_get_indices_of_matching_exprs() { - let empty_schema = &Arc::new(Schema::empty()); - let equal_properties = || EquivalenceProperties::new(empty_schema.clone()); + fn test_get_indices_of_exprs_strict() { let list1: Vec> = vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), @@ -935,274 +501,8 @@ mod tests { Arc::new(Column::new("c", 2)), Arc::new(Column::new("a", 0)), ]; - assert_eq!( - get_indices_of_matching_exprs(&list1, &list2, equal_properties), - vec![2, 0, 1] - ); - assert_eq!( - get_indices_of_matching_exprs(&list2, &list1, equal_properties), - vec![1, 2, 0] - ); - } - - #[test] - fn expr_list_eq_test() -> Result<()> { - let list1: Vec> = vec![ - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - ]; - let list2: Vec> = vec![ - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("a", 0)), - ]; - assert!(!expr_list_eq_any_order(list1.as_slice(), list2.as_slice())); - assert!(!expr_list_eq_any_order(list2.as_slice(), list1.as_slice())); - - assert!(!expr_list_eq_strict_order( - list1.as_slice(), - list2.as_slice() - )); - assert!(!expr_list_eq_strict_order( - list2.as_slice(), - list1.as_slice() - )); - - let list3: Vec> = vec![ - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("b", 1)), - ]; - let list4: Vec> = vec![ - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("b", 1)), - Arc::new(Column::new("a", 0)), - Arc::new(Column::new("c", 2)), - Arc::new(Column::new("a", 0)), - ]; - assert!(expr_list_eq_any_order(list3.as_slice(), list4.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list3.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list4.as_slice())); - - assert!(!expr_list_eq_strict_order( - list3.as_slice(), - list4.as_slice() - )); - assert!(!expr_list_eq_strict_order( - list4.as_slice(), - list3.as_slice() - )); - assert!(expr_list_eq_any_order(list3.as_slice(), list3.as_slice())); - assert!(expr_list_eq_any_order(list4.as_slice(), list4.as_slice())); - - Ok(()) - } - - #[test] - fn test_ordering_satisfy() -> Result<()> { - let crude = vec![PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }]; - let crude = Some(&crude[..]); - let finer = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - let finer = Some(&finer[..]); - let empty_schema = &Arc::new(Schema::empty()); - assert!(ordering_satisfy( - finer, - crude, - || { EquivalenceProperties::new(empty_schema.clone()) }, - || { OrderingEquivalenceProperties::new(empty_schema.clone()) }, - )); - assert!(!ordering_satisfy( - crude, - finer, - || { EquivalenceProperties::new(empty_schema.clone()) }, - || { OrderingEquivalenceProperties::new(empty_schema.clone()) }, - )); - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - // The schema is ordered by a ASC NULLS LAST, b ASC NULLS LAST - let provided = vec![ - PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }, - PhysicalSortExpr { - expr: Arc::new(col_b.clone()), - options: option1, - }, - ]; - let provided = Some(&provided[..]); - let (_test_schema, eq_properties, ordering_eq_properties) = create_test_params()?; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option1)], true), - (vec![(col_a, option2)], false), - // Test whether equivalence works as expected - (vec![(col_c, option1)], true), - (vec![(col_c, option2)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option1)], false), - (vec![(col_d, option1), (col_b, option1)], true), - (vec![(col_d, option2), (col_b, option1)], false), - (vec![(col_e, option2), (col_b, option1)], true), - (vec![(col_e, option1), (col_b, option1)], false), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_d, option1), - (col_b, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option2), - (col_b, option1), - ], - true, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_d, option2), - (col_b, option1), - ], - false, - ), - ( - vec![ - (col_d, option1), - (col_b, option1), - (col_e, option1), - (col_b, option1), - ], - false, - ), - ]; - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(col, options)| PhysicalSortExpr { - expr: Arc::new(col.clone()), - options, - }) - .collect::>(); - - let required = Some(&required[..]); - assert_eq!( - ordering_satisfy( - provided, - required, - || eq_properties.clone(), - || ordering_eq_properties.clone(), - ), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - fn convert_to_requirement( - in_data: &[(&Column, Option)], - ) -> Vec { - in_data - .iter() - .map(|(col, options)| { - PhysicalSortRequirement::new(Arc::new((*col).clone()) as _, *options) - }) - .collect::>() - } - - #[test] - fn test_normalize_sort_reqs() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let option2 = SortOptions { - descending: true, - nulls_first: true, - }; - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - (vec![(col_a, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_a, None)], vec![(col_a, None)]), - // Test whether equivalence works as expected - (vec![(col_c, Some(option1))], vec![(col_a, Some(option1))]), - (vec![(col_c, None)], vec![(col_a, None)]), - // Test whether ordering equivalence works as expected - ( - vec![(col_d, Some(option1)), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - ), - (vec![(col_d, None), (col_b, None)], vec![(col_a, None)]), - ( - vec![(col_e, Some(option2)), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - ), - // We should be able to normalize in compatible requirements also (not exactly equal) - ( - vec![(col_e, Some(option2)), (col_b, None)], - vec![(col_a, Some(option1))], - ), - (vec![(col_e, None), (col_b, None)], vec![(col_a, None)]), - ]; - let (_test_schema, eq_properties, ordering_eq_properties) = create_test_params()?; - let eq_classes = eq_properties.classes(); - let ordering_eq_classes = ordering_eq_properties.classes(); - for (reqs, expected_normalized) in requirements.into_iter() { - let req = convert_to_requirement(&reqs); - let expected_normalized = convert_to_requirement(&expected_normalized); - - assert_eq!( - normalize_sort_requirements(&req, eq_classes, ordering_eq_classes), - expected_normalized - ); - } - Ok(()) + assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]); + assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]); } #[test] @@ -1243,225 +543,88 @@ mod tests { } #[test] - fn test_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let _col_d = &Column::new("d", 3); - let _col_e = &Column::new("e", 4); - // Assume that column a and c are aliases. - let (_test_schema, eq_properties, _ordering_eq_properties) = - create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. - let expressions = vec![ - // Normalized version of the column a and c should go to a (since a is head) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), - // Cannot normalize column b - (&col_b_expr, &col_b_expr), - ]; - for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&normalize_expr_with_equivalence_properties( - expr.clone(), - eq_properties.classes() - )), - "error in test: expr: {expr:?}" - ); - } - + fn test_collect_columns() -> Result<()> { + let expr1 = Arc::new(Column::new("col1", 2)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col1", 2)); + assert_eq!(collect_columns(&expr1), expected); + + let expr2 = Arc::new(Column::new("col2", 5)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col2", 5)); + assert_eq!(collect_columns(&expr2), expected); + + let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _; + let mut expected = HashSet::new(); + expected.insert(Column::new("col1", 2)); + expected.insert(Column::new("col2", 5)); + assert_eq!(collect_columns(&expr3), expected); Ok(()) } #[test] - fn test_normalize_sort_requirement_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let _col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let _col_e = &Column::new("e", 4); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Assume that column a and c are aliases. - let (_test_schema, eq_properties, _ordering_eq_properties) = - create_test_params()?; + fn scatter_int() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); - // Test cases for equivalence normalization - // First entry in the tuple is PhysicalExpr, second entry is its ordering, third entry is result after normalization. - let expressions = vec![ - (&col_a, Some(option1), &col_a, Some(option1)), - (&col_c, Some(option1), &col_a, Some(option1)), - (&col_c, None, &col_a, None), - // Cannot normalize column d, since it is not in equivalence properties. - (&col_d, Some(option1), &col_d, Some(option1)), - ]; - for (expr, sort_options, expected_col, expected_options) in - expressions.into_iter() - { - let expected = PhysicalSortRequirement::new( - Arc::new((*expected_col).clone()) as _, - expected_options, - ); - let arg = PhysicalSortRequirement::new( - Arc::new((*expr).clone()) as _, - sort_options, - ); - assert!( - expected.eq(&normalize_sort_requirement_with_equivalence_properties( - arg.clone(), - eq_properties.classes() - )), - "error in test: expr: {expr:?}, sort_options: {sort_options:?}" - ); - } + // the output array is expected to be the same length as the mask array + let expected = + Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + assert_eq!(&expected, result); Ok(()) } #[test] - fn test_ordering_satisfy_different_lengths() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); - let col_d = &Column::new("d", 3); - let col_e = &Column::new("e", 4); - let test_schema = create_test_schema()?; - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - // Column a and c are aliases. - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions((col_a, col_c)); + fn scatter_int_end_with_false() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, false, true, false, false, false]); - // Column a and e are ordering equivalent (e.g global ordering of the table can be described both as a ASC and e ASC.) - let mut ordering_eq_properties = OrderingEquivalenceProperties::new(test_schema); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: option1, - }], - &vec![PhysicalSortExpr { - expr: Arc::new(col_e.clone()), - options: option1, - }], - )); - let sort_req_a = PhysicalSortExpr { - expr: Arc::new((col_a).clone()) as _, - options: option1, - }; - let sort_req_b = PhysicalSortExpr { - expr: Arc::new((col_b).clone()) as _, - options: option1, - }; - let sort_req_c = PhysicalSortExpr { - expr: Arc::new((col_c).clone()) as _, - options: option1, - }; - let sort_req_d = PhysicalSortExpr { - expr: Arc::new((col_d).clone()) as _, - options: option1, - }; - let sort_req_e = PhysicalSortExpr { - expr: Arc::new((col_e).clone()) as _, - options: option1, - }; - - assert!(ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC, d ASC - &[sort_req_a.clone(), sort_req_b.clone(), sort_req_d.clone()], - // After normalization would be a ASC, b ASC, d ASC - &[ - sort_req_c.clone(), - sort_req_b.clone(), - sort_req_a.clone(), - sort_req_d.clone(), - sort_req_e.clone(), - ], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); - - assert!(!ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC - &[sort_req_a.clone(), sort_req_b.clone()], - // After normalization would be a ASC, b ASC, d ASC - &[ - sort_req_c.clone(), - sort_req_b.clone(), - sort_req_a.clone(), - sort_req_d.clone(), - sort_req_e.clone(), - ], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); - - assert!(!ordering_satisfy_concrete( - // After normalization would be a ASC, b ASC, d ASC - &[sort_req_a.clone(), sort_req_b.clone(), sort_req_d.clone()], - // After normalization would be a ASC, d ASC, b ASC - &[sort_req_c, sort_req_d, sort_req_a, sort_req_b, sort_req_e,], - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )); + // output should be same length as mask + let expected = + Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + assert_eq!(&expected, result); Ok(()) } #[test] - fn test_get_compatible_ranges() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let option1 = SortOptions { - descending: false, - nulls_first: false, - }; - let test_data = vec![ - ( - vec![(col_a, Some(option1)), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - vec![(0, 1)], - ), - ( - vec![(col_a, None), (col_b, Some(option1))], - vec![(col_a, Some(option1))], - vec![(0, 1)], - ), - ( - vec![ - (col_a, None), - (col_b, Some(option1)), - (col_a, Some(option1)), - ], - vec![(col_a, Some(option1))], - vec![(0, 1), (2, 3)], - ), - ]; - for (searched, to_search, expected) in test_data { - let searched = convert_to_requirement(&searched); - let to_search = convert_to_requirement(&to_search); - let expected = expected - .into_iter() - .map(|(start, end)| Range { start, end }) - .collect::>(); - assert_eq!(get_compatible_ranges(&searched, &to_search), expected); - } + fn scatter_with_null_mask() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); + let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] + .into_iter() + .collect(); + + // output should treat nulls as though they are false + let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); Ok(()) } #[test] - fn test_collapse_vec() -> Result<()> { - assert_eq!(collapse_vec(vec![1, 2, 3]), vec![1, 2, 3]); - assert_eq!(collapse_vec(vec![1, 2, 3, 2, 3]), vec![1, 2, 3]); - assert_eq!(collapse_vec(vec![3, 1, 2, 3, 2, 3]), vec![3, 1, 2]); + fn scatter_boolean() -> Result<()> { + let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = BooleanArray::from_iter(vec![ + Some(false), + Some(false), + None, + None, + Some(false), + ]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_boolean_array(&result)?; + + assert_eq!(&expected, result); Ok(()) } } diff --git a/datafusion/physical-expr/src/var_provider.rs b/datafusion/physical-expr/src/var_provider.rs index faa07665e4f33..e00cf74072377 100644 --- a/datafusion/physical-expr/src/var_provider.rs +++ b/datafusion/physical-expr/src/var_provider.rs @@ -29,7 +29,7 @@ pub enum VarType { UserDefined, } -/// A var provider for @variable +/// A var provider for `@variable` and `@@variable` runtime values. pub trait VarProvider: std::fmt::Debug { /// Get variable value fn get_value(&self, var_names: Vec) -> Result; diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index c8a4797a52880..5892f7f3f3b05 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -155,8 +155,7 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - self.aggregate.supports_bounded_execution() - && !self.window_frame.end_bound.is_unbounded() + !self.window_frame.end_bound.is_unbounded() } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 030c20c5743c6..665ceb70d6584 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -21,22 +21,19 @@ use std::any::Any; use std::ops::Range; use std::sync::Arc; -use super::window_frame_state::WindowFrameContext; -use super::BuiltInWindowFunctionExpr; -use super::WindowExpr; -use crate::window::window_expr::{ - BuiltinWindowState, NthValueKind, NthValueState, WindowFn, -}; -use crate::window::{ - PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, -}; -use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; -use arrow::array::{new_empty_array, Array, ArrayRef}; +use super::{BuiltInWindowFunctionExpr, WindowExpr}; +use crate::expressions::PhysicalSortExpr; +use crate::window::window_expr::{get_orderby_values, WindowFn}; +use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; +use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; + +use arrow::array::{new_empty_array, ArrayRef}; use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. @@ -68,6 +65,35 @@ impl BuiltInWindowExpr { pub fn get_built_in_func_expr(&self) -> &Arc { &self.expr } + + /// Adds any equivalent orderings generated by the `self.expr` + /// to `builder`. + /// + /// If `self.expr` doesn't have an ordering, ordering equivalence properties + /// are not updated. Otherwise, ordering equivalence properties are updated + /// by the ordering of `self.expr`. + pub fn add_equal_orderings(&self, eq_properties: &mut EquivalenceProperties) { + let schema = eq_properties.schema(); + if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { + if self.partition_by.is_empty() { + // In the absence of a PARTITION BY, ordering of `self.expr` is global: + eq_properties.add_new_orderings([vec![fn_res_ordering]]); + } else { + // If we have a PARTITION BY, built-in functions can not introduce + // a global ordering unless the existing ordering is compatible + // with PARTITION BY expressions. To elaborate, when PARTITION BY + // expressions and existing ordering expressions are equal (w.r.t. + // set equality), we can prefix the ordering of `self.expr` with + // the existing ordering. + let (mut ordering, _) = + eq_properties.find_longest_permutation(&self.partition_by); + if ordering.len() == self.partition_by.len() { + ordering.push(fn_res_ordering); + eq_properties.add_new_orderings([ordering]); + } + } + } + } } impl WindowExpr for BuiltInWindowExpr { @@ -97,37 +123,42 @@ impl WindowExpr for BuiltInWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let evaluator = self.expr.create_evaluator()?; + let mut evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - if self.expr.uses_window_frame() { + if evaluator.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; - let (values, order_bys) = self.get_values_orderbys(batch)?; + let mut values = self.evaluate_args(batch)?; + let order_bys = get_orderby_values(self.order_by_columns(batch)?); + let n_args = values.len(); + values.extend(order_bys); + let order_bys_ref = &values[n_args..]; + let mut window_frame_ctx = WindowFrameContext::new(self.window_frame.clone(), sort_options); let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { let range = window_frame_ctx.calculate_range( - &order_bys, + order_bys_ref, &last_range, num_rows, idx, )?; - let value = evaluator.evaluate_inside_range(&values, &range)?; + let value = evaluator.evaluate(&values, &range)?; row_wise_results.push(value); last_range = range; } - ScalarValue::iter_to_array(row_wise_results.into_iter()) + ScalarValue::iter_to_array(row_wise_results) } else if evaluator.include_rank() { - let columns = self.sort_columns(batch)?; + let columns = self.order_by_columns(batch)?; let sort_partition_points = evaluate_partition_ranges(num_rows, &columns)?; - evaluator.evaluate_with_rank(num_rows, &sort_partition_points) + evaluator.evaluate_all_with_rank(num_rows, &sort_partition_points) } else { - let (values, _) = self.get_values_orderbys(batch)?; - evaluator.evaluate(&values, num_rows) + let values = self.evaluate_args(batch)?; + evaluator.evaluate_all(&values, num_rows) } } @@ -160,21 +191,23 @@ impl WindowExpr for BuiltInWindowExpr { }; let state = &mut window_state.state; - let (values, order_bys) = - self.get_values_orderbys(&partition_batch_state.record_batch)?; + let batch_ref = &partition_batch_state.record_batch; + let mut values = self.evaluate_args(batch_ref)?; + let order_bys = if evaluator.uses_window_frame() || evaluator.include_rank() { + get_orderby_values(self.order_by_columns(batch_ref)?) + } else { + vec![] + }; + let n_args = values.len(); + values.extend(order_bys); + let order_bys_ref = &values[n_args..]; // We iterate on each row to perform a running calculation. let record_batch = &partition_batch_state.record_batch; let num_rows = record_batch.num_rows(); - let sort_partition_points = if evaluator.include_rank() { - let columns = self.sort_columns(record_batch)?; - evaluate_partition_ranges(num_rows, &columns)? - } else { - vec![] - }; let mut row_wise_results: Vec = vec![]; for idx in state.last_calculated_index..num_rows { - let frame_range = if self.expr.uses_window_frame() { + let frame_range = if evaluator.uses_window_frame() { state .window_frame_ctx .get_or_insert_with(|| { @@ -184,7 +217,7 @@ impl WindowExpr for BuiltInWindowExpr { ) }) .calculate_range( - &order_bys, + order_bys_ref, // Start search from the last range &state.window_frame_range, num_rows, @@ -200,8 +233,8 @@ impl WindowExpr for BuiltInWindowExpr { } // Update last range state.window_frame_range = frame_range; - evaluator.update_state(state, idx, &order_bys, &sort_partition_points)?; - row_wise_results.push(evaluator.evaluate_stateful(&values)?); + row_wise_results + .push(evaluator.evaluate(&values, &state.window_frame_range)?); } let out_col = if row_wise_results.is_empty() { new_empty_array(out_type) @@ -211,13 +244,7 @@ impl WindowExpr for BuiltInWindowExpr { state.update(&out_col, partition_batch_state)?; if self.window_frame.start_bound.is_unbounded() { - let mut evaluator_state = evaluator.state()?; - if let BuiltinWindowState::NthValue(nth_value_state) = - &mut evaluator_state - { - memoize_nth_value(state, nth_value_state)?; - evaluator.set_state(&evaluator_state)?; - } + evaluator.memoize(state)?; } } Ok(()) @@ -239,40 +266,12 @@ impl WindowExpr for BuiltInWindowExpr { } fn uses_bounded_memory(&self) -> bool { - self.expr.supports_bounded_execution() - && (!self.expr.uses_window_frame() - || !self.window_frame.end_bound.is_unbounded()) - } -} - -// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), for -// FIRST_VALUE, LAST_VALUE and NTH_VALUE functions: we can memoize result. -// Once result is calculated it will always stay same. Hence, we do not -// need to keep past data as we process the entire dataset. This feature -// enables us to prune rows from table. -fn memoize_nth_value( - state: &mut WindowAggState, - nth_value_state: &mut NthValueState, -) -> Result<()> { - let out = &state.out_col; - let size = out.len(); - let (is_prunable, new_prunable) = match nth_value_state.kind { - NthValueKind::First => { - let n_range = state.window_frame_range.end - state.window_frame_range.start; - (n_range > 0 && size > 0, true) - } - NthValueKind::Last => (true, false), - NthValueKind::Nth(n) => { - let n_range = state.window_frame_range.end - state.window_frame_range.start; - (n_range >= (n as usize) && size >= (n as usize), true) - } - }; - if is_prunable { - if nth_value_state.finalized_result.is_none() && new_prunable { - let result = ScalarValue::try_from_array(out, size - 1)?; - nth_value_state.finalized_result = Some(result); + if let Ok(evaluator) = self.expr.create_evaluator() { + evaluator.supports_bounded_execution() + && (!evaluator.uses_window_frame() + || !self.window_frame.end_bound.is_unbounded()) + } else { + false } - state.window_frame_range.start = state.window_frame_range.end.saturating_sub(1); } - Ok(()) } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 59438a72f2759..7aa4f6536a6e4 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. -use super::partition_evaluator::PartitionEvaluator; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; +use arrow_schema::SchemaRef; use datafusion_common::Result; +use datafusion_expr::PartitionEvaluator; + use std::any::Any; use std::sync::Arc; @@ -34,7 +37,7 @@ use std::sync::Arc; /// `nth_value` need the value. #[allow(rustdoc::private_intra_doc_links)] pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { - /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -57,8 +60,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } @@ -79,20 +84,11 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { None } - /// Can the window function be incrementally computed using - /// bounded memory? - /// - /// If this function returns true, [`Self::create_evaluator`] must - /// implement [`PartitionEvaluator::evaluate_stateful`] - fn supports_bounded_execution(&self) -> bool { - false - } - - /// Does the window function use the values from its window frame? - /// - /// If this function returns true, [`Self::create_evaluator`] must - /// implement [`PartitionEvaluator::evaluate_inside_range`] - fn uses_window_frame(&self) -> bool { - false + /// Returns the ordering introduced by the window function, if applicable. + /// Most window functions don't introduce an ordering, hence the default + /// value is `None`. Note that this information is used to update ordering + /// equivalences. + fn get_result_ordering(&self, _schema: &SchemaRef) -> Option { + None } } diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 46997578001d3..edef77c51c315 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -18,13 +18,13 @@ //! Defines physical expression for `cume_dist` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::Float64Array; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; +use datafusion_expr::PartitionEvaluator; use std::any::Any; use std::iter; use std::ops::Range; @@ -70,11 +70,7 @@ impl BuiltInWindowFunctionExpr for CumeDist { pub(crate) struct CumeDistEvaluator; impl PartitionEvaluator for CumeDistEvaluator { - fn include_rank(&self) -> bool { - true - } - - fn evaluate_with_rank( + fn evaluate_all_with_rank( &self, num_rows: usize, ranks_in_partition: &[Range], @@ -94,6 +90,10 @@ impl PartitionEvaluator for CumeDistEvaluator { ); Ok(Arc::new(result)) } + + fn include_rank(&self) -> bool { + true + } } #[cfg(test)] @@ -109,7 +109,7 @@ mod tests { ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(num_rows, &ranks)?; + .evaluate_all_with_rank(num_rows, &ranks)?; let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); @@ -117,6 +117,7 @@ mod tests { } #[test] + #[allow(clippy::single_range_in_vec_init)] fn test_cume_dist() -> Result<()> { let r = cume_dist("arr".into()); diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 8d97d5ebc0b33..d22660d41ebd7 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -18,15 +18,14 @@ //! Defines physical expression for `lead` and `lag` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_expr::{BuiltinWindowState, LeadLagState}; -use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; +use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::PartitionEvaluator; use std::any::Any; use std::cmp::min; use std::ops::{Neg, Range}; @@ -47,6 +46,11 @@ impl WindowShift { pub fn get_shift_offset(&self) -> i64 { self.shift_offset } + + /// Get the default_value for window shift expression. + pub fn get_default_value(&self) -> Option { + self.default_value.clone() + } } /// lead() window function @@ -104,16 +108,11 @@ impl BuiltInWindowFunctionExpr for WindowShift { fn create_evaluator(&self) -> Result> { Ok(Box::new(WindowShiftEvaluator { - state: LeadLagState { idx: 0 }, shift_offset: self.shift_offset, default_value: self.default_value.clone(), })) } - fn supports_bounded_execution(&self) -> bool { - true - } - fn reverse_expr(&self) -> Option> { Some(Arc::new(Self { name: self.name.clone(), @@ -127,7 +126,6 @@ impl BuiltInWindowFunctionExpr for WindowShift { #[derive(Debug)] pub(crate) struct WindowShiftEvaluator { - state: LeadLagState, shift_offset: i64, default_value: Option, } @@ -141,6 +139,7 @@ fn create_empty_array( let array = value .as_ref() .map(|scalar| scalar.to_array_of_size(size)) + .transpose()? .unwrap_or_else(|| new_null_array(data_type, size)); if array.data_type() != data_type { cast(&array, data_type).map_err(DataFusionError::ArrowError) @@ -182,22 +181,6 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::LeadLag(self.state.clone())) - } - - fn update_state( - &mut self, - _state: &WindowAggState, - idx: usize, - _range_columns: &[ArrayRef], - _sort_partition_points: &[Range], - ) -> Result<()> { - self.state.idx = idx; - Ok(()) - } - fn get_range(&self, idx: usize, n_rows: usize) -> Result> { if self.shift_offset > 0 { let offset = self.shift_offset as usize; @@ -211,10 +194,21 @@ impl PartitionEvaluator for WindowShiftEvaluator { } } - fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { let array = &values[0]; let dtype = array.data_type(); - let idx = self.state.idx as i64 - self.shift_offset; + // LAG mode + let idx = if self.shift_offset > 0 { + range.end as i64 - self.shift_offset - 1 + } else { + // LEAD mode + range.start as i64 - self.shift_offset + }; + if idx < 0 || idx as usize >= array.len() { get_default_value(self.default_value.as_ref(), dtype) } else { @@ -222,11 +216,19 @@ impl PartitionEvaluator for WindowShiftEvaluator { } } - fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { + fn evaluate_all( + &mut self, + values: &[ArrayRef], + _num_rows: usize, + ) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; shift_with_default_value(value, self.shift_offset, self.default_value.as_ref()) } + + fn supports_bounded_execution(&self) -> bool { + true + } } fn get_default_value( @@ -237,9 +239,7 @@ fn get_default_value( if let ScalarValue::Int64(Some(val)) = value { ScalarValue::try_from_string(val.to_string(), dtype) } else { - Err(DataFusionError::Internal( - "Expects default value to have Int64 type".to_string(), - )) + internal_err!("Expects default value to have Int64 type") } } else { Ok(ScalarValue::try_from(dtype)?) @@ -263,7 +263,7 @@ mod tests { let values = expr.evaluate_args(&batch)?; let result = expr .create_evaluator()? - .evaluate(&values, batch.num_rows())?; + .evaluate_all(&values, batch.num_rows())?; let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) @@ -279,7 +279,7 @@ mod tests { None, None, ), - vec![ + [ Some(-2), Some(3), Some(-4), @@ -301,7 +301,7 @@ mod tests { None, None, ), - vec![ + [ None, Some(1), Some(-2), @@ -323,7 +323,7 @@ mod tests { None, Some(ScalarValue::Int32(Some(100))), ), - vec![ + [ Some(100), Some(1), Some(-2), diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 4c8b8b5a4e4b2..644edae36c9ca 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -22,21 +22,18 @@ pub(crate) mod cume_dist; pub(crate) mod lead_lag; pub(crate) mod nth_value; pub(crate) mod ntile; -pub(crate) mod partition_evaluator; pub(crate) mod rank; pub(crate) mod row_number; mod sliding_aggregate; mod window_expr; -mod window_frame_state; pub use aggregate::PlainAggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; pub use sliding_aggregate::SlidingAggregateWindowExpr; -pub use window_expr::PartitionBatchState; +pub use window_expr::NthValueKind; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; -pub use window_expr::WindowAggState; pub use window_expr::WindowExpr; pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 4bfe514c38daf..b3c89122ebad2 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -15,20 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` -//! that can evaluated at runtime during query execution +//! Defines physical expressions for `FIRST_VALUE`, `LAST_VALUE`, and `NTH_VALUE` +//! functions that can be evaluated at run time during query execution. -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_expr::{BuiltinWindowState, NthValueKind, NthValueState}; -use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; +use std::any::Any; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; + +use crate::window::window_expr::{NthValueKind, NthValueState}; +use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; + use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::ScalarValue; +use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; +use datafusion_expr::window_state::WindowAggState; +use datafusion_expr::PartitionEvaluator; /// nth_value expression #[derive(Debug)] @@ -76,19 +80,17 @@ impl NthValue { n: u32, ) -> Result { match n { - 0 => Err(DataFusionError::Execution( - "nth_value expect n to be > 0".to_owned(), - )), + 0 => exec_err!("NTH_VALUE expects n to be non-zero"), _ => Ok(Self { name: name.into(), expr, data_type, - kind: NthValueKind::Nth(n), + kind: NthValueKind::Nth(n as i64), }), } } - /// Get nth_value kind + /// Get the NTH_VALUE kind pub fn get_kind(&self) -> NthValueKind { self.kind } @@ -122,19 +124,11 @@ impl BuiltInWindowFunctionExpr for NthValue { Ok(Box::new(NthValueEvaluator { state })) } - fn supports_bounded_execution(&self) -> bool { - true - } - - fn uses_window_frame(&self) -> bool { - true - } - fn reverse_expr(&self) -> Option> { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => return None, + NthValueKind::Nth(idx) => NthValueKind::Nth(-idx), }; Some(Arc::new(Self { name: self.name.clone(), @@ -152,64 +146,113 @@ pub(crate) struct NthValueEvaluator { } impl PartitionEvaluator for NthValueEvaluator { - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::NthValue(self.state.clone())) - } - - fn update_state( - &mut self, - state: &WindowAggState, - _idx: usize, - _range_columns: &[ArrayRef], - _sort_partition_points: &[Range], - ) -> Result<()> { - // If we do not use state, update_state does nothing - self.state.range.clone_from(&state.window_frame_range); - Ok(()) - } - - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { - if let BuiltinWindowState::NthValue(nth_value_state) = state { - self.state = nth_value_state.clone() + /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), + /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we + /// can memoize the result. Once result is calculated, it will always stay + /// same. Hence, we do not need to keep past data as we process the entire + /// dataset. + fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { + let out = &state.out_col; + let size = out.len(); + let mut buffer_size = 1; + // Decide if we arrived at a final result yet: + let (is_prunable, is_reverse_direction) = match self.state.kind { + NthValueKind::First => { + let n_range = + state.window_frame_range.end - state.window_frame_range.start; + (n_range > 0 && size > 0, false) + } + NthValueKind::Last => (true, true), + NthValueKind::Nth(n) => { + let n_range = + state.window_frame_range.end - state.window_frame_range.start; + match n.cmp(&0) { + Ordering::Greater => { + (n_range >= (n as usize) && size > (n as usize), false) + } + Ordering::Less => { + let reverse_index = (-n) as usize; + buffer_size = reverse_index; + // Negative index represents reverse direction. + (n_range >= reverse_index, true) + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } + } + } + }; + if is_prunable { + if self.state.finalized_result.is_none() && !is_reverse_direction { + let result = ScalarValue::try_from_array(out, size - 1)?; + self.state.finalized_result = Some(result); + } + state.window_frame_range.start = + state.window_frame_range.end.saturating_sub(buffer_size); } Ok(()) } - fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { - if let Some(ref result) = self.state.finalized_result { - Ok(result.clone()) - } else { - self.evaluate_inside_range(values, &self.state.range) - } - } - - fn evaluate_inside_range( - &self, + fn evaluate( + &mut self, values: &[ArrayRef], range: &Range, ) -> Result { - // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. - let arr = &values[0]; - let n_range = range.end - range.start; - if n_range == 0 { - // We produce None if the window is empty. - return ScalarValue::try_from(arr.data_type()); - } - match self.state.kind { - NthValueKind::First => ScalarValue::try_from_array(arr, range.start), - NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), - NthValueKind::Nth(n) => { - // We are certain that n > 0. - let index = (n as usize) - 1; - if index >= n_range { - ScalarValue::try_from(arr.data_type()) - } else { - ScalarValue::try_from_array(arr, range.start + index) + if let Some(ref result) = self.state.finalized_result { + Ok(result.clone()) + } else { + // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1. + let arr = &values[0]; + let n_range = range.end - range.start; + if n_range == 0 { + // We produce None if the window is empty. + return ScalarValue::try_from(arr.data_type()); + } + match self.state.kind { + NthValueKind::First => ScalarValue::try_from_array(arr, range.start), + NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1), + NthValueKind::Nth(n) => { + match n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } else { + ScalarValue::try_from_array(arr, range.start + index) + } + } + Ordering::Less => { + let reverse_index = (-n) as usize; + if n_range >= reverse_index { + ScalarValue::try_from_array( + arr, + range.start + n_range - reverse_index, + ) + } else { + // Outside the range, return NULL: + ScalarValue::try_from(arr.data_type()) + } + } + Ordering::Equal => { + // The case n = 0 is not valid for the NTH_VALUE function. + unreachable!(); + } + } } } } } + + fn supports_bounded_execution(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true + } } #[cfg(test)] @@ -233,11 +276,11 @@ mod tests { end: i + 1, }) } - let evaluator = expr.create_evaluator()?; + let mut evaluator = expr.create_evaluator()?; let values = expr.evaluate_args(&batch)?; let result = ranges .iter() - .map(|range| evaluator.evaluate_inside_range(&values, range)) + .map(|range| evaluator.evaluate(&values, range)) .collect::>>()?; let result = ScalarValue::iter_to_array(result.into_iter())?; let result = as_int32_array(&result)?; diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index 479fa263337a8..f5442e1b0fee4 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -18,13 +18,16 @@ //! Defines physical expression for `ntile` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; +use crate::expressions::Column; use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::Field; -use arrow_schema::DataType; +use arrow_schema::{DataType, SchemaRef, SortOptions}; use datafusion_common::Result; +use datafusion_expr::PartitionEvaluator; + use std::any::Any; use std::sync::Arc; @@ -38,6 +41,10 @@ impl Ntile { pub fn new(name: String, n: u64) -> Self { Self { name, n } } + + pub fn get_n(&self) -> u64 { + self.n + } } impl BuiltInWindowFunctionExpr for Ntile { @@ -62,6 +69,18 @@ impl BuiltInWindowFunctionExpr for Ntile { fn create_evaluator(&self) -> Result> { Ok(Box::new(NtileEvaluator { n: self.n })) } + + fn get_result_ordering(&self, schema: &SchemaRef) -> Option { + // The built-in NTILE window function introduces a new ordering: + schema.column_with_name(self.name()).map(|(idx, field)| { + let expr = Arc::new(Column::new(field.name(), idx)); + let options = SortOptions { + descending: false, + nulls_first: false, + }; // ASC, NULLS LAST + PhysicalSortExpr { expr, options } + }) + } } #[derive(Debug)] @@ -70,11 +89,16 @@ pub(crate) struct NtileEvaluator { } impl PartitionEvaluator for NtileEvaluator { - fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); for i in 0..num_rows { - let res = i * self.n / num_rows; + let res = i * n / num_rows; vec.push(res + 1) } Ok(Arc::new(UInt64Array::from(vec))) diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs deleted file mode 100644 index db60fdd5f1fa6..0000000000000 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ /dev/null @@ -1,220 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Partition evaluation module - -use crate::window::window_expr::BuiltinWindowState; -use crate::window::WindowAggState; -use arrow::array::ArrayRef; -use datafusion_common::Result; -use datafusion_common::{DataFusionError, ScalarValue}; -use std::fmt::Debug; -use std::ops::Range; - -/// Partition evaluator for Window Functions -/// -/// # Background -/// -/// An implementation of this trait is created and used for each -/// partition defined by an `OVER` clause and is instantiated by -/// [`BuiltInWindowFunctionExpr::create_evaluator`] -/// -/// For example, evaluating `window_func(val) OVER (PARTITION BY col)` -/// on the following data: -/// -/// ```text -/// col | val -/// --- + ---- -/// A | 10 -/// A | 10 -/// C | 20 -/// D | 30 -/// D | 30 -/// ``` -/// -/// Will instantiate three `PartitionEvaluator`s, one each for the -/// partitions defined by `col=A`, `col=B`, and `col=C`. -/// -/// ```text -/// col | val -/// --- + ---- -/// A | 10 <--- partition 1 -/// A | 10 -/// -/// col | val -/// --- + ---- -/// C | 20 <--- partition 2 -/// -/// col | val -/// --- + ---- -/// D | 30 <--- partition 3 -/// D | 30 -/// ``` -/// -/// Different methods on this trait will be called depending on the -/// capabilities described by [`BuiltInWindowFunctionExpr`]: -/// -/// # Stateless `PartitionEvaluator` -/// -/// In this case, [`Self::evaluate`], [`Self::evaluate_with_rank`] or -/// [`Self::evaluate_inside_range`] is called with values for the -/// entire partition. -/// -/// # Stateful `PartitionEvaluator` -/// -/// In this case, [`Self::evaluate_stateful`] is called to calculate -/// the results of the window function incrementally for each new -/// batch, saving and restoring any state needed to do so as -/// [`BuiltinWindowState`]. -/// -/// For example, when computing `ROW_NUMBER` incrementally, -/// [`Self::evaluate_stateful`] will be called multiple times with -/// different batches. For all batches after the first, the output -/// `row_number` must start from last `row_number` produced for the -/// previous batch. The previous row number is saved and restored as -/// the state. -/// -/// [`BuiltInWindowFunctionExpr`]: crate::window::BuiltInWindowFunctionExpr -/// [`BuiltInWindowFunctionExpr::create_evaluator`]: crate::window::BuiltInWindowFunctionExpr::create_evaluator -pub trait PartitionEvaluator: Debug + Send { - /// Can this evaluator be evaluated with (only) rank - /// - /// If `include_rank` is true, then [`Self::evaluate_with_rank`] - /// will be called for each partition, which includes the - /// `rank`. - fn include_rank(&self) -> bool { - false - } - - /// Returns the internal state of the window function - /// - /// Only used for stateful evaluation - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::Default) - } - - /// Updates the internal state for window function - /// - /// Only used for stateful evaluation - /// - /// `state`: is useful to update internal state for window function. - /// `idx`: is the index of last row for which result is calculated. - /// `range_columns`: is the result of order by column values. It is used to calculate rank boundaries - /// `sort_partition_points`: is the boundaries of each rank in the range_column. It is used to update rank. - fn update_state( - &mut self, - _state: &WindowAggState, - _idx: usize, - _range_columns: &[ArrayRef], - _sort_partition_points: &[Range], - ) -> Result<()> { - // If we do not use state, update_state does nothing - Ok(()) - } - - /// Sets the internal state for window function - /// - /// Only used for stateful evaluation - fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> { - Err(DataFusionError::NotImplemented( - "set_state is not implemented for this window function".to_string(), - )) - } - - /// Gets the range where the window function result is calculated. - /// - /// `idx`: is the index of last row for which result is calculated. - /// `n_rows`: is the number of rows of the input record batch (Used during bounds check) - fn get_range(&self, _idx: usize, _n_rows: usize) -> Result> { - Err(DataFusionError::NotImplemented( - "get_range is not implemented for this window function".to_string(), - )) - } - - /// Called for window functions that *do not use* values from the - /// the window frame, such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, - /// `PERCENT_RANK`, `CUME_DIST`, `LEAD`, `LAG`). - fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { - Err(DataFusionError::NotImplemented( - "evaluate is not implemented by default".into(), - )) - } - - /// Evaluate window function result inside given range. - /// - /// Only used for stateful evaluation - fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { - Err(DataFusionError::NotImplemented( - "evaluate_stateful is not implemented by default".into(), - )) - } - - /// [`PartitionEvaluator::evaluate_with_rank`] is called for window - /// functions that only need the rank of a row within its window - /// frame. - /// - /// Evaluate the partition evaluator against the partition using - /// the row ranks. For example, `RANK(col)` produces - /// - /// ```text - /// col | rank - /// --- + ---- - /// A | 1 - /// A | 1 - /// C | 3 - /// D | 4 - /// D | 5 - /// ``` - /// - /// For this case, `num_rows` would be `5` and the - /// `ranks_in_partition` would be called with - /// - /// ```text - /// [ - /// (0,1), - /// (2,2), - /// (3,4), - /// ] - /// ``` - /// - /// See [`Self::include_rank`] for more details - fn evaluate_with_rank( - &self, - _num_rows: usize, - _ranks_in_partition: &[Range], - ) -> Result { - Err(DataFusionError::NotImplemented( - "evaluate_partition_with_rank is not implemented by default".into(), - )) - } - - /// Called for window functions that use values from window frame, - /// such as `FIRST_VALUE`, `LAST_VALUE`, `NTH_VALUE` and produce a - /// single value for every row in the partition. - /// - /// Returns a [`ScalarValue`] that is the value of the window function for the entire partition - fn evaluate_inside_range( - &self, - _values: &[ArrayRef], - _range: &Range, - ) -> Result { - Err(DataFusionError::NotImplemented( - "evaluate_inside_range is not implemented by default".into(), - )) - } -} diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 89ca40dd564f1..9bc36728f46ef 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -18,15 +18,19 @@ //! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated //! at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_expr::{BuiltinWindowState, RankState}; -use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; -use crate::PhysicalExpr; +use crate::expressions::Column; +use crate::window::window_expr::RankState; +use crate::window::BuiltInWindowFunctionExpr; +use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; +use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::PartitionEvaluator; + use std::any::Any; use std::iter; use std::ops::Range; @@ -100,16 +104,24 @@ impl BuiltInWindowFunctionExpr for Rank { &self.name } - fn supports_bounded_execution(&self) -> bool { - matches!(self.rank_type, RankType::Basic | RankType::Dense) - } - fn create_evaluator(&self) -> Result> { Ok(Box::new(RankEvaluator { state: RankState::default(), rank_type: self.rank_type, })) } + + fn get_result_ordering(&self, schema: &SchemaRef) -> Option { + // The built-in RANK window function (in all modes) introduces a new ordering: + schema.column_with_name(self.name()).map(|(idx, field)| { + let expr = Arc::new(Column::new(field.name(), idx)); + let options = SortOptions { + descending: false, + nulls_first: false, + }; // ASC, NULLS LAST + PhysicalSortExpr { expr, options } + }) + } } #[derive(Debug)] @@ -119,61 +131,38 @@ pub(crate) struct RankEvaluator { } impl PartitionEvaluator for RankEvaluator { - fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { - let start = idx; - let end = idx + 1; - Ok(Range { start, end }) - } - - fn state(&self) -> Result { - Ok(BuiltinWindowState::Rank(self.state.clone())) - } - - fn update_state( + /// Evaluates the window function inside the given range. + fn evaluate( &mut self, - state: &WindowAggState, - idx: usize, - range_columns: &[ArrayRef], - sort_partition_points: &[Range], - ) -> Result<()> { - // find range inside `sort_partition_points` containing `idx` - let chunk_idx = sort_partition_points - .iter() - .position(|elem| elem.start <= idx && idx < elem.end) - .ok_or_else(|| { - DataFusionError::Execution( - "Expects sort_partition_points to contain idx".to_string(), - ) - })?; - let chunk = &sort_partition_points[chunk_idx]; - let last_rank_data = get_row_at_idx(range_columns, chunk.end - 1)?; + values: &[ArrayRef], + range: &Range, + ) -> Result { + let row_idx = range.start; + // There is no argument, values are order by column values (where rank is calculated) + let range_columns = values; + let last_rank_data = get_row_at_idx(range_columns, row_idx)?; let empty = self.state.last_rank_data.is_empty(); if empty || self.state.last_rank_data != last_rank_data { self.state.last_rank_data = last_rank_data; - self.state.last_rank_boundary = state.offset_pruned_rows + chunk.start; - self.state.n_rank = 1 + if empty { chunk_idx } else { self.state.n_rank }; + self.state.last_rank_boundary += self.state.current_group_count; + self.state.current_group_count = 1; + self.state.n_rank += 1; + } else { + // data is still in the same rank + self.state.current_group_count += 1; } - Ok(()) - } - - /// evaluate window function result inside given range - fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { match self.rank_type { RankType::Basic => Ok(ScalarValue::UInt64(Some( self.state.last_rank_boundary as u64 + 1, ))), RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), - RankType::Percent => Err(DataFusionError::Execution( - "Can not execute PERCENT_RANK in a streaming fashion".to_string(), - )), + RankType::Percent => { + exec_err!("Can not execute PERCENT_RANK in a streaming fashion") + } } } - fn include_rank(&self) -> bool { - true - } - - fn evaluate_with_rank( + fn evaluate_all_with_rank( &self, num_rows: usize, ranks_in_partition: &[Range], @@ -219,6 +208,14 @@ impl PartitionEvaluator for RankEvaluator { }; Ok(result) } + + fn supports_bounded_execution(&self) -> bool { + matches!(self.rank_type, RankType::Basic | RankType::Dense) + } + + fn include_rank(&self) -> bool { + true + } } #[cfg(test)] @@ -230,6 +227,7 @@ mod tests { test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected) } + #[allow(clippy::single_range_in_vec_init)] fn test_without_rank(expr: &Rank, expected: Vec) -> Result<()> { test_i32_result(expr, vec![0..8], expected) } @@ -242,7 +240,7 @@ mod tests { ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(num_rows, &ranks)?; + .evaluate_all_with_rank(num_rows, &ranks)?; let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); @@ -254,7 +252,7 @@ mod tests { ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_with_rank(8, &ranks)?; + let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); @@ -278,6 +276,7 @@ mod tests { } #[test] + #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { let r = percent_rank("arr".into()); diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index 9883d67f7cd8d..f5e2f65a656e5 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -17,13 +17,17 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; +use crate::expressions::Column; +use crate::window::window_expr::NumRowsState; use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; +use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::PartitionEvaluator; + use std::any::Any; use std::ops::Range; use std::sync::Arc; @@ -61,12 +65,20 @@ impl BuiltInWindowFunctionExpr for RowNumber { &self.name } - fn create_evaluator(&self) -> Result> { - Ok(Box::::default()) + fn get_result_ordering(&self, schema: &SchemaRef) -> Option { + // The built-in ROW_NUMBER window function introduces a new ordering: + schema.column_with_name(self.name()).map(|(idx, field)| { + let expr = Arc::new(Column::new(field.name(), idx)); + let options = SortOptions { + descending: false, + nulls_first: false, + }; // ASC, NULLS LAST + PhysicalSortExpr { expr, options } + }) } - fn supports_bounded_execution(&self) -> bool { - true + fn create_evaluator(&self) -> Result> { + Ok(Box::::default()) } } @@ -76,28 +88,29 @@ pub(crate) struct NumRowsEvaluator { } impl PartitionEvaluator for NumRowsEvaluator { - fn state(&self) -> Result { - // If we do not use state we just return Default - Ok(BuiltinWindowState::NumRows(self.state.clone())) - } - - fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { - let start = idx; - let end = idx + 1; - Ok(Range { start, end }) - } - /// evaluate window function result inside given range - fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &Range, + ) -> Result { self.state.n_rows += 1; Ok(ScalarValue::UInt64(Some(self.state.n_rows as u64))) } - fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, ))) } + + fn supports_bounded_execution(&self) -> bool { + true + } } #[cfg(test)] @@ -118,7 +131,7 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, batch.num_rows())?; + .evaluate_all(&values, batch.num_rows())?; let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *result); @@ -136,7 +149,7 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, batch.num_rows())?; + .evaluate_all(&values, batch.num_rows())?; let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *result); diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 709f8d23be366..1494129cf8976 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -139,8 +139,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - self.aggregate.supports_bounded_execution() - && !self.window_frame.end_bound.is_unbounded() + !self.window_frame.end_bound.is_unbounded() } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 7fe616feda610..4211a616e100a 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,22 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::window_frame_state::WindowFrameContext; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + use crate::{PhysicalExpr, PhysicalSortExpr}; + use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::window_state::{ + PartitionBatchState, WindowAggState, WindowFrameContext, +}; +use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame}; + use indexmap::IndexMap; -use std::any::Any; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; /// Common trait for [window function] implementations /// @@ -58,7 +61,7 @@ use std::sync::Arc; /// [`PlainAggregateWindowExpr`]: crate::window::PlainAggregateWindowExpr /// [`SlidingAggregateWindowExpr`]: crate::window::SlidingAggregateWindowExpr pub trait WindowExpr: Send + Sync + Debug { - /// Returns the window expression as [`Any`](std::any::Any) so that it can be + /// Returns the window expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -81,8 +84,10 @@ pub trait WindowExpr: Send + Sync + Debug { fn evaluate_args(&self, batch: &RecordBatch) -> Result> { self.expressions() .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect() } @@ -96,10 +101,7 @@ pub trait WindowExpr: Send + Sync + Debug { _partition_batches: &PartitionBatches, _window_agg_state: &mut PartitionWindowAggStates, ) -> Result<()> { - Err(DataFusionError::Internal(format!( - "evaluate_stateful is not implemented for {}", - self.name() - ))) + internal_err!("evaluate_stateful is not implemented for {}", self.name()) } /// Expressions that's from the window function's partition by clause, empty if absent @@ -116,25 +118,6 @@ pub trait WindowExpr: Send + Sync + Debug { .collect::>>() } - /// Get sort columns that can be used for peer evaluation, empty if absent - fn sort_columns(&self, batch: &RecordBatch) -> Result> { - let order_by_columns = self.order_by_columns(batch)?; - Ok(order_by_columns) - } - - /// Get values columns (argument of Window Function) - /// and order by columns (columns of the ORDER BY expression) used in evaluators - fn get_values_orderbys( - &self, - record_batch: &RecordBatch, - ) -> Result<(Vec, Vec)> { - let values = self.evaluate_args(record_batch)?; - let order_by_columns = self.order_by_columns(record_batch)?; - let order_bys: Vec = - order_by_columns.iter().map(|s| s.values.clone()).collect(); - Ok((values, order_bys)) - } - /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; @@ -243,7 +226,8 @@ pub trait AggregateWindowExpr: WindowExpr { mut idx: usize, not_end: bool, ) -> Result { - let (values, order_bys) = self.get_values_orderbys(record_batch)?; + let values = self.evaluate_args(record_batch)?; + let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); // We iterate on each row to perform a running calculation. let length = values[0].len(); let mut row_wise_results: Vec = vec![]; @@ -271,10 +255,14 @@ pub trait AggregateWindowExpr: WindowExpr { let out_type = field.data_type(); Ok(new_empty_array(out_type)) } else { - ScalarValue::iter_to_array(row_wise_results.into_iter()) + ScalarValue::iter_to_array(row_wise_results) } } } +/// Get order by expression results inside `order_by_columns`. +pub(crate) fn get_orderby_values(order_by_columns: Vec) -> Vec { + order_by_columns.into_iter().map(|s| s.values).collect() +} #[derive(Debug)] pub enum WindowFn { @@ -289,6 +277,8 @@ pub struct RankState { pub last_rank_data: Vec, /// The index where last_rank_boundary is started pub last_rank_boundary: usize, + /// Keep the number of entries in current rank + pub current_group_count: usize, /// Rank number kept from the start pub n_rank: usize, } @@ -304,7 +294,7 @@ pub struct NumRowsState { pub enum NthValueKind { First, Last, - Nth(u32), + Nth(i64), } #[derive(Debug, Clone)] @@ -322,98 +312,6 @@ pub struct NthValueState { pub kind: NthValueKind, } -#[derive(Debug, Clone, Default)] -pub struct LeadLagState { - pub idx: usize, -} - -#[derive(Debug, Clone, Default)] -pub enum BuiltinWindowState { - Rank(RankState), - NumRows(NumRowsState), - NthValue(NthValueState), - LeadLag(LeadLagState), - #[default] - Default, -} - -#[derive(Debug)] -pub struct WindowAggState { - /// The range that we calculate the window function - pub window_frame_range: Range, - pub window_frame_ctx: Option, - /// The index of the last row that its result is calculated inside the partition record batch buffer. - pub last_calculated_index: usize, - /// The offset of the deleted row number - pub offset_pruned_rows: usize, - /// Stores the results calculated by window frame - pub out_col: ArrayRef, - /// Keeps track of how many rows should be generated to be in sync with input record_batch. - // (For each row in the input record batch we need to generate a window result). - pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition - pub is_end: bool, -} - -impl WindowAggState { - pub fn prune_state(&mut self, n_prune: usize) { - self.window_frame_range = Range { - start: self.window_frame_range.start - n_prune, - end: self.window_frame_range.end - n_prune, - }; - self.last_calculated_index -= n_prune; - self.offset_pruned_rows += n_prune; - - match self.window_frame_ctx.as_mut() { - // Rows have no state do nothing - Some(WindowFrameContext::Rows(_)) => {} - Some(WindowFrameContext::Range { .. }) => {} - Some(WindowFrameContext::Groups { state, .. }) => { - let mut n_group_to_del = 0; - for (_, end_idx) in &state.group_end_indices { - if n_prune < *end_idx { - break; - } - n_group_to_del += 1; - } - state.group_end_indices.drain(0..n_group_to_del); - state - .group_end_indices - .iter_mut() - .for_each(|(_, start_idx)| *start_idx -= n_prune); - state.current_group_idx -= n_group_to_del; - } - None => {} - }; - } -} - -impl WindowAggState { - pub fn update( - &mut self, - out_col: &ArrayRef, - partition_batch_state: &PartitionBatchState, - ) -> Result<()> { - self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; - self.n_row_result_missing = - partition_batch_state.record_batch.num_rows() - self.last_calculated_index; - self.is_end = partition_batch_state.is_end; - Ok(()) - } -} - -/// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] -pub struct PartitionBatchState { - /// The record_batch belonging to current partition - pub record_batch: RecordBatch, - /// Flag indicating whether we have received all data for this partition - pub is_end: bool, - /// Number of rows emitted for each partition - pub n_out_row: usize, -} - /// Key for IndexMap for each unique partition /// /// For instance, if window frame is `OVER(PARTITION BY a,b)`, @@ -429,18 +327,3 @@ pub type PartitionWindowAggStates = IndexMap; /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; - -impl WindowAggState { - pub fn new(out_type: &DataType) -> Result { - let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); - Ok(Self { - window_frame_range: Range { start: 0, end: 0 }, - window_frame_ctx: None, - last_calculated_index: 0, - offset_pruned_rows: 0, - out_col: empty_out_col, - n_row_result_missing: 0, - is_end: false, - }) - } -} diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml new file mode 100644 index 0000000000000..6c761fc9687c3 --- /dev/null +++ b/datafusion/physical-plan/Cargo.toml @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-physical-plan" +description = "Physical (ExecutionPlan) implementations for DataFusion query engine" +keywords = ["arrow", "query", "sql"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_physical_plan" +path = "src/lib.rs" + +[dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } +async-trait = { workspace = true } +chrono = { version = "0.4.23", default-features = false } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr = { workspace = true } +futures = { workspace = true } +half = { version = "2.1", default-features = false } +hashbrown = { version = "0.14", features = ["raw"] } +indexmap = { workspace = true } +itertools = { version = "0.12", features = ["use_std"] } +log = { workspace = true } +once_cell = "1.18.0" +parking_lot = { workspace = true } +pin-project-lite = "^0.2.7" +rand = { workspace = true } +tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] } +uuid = { version = "^1.2", features = ["v4"] } + +[dev-dependencies] +rstest = { workspace = true } +termtree = "0.4.1" +tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } diff --git a/datafusion/physical-plan/README.md b/datafusion/physical-plan/README.md new file mode 100644 index 0000000000000..366a6b555150e --- /dev/null +++ b/datafusion/physical-plan/README.md @@ -0,0 +1,27 @@ + + +# DataFusion Common + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that contains the `ExecutionPlan` trait and the various implementations of that +trait for built in operators such as filters, projections, joins, aggregations, etc. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs new file mode 100644 index 0000000000000..cafa385eac39b --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{downcast_primitive, ArrayRef}; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_physical_expr::EmitTo; + +pub(crate) mod primitive; +use primitive::GroupValuesPrimitive; + +mod row; +use row::GroupValuesRows; + +/// An interning store for group keys +pub trait GroupValues: Send { + /// Calculates the `groups` for each input row of `cols` + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + + /// Returns the number of bytes used by this [`GroupValues`] + fn size(&self) -> usize; + + /// Returns true if this [`GroupValues`] is empty + fn is_empty(&self) -> bool; + + /// The number of values stored in this [`GroupValues`] + fn len(&self) -> usize; + + /// Emits the group values + fn emit(&mut self, emit_to: EmitTo) -> Result>; + + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) + fn clear_shrink(&mut self, batch: &RecordBatch); +} + +pub fn new_group_values(schema: SchemaRef) -> Result> { + if schema.fields.len() == 1 { + let d = schema.fields[0].data_type(); + + macro_rules! downcast_helper { + ($t:ty, $d:ident) => { + return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) + }; + } + + downcast_primitive! { + d => (downcast_helper, d), + _ => {} + } + } + + Ok(Box::new(GroupValuesRows::try_new(schema)?)) +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs new file mode 100644 index 0000000000000..e3ba284797d15 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::BooleanBufferBuilder; +use arrow::buffer::NullBuffer; +use arrow::datatypes::i256; +use arrow::record_batch::RecordBatch; +use arrow_array::cast::AsArray; +use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_physical_expr::EmitTo; +use half::f16; +use hashbrown::raw::RawTable; +use std::sync::Arc; + +/// A trait to allow hashing of floating point numbers +pub(crate) trait HashValue { + fn hash(&self, state: &RandomState) -> u64; +} + +macro_rules! hash_integer { + ($($t:ty),+) => { + $(impl HashValue for $t { + #[cfg(not(feature = "force_hash_collisions"))] + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self) + } + + #[cfg(feature = "force_hash_collisions")] + fn hash(&self, _state: &RandomState) -> u64 { + 0 + } + })+ + }; +} +hash_integer!(i8, i16, i32, i64, i128, i256); +hash_integer!(u8, u16, u32, u64); + +macro_rules! hash_float { + ($($t:ty),+) => { + $(impl HashValue for $t { + #[cfg(not(feature = "force_hash_collisions"))] + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self.to_bits()) + } + + #[cfg(feature = "force_hash_collisions")] + fn hash(&self, _state: &RandomState) -> u64 { + 0 + } + })+ + }; +} + +hash_float!(f16, f32, f64); + +/// A [`GroupValues`] storing a single column of primitive values +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesPrimitive { + /// The data type of the output array + data_type: DataType, + /// Stores the group index based on the hash of its value + /// + /// We don't store the hashes as hashing fixed width primitives + /// is fast enough for this not to benefit performance + map: RawTable, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + /// The random state used to generate hashes + random_state: RandomState, +} + +impl GroupValuesPrimitive { + pub fn new(data_type: DataType) -> Self { + assert!(PrimitiveArray::::is_compatible(&data_type)); + Self { + data_type, + map: RawTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + random_state: Default::default(), + } + } +} + +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + for v in cols[0].as_primitive::() { + let group_id = match v { + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), + Some(key) => { + let state = &self.random_state; + let hash = key.hash(state); + let insert = self.map.find_or_find_insert_slot( + hash, + |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, + |g| unsafe { self.values.get_unchecked(*g).hash(state) }, + ); + + // SAFETY: No mutation occurred since find_or_find_insert_slot + unsafe { + match insert { + Ok(v) => *v.as_ref(), + Err(slot) => { + let g = self.values.len(); + self.map.insert_in_slot(hash, slot, g); + self.values.push(key); + g + } + } + } + } + }; + groups.push(group_id) + } + Ok(()) + } + + fn size(&self) -> usize { + self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = BooleanBufferBuilder::new(values.len()); + buffer.append_n(values.len(), true); + buffer.set_bit(null_idx, false); + unsafe { NullBuffer::new_unchecked(buffer.finish(), 1) } + }); + PrimitiveArray::::new(values.into(), nulls) + } + + let array: PrimitiveArray = match emit_to { + EmitTo::All => { + self.map.clear(); + build_primitive(std::mem::take(&mut self.values), self.null_group.take()) + } + EmitTo::First(n) => { + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => *bucket.as_mut() = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + let null_group = match &mut self.null_group { + Some(v) if *v >= n => { + *v -= n; + None + } + Some(_) => self.null_group.take(), + None => None, + }; + let mut split = self.values.split_off(n); + std::mem::swap(&mut self.values, &mut split); + build_primitive(split, null_group) + } + }; + Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.values.clear(); + self.values.shrink_to(count); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs new file mode 100644 index 0000000000000..e7c7a42cf9029 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::record_batch::RecordBatch; +use arrow::row::{RowConverter, Rows, SortField}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_physical_expr::EmitTo; +use hashbrown::raw::RawTable; + +/// A [`GroupValues`] making use of [`Rows`] +pub struct GroupValuesRows { + /// Converter for the group values + row_converter: RowConverter, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys (group values) in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, group_index) + map: RawTable<(u64, usize)>, + + /// The size of `map` in bytes + map_size: usize, + + /// The actual group by values, stored in arrow [`Row`] format. + /// `group_values[i]` holds the group value for group_index `i`. + /// + /// The row format is used to compare group keys quickly and store + /// them efficiently in memory. Quick comparison is especially + /// important for multi-column group keys. + /// + /// [`Row`]: arrow::row::Row + group_values: Option, + + // buffer to be reused to store hashes + hashes_buffer: Vec, + + /// Random state for creating hashes + random_state: RandomState, +} + +impl GroupValuesRows { + pub fn try_new(schema: SchemaRef) -> Result { + let row_converter = RowConverter::new( + schema + .fields() + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + + let map = RawTable::with_capacity(0); + + Ok(Self { + row_converter, + map, + map_size: 0, + group_values: None, + hashes_buffer: Default::default(), + random_state: Default::default(), + }) + } +} + +impl GroupValues for GroupValuesRows { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // Convert the group keys into the row format + // Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available + let group_rows = self.row_converter.convert_columns(cols)?; + let n_rows = group_rows.num_rows(); + + let mut group_values = match self.group_values.take() { + Some(group_values) => group_values, + None => self.row_converter.empty_rows(0, 0), + }; + + // tracks to which group each of the input rows belongs + groups.clear(); + + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + group_rows.row(row) == group_values.row(*group_idx) + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + let group_idx = group_values.num_rows(); + group_values.push(group_rows.row(row)); + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + self.group_values = Some(group_values); + + Ok(()) + } + + fn size(&self) -> usize { + let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); + self.row_converter.size() + + group_values_size + + self.map_size + + self.hashes_buffer.allocated_size() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + self.group_values + .as_ref() + .map(|group_values| group_values.num_rows()) + .unwrap_or(0) + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let mut group_values = self + .group_values + .take() + .expect("Can not emit from empty rows"); + + let output = match emit_to { + EmitTo::All => { + let output = self.row_converter.convert_rows(&group_values)?; + group_values.clear(); + output + } + EmitTo::First(n) => { + let groups_rows = group_values.iter().take(n); + let output = self.row_converter.convert_rows(groups_rows)?; + // Clear out first n group keys by copying them to a new Rows. + // TODO file some ticket in arrow-rs to make this more efficent? + let mut new_group_values = self.row_converter.empty_rows(0, 0); + for row in group_values.iter().skip(n) { + new_group_values.push(row); + } + std::mem::swap(&mut new_group_values, &mut group_values); + + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().1.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => bucket.as_mut().1 = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + output + } + }; + + self.group_values = Some(group_values); + Ok(output) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.group_values = self.group_values.take().map(|mut rows| { + rows.clear(); + rows + }); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(count); + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs similarity index 53% rename from datafusion/core/src/physical_plan/aggregates/mod.rs rename to datafusion/physical-plan/src/aggregates/mod.rs index a2ae41de1f746..2f69ed061ce10 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -17,46 +17,50 @@ //! Aggregates functionalities -use crate::physical_plan::aggregates::{ - bounded_aggregate_stream::BoundedAggregateStream, no_grouping::AggregateStream, - row_hash::GroupedHashAggregateStream, +use std::any::Any; +use std::sync::Arc; + +use super::DisplayAs; +use crate::aggregates::{ + no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, + topk_stream::GroupedTopKAggregateStream, }; -use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::physical_plan::{ - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, + +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::windows::{get_ordered_partition_by_indices, get_window_mode}; +use crate::{ + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, }; + use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::utils::longest_consecutive_prefix; -use datafusion_common::{DataFusionError, Result}; +use arrow_schema::DataType; +use datafusion_common::stats::Precision; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ - aggregate::row_accumulator::RowAccumulator, - equivalence::project_equivalence_properties, - expressions::{Avg, CastExpr, Column, Sum}, - normalize_out_expr_with_columns_map, reverse_order_bys, - utils::{convert_to_expr, get_indices_of_matching_exprs}, - AggregateExpr, LexOrdering, LexOrderingReq, OrderingEquivalenceProperties, - PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + aggregate::is_order_sensitive, + equivalence::collapse_lex_req, + expressions::{Column, Max, Min, UnKnownColumn}, + physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, + LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; -use std::any::Any; -use std::collections::HashMap; -use std::sync::Arc; -mod bounded_aggregate_stream; +use itertools::{izip, Itertools}; + +mod group_values; mod no_grouping; +mod order; mod row_hash; -mod utils; +mod topk; +mod topk_stream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::aggregate::is_order_sensitive; +use datafusion_physical_expr::equivalence::ProjectionMapping; pub use datafusion_physical_expr::expressions::create_aggregate_expr; -use datafusion_physical_expr::utils::{ - get_finer_ordering, ordering_satisfy_requirement_concrete, -}; /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -75,21 +79,54 @@ pub enum AggregateMode { /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. + /// This mode requires tha the input is a single partition (like Final) Single, + /// Applies the entire logical aggregation operation in a single operator, + /// as opposed to Partial / Final modes which apply the logical aggregation using + /// two operators. + /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) + SinglePartitioned, +} + +impl AggregateMode { + /// Checks whether this aggregation step describes a "first stage" calculation. + /// In other words, its input is not another aggregation result and the + /// `merge_batch` method will not be called for these modes. + pub fn is_first_stage(&self) -> bool { + match self { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => true, + AggregateMode::Final | AggregateMode::FinalPartitioned => false, + } + } } /// Group By expression modes -#[derive(Debug, Clone, PartialEq, Eq)] +/// +/// `PartiallyOrdered` and `FullyOrdered` are used to reason about +/// when certain group by keys will never again be seen (and thus can +/// be emitted by the grouping operator). +/// +/// Specifically, each distinct combination of the relevant columns +/// are contiguous in the input, and once a new combination is seen +/// previous combinations are guaranteed never to appear again +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum GroupByOrderMode { - /// Some of the expressions in the GROUP BY clause have an ordering. - // For example, if the input is ordered by a, b, c and we group by b, a, d; - // the mode will be `PartiallyOrdered` meaning a subset of group b, a, d - // defines a preset for the existing ordering, e.g a, b defines a preset. + /// The input is known to be ordered by a preset (prefix but + /// possibly reordered) of the expressions in the `GROUP BY` clause. + /// + /// For example, if the input is ordered by `a, b, c` and we group + /// by `b, a, d`, `PartiallyOrdered` means a subset of group `b, + /// a, d` defines a preset for the existing ordering, in this case + /// `a, b`. PartiallyOrdered, - /// All the expressions in the GROUP BY clause have orderings. - // For example, if the input is ordered by a, b, c, d and we group by b, a; - // the mode will be `Ordered` meaning a all of the of group b, d - // defines a preset for the existing ordering, e.g a, b defines a preset. + /// The input is known to be ordered by *all* the expressions in the + /// `GROUP BY` clause. + /// + /// For example, if the input is ordered by `a, b, c, d` and we group by b, a, + /// `Ordered` means that all of the of group by expressions appear + /// as a preset for the existing ordering, in this case `a, b`. FullyOrdered, } @@ -100,6 +137,7 @@ pub enum GroupByOrderMode { /// into multiple groups, using null expressions to align each group. /// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should /// create a `PhysicalGroupBy` like +/// ```text /// PhysicalGroupBy { /// expr: [(col(a), a), (col(b), b)], /// null_expr: [(NULL, a), (NULL, b)], @@ -109,6 +147,7 @@ pub enum GroupByOrderMode { /// [true, false] // (b) <=> (NULL, b) /// ] /// } +/// ``` #[derive(Clone, Debug, Default)] pub struct PhysicalGroupBy { /// Distinct (Physical Expr, Alias) in the grouping set @@ -171,6 +210,28 @@ impl PhysicalGroupBy { pub fn is_empty(&self) -> bool { self.expr.is_empty() } + + /// Check whether grouping set is single group + pub fn is_single(&self) -> bool { + self.null_expr.is_empty() + } + + /// Calculate GROUP BY expressions according to input schema. + pub fn input_exprs(&self) -> Vec> { + self.expr + .iter() + .map(|(expr, _alias)| expr.clone()) + .collect() + } + + /// Return grouping expressions as they occur in the output schema. + pub fn output_exprs(&self) -> Vec> { + self.expr + .iter() + .enumerate() + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) + .collect() + } } impl PartialEq for PhysicalGroupBy { @@ -193,153 +254,58 @@ impl PartialEq for PhysicalGroupBy { enum StreamType { AggregateStream(AggregateStream), - GroupedHashAggregateStream(GroupedHashAggregateStream), - BoundedAggregate(BoundedAggregateStream), + GroupedHash(GroupedHashAggregateStream), + GroupedPriorityQueue(GroupedTopKAggregateStream), } impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), - StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), - StreamType::BoundedAggregate(stream) => Box::pin(stream), + StreamType::GroupedHash(stream) => Box::pin(stream), + StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } } -/// This object encapsulates ordering-related information on GROUP BY columns. -#[derive(Debug, Clone)] -pub(crate) struct AggregationOrdering { - /// Specifies whether the GROUP BY columns are partially or fully ordered. - mode: GroupByOrderMode, - /// Stores indices such that when we iterate with these indices, GROUP BY - /// expressions match input ordering. - order_indices: Vec, - /// Actual ordering information of the GROUP BY columns. - ordering: LexOrdering, -} - /// Hash aggregate execution plan #[derive(Debug)] pub struct AggregateExec { /// Aggregation mode (full, partial) - pub(crate) mode: AggregateMode, + mode: AggregateMode, /// Group by expressions - pub(crate) group_by: PhysicalGroupBy, + group_by: PhysicalGroupBy, /// Aggregate expressions - pub(crate) aggr_expr: Vec>, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression - pub(crate) filter_expr: Vec>>, + filter_expr: Vec>>, /// (ORDER BY clause) expression for each aggregate expression - pub(crate) order_by_expr: Vec>, + order_by_expr: Vec>, + /// Set if the output of this aggregation is truncated by a upstream sort/limit clause + limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate - pub(crate) input: Arc, + pub input: Arc, + /// Original aggregation schema, could be different from `schema` before dictionary group + /// keys get materialized + original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the /// same as input.schema() but for the final aggregate it will be the same as the input - /// to the partial aggregate - pub(crate) input_schema: SchemaRef, - /// The columns map used to normalize out expressions like Partitioning and PhysicalSortExpr - /// The key is the column from the input schema and the values are the columns from the output schema - columns_map: HashMap>, - /// Execution Metrics + /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`. + /// We need the input schema of partial aggregate to be able to deserialize aggregate + /// expressions from protobuf for final aggregate. + pub input_schema: SchemaRef, + /// The mapping used to normalize expressions like Partitioning and + /// PhysicalSortExpr that maps input to output + projection_mapping: ProjectionMapping, + /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Stores mode and output ordering information for the `AggregateExec`. - aggregation_ordering: Option, - required_input_ordering: Option, -} - -/// Calculates the working mode for `GROUP BY` queries. -/// - If no GROUP BY expression has an ordering, returns `None`. -/// - If some GROUP BY expressions have an ordering, returns `Some(GroupByOrderMode::PartiallyOrdered)`. -/// - If all GROUP BY expressions have orderings, returns `Some(GroupByOrderMode::Ordered)`. -fn get_working_mode( - input: &Arc, - group_by: &PhysicalGroupBy, -) -> Option<(GroupByOrderMode, Vec)> { - if group_by.groups.len() > 1 { - // We do not currently support streaming execution if we have more - // than one group (e.g. we have grouping sets). - return None; - }; - - let output_ordering = input.output_ordering().unwrap_or(&[]); - // Since direction of the ordering is not important for GROUP BY columns, - // we convert PhysicalSortExpr to PhysicalExpr in the existing ordering. - let ordering_exprs = convert_to_expr(output_ordering); - let groupby_exprs = group_by - .expr - .iter() - .map(|(item, _)| item.clone()) - .collect::>(); - // Find where each expression of the GROUP BY clause occurs in the existing - // ordering (if it occurs): - let mut ordered_indices = - get_indices_of_matching_exprs(&groupby_exprs, &ordering_exprs, || { - input.equivalence_properties() - }); - ordered_indices.sort(); - // Find out how many expressions of the existing ordering define ordering - // for expressions in the GROUP BY clause. For example, if the input is - // ordered by a, b, c, d and we group by b, a, d; the result below would be. - // 2, meaning 2 elements (a, b) among the GROUP BY columns define ordering. - let first_n = longest_consecutive_prefix(ordered_indices); - if first_n == 0 { - // No GROUP by columns are ordered, we can not do streaming execution. - return None; - } - let ordered_exprs = ordering_exprs[0..first_n].to_vec(); - // Find indices for the GROUP BY expressions such that when we iterate with - // these indices, we would match existing ordering. For the example above, - // this would produce 1, 0; meaning 1st and 0th entries (a, b) among the - // GROUP BY expressions b, a, d match input ordering. - let ordered_group_by_indices = - get_indices_of_matching_exprs(&ordered_exprs, &groupby_exprs, || { - input.equivalence_properties() - }); - Some(if first_n == group_by.expr.len() { - (GroupByOrderMode::FullyOrdered, ordered_group_by_indices) - } else { - (GroupByOrderMode::PartiallyOrdered, ordered_group_by_indices) - }) -} - -/// This function gathers the ordering information for the GROUP BY columns. -fn calc_aggregation_ordering( - input: &Arc, - group_by: &PhysicalGroupBy, -) -> Option { - get_working_mode(input, group_by).map(|(mode, order_indices)| { - let existing_ordering = input.output_ordering().unwrap_or(&[]); - let out_group_expr = output_group_expr_helper(group_by); - // Calculate output ordering information for the operator: - let out_ordering = order_indices - .iter() - .zip(existing_ordering) - .map(|(idx, input_col)| PhysicalSortExpr { - expr: out_group_expr[*idx].clone(), - options: input_col.options, - }) - .collect::>(); - AggregationOrdering { - mode, - order_indices, - ordering: out_ordering, - } - }) -} - -/// This function returns grouping expressions as they occur in the output schema. -fn output_group_expr_helper(group_by: &PhysicalGroupBy) -> Vec> { - // Update column indices. Since the group by columns come first in the output schema, their - // indices are simply 0..self.group_expr(len). - group_by - .expr() - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + required_input_ordering: Option, + /// Describes how the input is ordered relative to the group by columns + input_order_mode: InputOrderMode, + /// Describe how the output is ordered + output_ordering: Option, } /// This function returns the ordering requirement of the first non-reversible @@ -368,46 +334,57 @@ fn get_init_req( /// This function gets the finest ordering requirement among all the aggregation /// functions. If requirements are conflicting, (i.e. we can not compute the /// aggregations in a single [`AggregateExec`]), the function returns an error. -fn get_finest_requirement< - F: Fn() -> EquivalenceProperties, - F2: Fn() -> OrderingEquivalenceProperties, ->( +fn get_finest_requirement( aggr_expr: &mut [Arc], order_by_expr: &mut [Option], - eq_properties: F, - ordering_eq_properties: F2, + eq_properties: &EquivalenceProperties, ) -> Result> { + // First, we check if all the requirements are satisfied by the existing + // ordering. If so, we return `None` to indicate this. + let mut all_satisfied = true; + for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { + if eq_properties.ordering_satisfy(fn_req.as_deref().unwrap_or(&[])) { + continue; + } + if let Some(reverse) = aggr_expr.reverse_expr() { + let reverse_req = fn_req.as_ref().map(|item| reverse_order_bys(item)); + if eq_properties.ordering_satisfy(reverse_req.as_deref().unwrap_or(&[])) { + // We need to update `aggr_expr` with its reverse since only its + // reverse requirement is compatible with the existing requirements: + *aggr_expr = reverse; + *fn_req = reverse_req; + continue; + } + } + // Requirement is not satisfied: + all_satisfied = false; + } + if all_satisfied { + // All of the requirements are already satisfied. + return Ok(None); + } let mut finest_req = get_init_req(aggr_expr, order_by_expr); for (aggr_expr, fn_req) in aggr_expr.iter_mut().zip(order_by_expr.iter_mut()) { - let fn_req = if let Some(fn_req) = fn_req { - fn_req - } else { + let Some(fn_req) = fn_req else { continue; }; + if let Some(finest_req) = &mut finest_req { - if let Some(finer) = get_finer_ordering( - finest_req, - fn_req, - &eq_properties, - &ordering_eq_properties, - ) { - *finest_req = finer.to_vec(); + if let Some(finer) = eq_properties.get_finer_ordering(finest_req, fn_req) { + *finest_req = finer; continue; } // If an aggregate function is reversible, analyze whether its reverse // direction is compatible with existing requirements: if let Some(reverse) = aggr_expr.reverse_expr() { let fn_req_reverse = reverse_order_bys(fn_req); - if let Some(finer) = get_finer_ordering( - finest_req, - &fn_req_reverse, - &eq_properties, - &ordering_eq_properties, - ) { + if let Some(finer) = + eq_properties.get_finer_ordering(finest_req, &fn_req_reverse) + { // We need to update `aggr_expr` with its reverse, since only its // reverse requirement is compatible with existing requirements: *aggr_expr = reverse; - *finest_req = finer.to_vec(); + *finest_req = finer; *fn_req = fn_req_reverse; continue; } @@ -415,9 +392,9 @@ fn get_finest_requirement< // If neither of the requirements satisfy the other, this means // requirements are conflicting. Currently, we do not support // conflicting requirements. - return Err(DataFusionError::NotImplemented( - "Conflicting ordering requirements in aggregate functions is not supported".to_string(), - )); + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); } else { finest_req = Some(fn_req.clone()); } @@ -425,101 +402,63 @@ fn get_finest_requirement< Ok(finest_req) } -/// Calculate the required input ordering for the [`AggregateExec`] by considering -/// ordering requirements of order-sensitive aggregation functions. -fn calc_required_input_ordering( +/// Calculates search_mode for the aggregation +fn get_aggregate_search_mode( + group_by: &PhysicalGroupBy, input: &Arc, aggr_expr: &mut [Arc], - aggregator_reqs: LexOrderingReq, - aggregator_reverse_reqs: Option, - aggregation_ordering: &Option, -) -> Result> { - let mut required_input_ordering = vec![]; - // Boolean shows that whether `required_input_ordering` stored comes from - // `aggregator_reqs` or `aggregator_reverse_reqs` - let mut reverse_req = false; - // If reverse aggregator is None, there is no way to run aggregators in reverse mode. Hence ignore it during analysis - let aggregator_requirements = - if let Some(aggregator_reverse_reqs) = aggregator_reverse_reqs { - // If existing ordering doesn't satisfy requirement, we should do calculations - // on naive requirement (by convention, otherwise the final plan will be unintuitive), - // even if reverse ordering is possible. - // Hence, while iterating consider naive requirement last, by this way - // we prioritize naive requirement over reverse requirement, when - // reverse requirement is not helpful with removing SortExec from the plan. - vec![(true, aggregator_reverse_reqs), (false, aggregator_reqs)] - } else { - vec![(false, aggregator_reqs)] - }; - for (is_reverse, aggregator_requirement) in aggregator_requirements.into_iter() { - if let Some(AggregationOrdering { - ordering, - // If the mode is FullyOrdered or PartiallyOrdered (i.e. we are - // running with bounded memory, without breaking the pipeline), - // then we append the aggregator ordering requirement to the existing - // ordering. This way, we can still run with bounded memory. - mode: GroupByOrderMode::FullyOrdered | GroupByOrderMode::PartiallyOrdered, - .. - }) = aggregation_ordering - { - // Get the section of the input ordering that enables us to run in - // FullyOrdered or PartiallyOrdered modes: - let requirement_prefix = - if let Some(existing_ordering) = input.output_ordering() { - &existing_ordering[0..ordering.len()] - } else { - &[] - }; - let mut requirement = - PhysicalSortRequirement::from_sort_exprs(requirement_prefix.iter()); - for req in aggregator_requirement { - if requirement.iter().all(|item| req.expr.ne(&item.expr)) { - requirement.push(req); - } - } - required_input_ordering = requirement; - } else { - required_input_ordering = aggregator_requirement; - } - // keep track of from which direction required_input_ordering is constructed - reverse_req = is_reverse; - // If all of the order-sensitive aggregate functions are reversible (such as all of the order-sensitive aggregators are - // either FIRST_VALUE or LAST_VALUE). We can run aggregate expressions both in the direction of naive required ordering - // (e.g finest requirement that satisfy each aggregate function requirement) and in its reversed (opposite) direction. - // We analyze these two possibilities, and use the version that satisfies existing ordering (This saves us adding - // unnecessary SortExec to the final plan). If none of the versions satisfy existing ordering, we use naive required ordering. - // In short, if running aggregators in reverse order, helps us to remove a `SortExec`, we do so. Otherwise, we use aggregators as is. - let existing_ordering = input.output_ordering().unwrap_or(&[]); - if ordering_satisfy_requirement_concrete( - existing_ordering, - &required_input_ordering, - || input.equivalence_properties(), - || input.ordering_equivalence_properties(), - ) { - break; - } + order_by_expr: &mut [Option], + ordering_req: &mut Vec, +) -> InputOrderMode { + let groupby_exprs = group_by + .expr + .iter() + .map(|(item, _)| item.clone()) + .collect::>(); + let mut input_order_mode = InputOrderMode::Linear; + if !group_by.is_single() || groupby_exprs.is_empty() { + return input_order_mode; } - // If `required_input_ordering` is constructed using reverse requirement, we should reverse - // each `aggr_expr` to be able to correctly calculate their result in reverse order. - if reverse_req { - aggr_expr - .iter_mut() - .map(|elem| { - if is_order_sensitive(elem) { - if let Some(reverse) = elem.reverse_expr() { - *elem = reverse; + + if let Some((should_reverse, mode)) = + get_window_mode(&groupby_exprs, ordering_req, input) + { + let all_reversible = aggr_expr + .iter() + .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()); + if should_reverse && all_reversible { + izip!(aggr_expr.iter_mut(), order_by_expr.iter_mut()).for_each( + |(aggr, order_by)| { + if let Some(reverse) = aggr.reverse_expr() { + *aggr = reverse; } else { - return Err(DataFusionError::Execution( - "Aggregate expression should have a reverse expression" - .to_string(), - )); + unreachable!(); } - } - Ok(()) - }) - .collect::>>()?; + *order_by = order_by.as_ref().map(|ob| reverse_order_bys(ob)); + }, + ); + *ordering_req = reverse_order_bys(ordering_req); + } + input_order_mode = mode; } - Ok((!required_input_ordering.is_empty()).then_some(required_input_ordering)) + input_order_mode +} + +/// Check whether group by expression contains all of the expression inside `requirement` +// As an example Group By (c,b,a) contains all of the expressions in the `requirement`: (a ASC, b DESC) +fn group_by_contains_all_requirements( + group_by: &PhysicalGroupBy, + requirement: &LexOrdering, +) -> bool { + let physical_exprs = group_by.input_exprs(); + // When we have multiple groups (grouping set) + // since group by may be calculated on the subset of the group_by.expr() + // it is not guaranteed to have all of the requirements among group by expressions. + // Hence do the analysis: whether group by contains all requirements in the single group case. + group_by.is_single() + && requirement + .iter() + .all(|req| physical_exprs_contains(&physical_exprs, &req.expr)) } impl AggregateExec { @@ -529,11 +468,12 @@ impl AggregateExec { group_by: PhysicalGroupBy, mut aggr_expr: Vec>, filter_expr: Vec>>, + // Ordering requirement of each aggregate expression mut order_by_expr: Vec>, input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( + let original_schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -541,85 +481,77 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(schema); + let schema = Arc::new(materialize_dict_group_keys( + &original_schema, + group_by.expr.len(), + )); + let original_schema = Arc::new(original_schema); // Reset ordering requirement to `None` if aggregator is not order-sensitive order_by_expr = aggr_expr .iter() - .zip(order_by_expr.into_iter()) + .zip(order_by_expr) .map(|(aggr_expr, fn_reqs)| { - // If aggregation function is ordering sensitive, keep ordering requirement as is; otherwise ignore requirement - if is_order_sensitive(aggr_expr) { - fn_reqs - } else { - None - } + // If + // - aggregation function is order-sensitive and + // - aggregation is performing a "first stage" calculation, and + // - at least one of the aggregate function requirement is not inside group by expression + // keep the ordering requirement as is; otherwise ignore the ordering requirement. + // In non-first stage modes, we accumulate data (using `merge_batch`) + // from different partitions (i.e. merge partial results). During + // this merge, we consider the ordering of each partial result. + // Hence, we do not need to use the ordering requirement in such + // modes as long as partial results are generated with the + // correct ordering. + fn_reqs.filter(|req| { + is_order_sensitive(aggr_expr) + && mode.is_first_stage() + && !group_by_contains_all_requirements(&group_by, req) + }) }) .collect::>(); + let requirement = get_finest_requirement( + &mut aggr_expr, + &mut order_by_expr, + &input.equivalence_properties(), + )?; + let mut ordering_req = requirement.unwrap_or(vec![]); + let input_order_mode = get_aggregate_search_mode( + &group_by, + &input, + &mut aggr_expr, + &mut order_by_expr, + &mut ordering_req, + ); - let mut aggregator_reqs = vec![]; - let mut aggregator_reverse_reqs = None; - // Currently we support order-sensitive aggregation only in `Single` mode. - // For `Final` and `FinalPartitioned` modes, we cannot guarantee they will receive - // data according to ordering requirements. As long as we cannot produce correct result - // in `Final` mode, it is not important to produce correct result in `Partial` mode. - // We only support `Single` mode, where we are sure that output produced is final, and it - // is produced in a single step. - if mode == AggregateMode::Single { - let requirement = get_finest_requirement( - &mut aggr_expr, - &mut order_by_expr, - || input.equivalence_properties(), - || input.ordering_equivalence_properties(), - )?; - let aggregator_requirement = requirement - .as_ref() - .map(|exprs| PhysicalSortRequirement::from_sort_exprs(exprs.iter())); - aggregator_reqs = aggregator_requirement.unwrap_or(vec![]); - // If all aggregate expressions are reversible, also consider reverse - // requirement(s). The reason is that existing ordering may satisfy the - // given requirement or its reverse. By considering both, we can generate better plans. - if aggr_expr - .iter() - .all(|expr| !is_order_sensitive(expr) || expr.reverse_expr().is_some()) - { - let reverse_agg_requirement = requirement.map(|reqs| { - PhysicalSortRequirement::from_sort_exprs( - reverse_order_bys(&reqs).iter(), - ) - }); - aggregator_reverse_reqs = reverse_agg_requirement; - } - } - - // construct a map from the input columns to the output columns of the Aggregation - let mut columns_map: HashMap> = HashMap::new(); - for (expression, name) in group_by.expr.iter() { - if let Some(column) = expression.as_any().downcast_ref::() { - let new_col_idx = schema.index_of(name)?; - let entry = columns_map.entry(column.clone()).or_insert_with(Vec::new); - entry.push(Column::new(name, new_col_idx)); - }; - } + // Get GROUP BY expressions: + let groupby_exprs = group_by.input_exprs(); + // If existing ordering satisfies a prefix of the GROUP BY expressions, + // prefix requirements with this section. In this case, aggregation will + // work more efficiently. + let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); + let mut new_requirement = indices + .into_iter() + .map(|idx| PhysicalSortRequirement { + expr: groupby_exprs[idx].clone(), + options: None, + }) + .collect::>(); + // Postfix ordering requirement of the aggregation to the requirement. + let req = PhysicalSortRequirement::from_sort_exprs(&ordering_req); + new_requirement.extend(req); + new_requirement = collapse_lex_req(new_requirement); - let aggregation_ordering = calc_aggregation_ordering(&input, &group_by); + // construct a map from the input expression to the output expression of the Aggregation group by + let projection_mapping = + ProjectionMapping::try_new(&group_by.expr, &input.schema())?; - let required_input_ordering = calc_required_input_ordering( - &input, - &mut aggr_expr, - aggregator_reqs, - aggregator_reverse_reqs, - &aggregation_ordering, - )?; + let required_input_ordering = + (!new_requirement.is_empty()).then_some(new_requirement); - // If aggregator is working on multiple partitions and there is an order-sensitive aggregator with a requirement return error. - if input.output_partitioning().partition_count() > 1 - && order_by_expr.iter().any(|req| req.is_some()) - { - return Err(DataFusionError::NotImplemented( - "Order-sensitive aggregators is not supported on multiple partitions" - .to_string(), - )); - } + let aggregate_eqs = input + .equivalence_properties() + .project(&projection_mapping, schema.clone()); + let output_ordering = aggregate_eqs.oeq_class().output_ordering(); Ok(AggregateExec { mode, @@ -628,12 +560,15 @@ impl AggregateExec { filter_expr, order_by_expr, input, + original_schema, schema, input_schema, - columns_map, + projection_mapping, metrics: ExecutionPlanMetricsSet::new(), - aggregation_ordering, required_input_ordering, + limit: None, + input_order_mode, + output_ordering, }) } @@ -642,6 +577,11 @@ impl AggregateExec { &self.mode } + /// Set the `limit` of this AggExec + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } /// Grouping expressions pub fn group_expr(&self) -> &PhysicalGroupBy { &self.group_by @@ -649,7 +589,7 @@ impl AggregateExec { /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { - output_group_expr_helper(&self.group_by) + self.group_by.output_exprs() } /// Aggregate expressions @@ -677,159 +617,98 @@ impl AggregateExec { self.input_schema.clone() } + /// number of rows soft limit of the AggregateExec + pub fn limit(&self) -> Option { + self.limit + } + fn execute_typed( &self, partition: usize, context: Arc, ) -> Result { + // no group by at all if self.group_by.expr.is_empty() { - Ok(StreamType::AggregateStream(AggregateStream::new( + return Ok(StreamType::AggregateStream(AggregateStream::new( self, context, partition, - )?)) - } else if let Some(aggregation_ordering) = &self.aggregation_ordering { - let aggregation_ordering = aggregation_ordering.clone(); - Ok(StreamType::BoundedAggregate(BoundedAggregateStream::new( - self, - context, - partition, - aggregation_ordering, - )?)) - } else { - Ok(StreamType::GroupedHashAggregateStream( - GroupedHashAggregateStream::new(self, context, partition)?, - )) + )?)); } - } -} -impl ExecutionPlan for AggregateExec { - /// Return a reference to Any that can be used for down-casting - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - /// Get the output partitioning of this plan - fn output_partitioning(&self) -> Partitioning { - match &self.mode { - AggregateMode::Partial | AggregateMode::Single => { - // Partial and Single Aggregation will not change the output partitioning but need to respect the Alias - let input_partition = self.input.output_partitioning(); - match input_partition { - Partitioning::Hash(exprs, part) => { - let normalized_exprs = exprs - .into_iter() - .map(|expr| { - normalize_out_expr_with_columns_map( - expr, - &self.columns_map, - ) - }) - .collect::>(); - Partitioning::Hash(normalized_exprs, part) - } - _ => input_partition, - } + // grouping by an expression that has a sort/limit upstream + if let Some(limit) = self.limit { + if !self.is_unordered_unfiltered_group_by_distinct() { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); } - // Final Aggregation's output partitioning is the same as its real input - _ => self.input.output_partitioning(), } + + // grouping by something else and we need to just materialize all results + Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( + self, context, partition, + )?)) } - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns an error to indicate this. - fn unbounded_output(&self, children: &[bool]) -> Result { - if children[0] { - if self.aggregation_ordering.is_none() { - // Cannot run without breaking pipeline. - Err(DataFusionError::Plan( - "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs.".to_string(), - )) - } else { - Ok(true) - } + /// Finds the DataType and SortDirection for this Aggregate, if there is one + pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; + if let Some(max) = agg_expr.as_any().downcast_ref::() { + Some((max.field().ok()?, true)) + } else if let Some(min) = agg_expr.as_any().downcast_ref::() { + Some((min.field().ok()?, false)) } else { - Ok(false) + None } } - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.aggregation_ordering - .as_ref() - .map(|item: &AggregationOrdering| item.ordering.as_slice()) + pub fn group_by(&self) -> &PhysicalGroupBy { + &self.group_by } - fn required_input_distribution(&self) -> Vec { - match &self.mode { - AggregateMode::Partial | AggregateMode::Single => { - vec![Distribution::UnspecifiedDistribution] - } - AggregateMode::FinalPartitioned => { - vec![Distribution::HashPartitioned(self.output_group_expr())] - } - AggregateMode::Final => vec![Distribution::SinglePartition], + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions + /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule + /// on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + // ensure there is a group by + if self.group_by().is_empty() { + return false; } + // ensure there are no aggregate expressions + if !self.aggr_expr().is_empty() { + return false; + } + // ensure there are no filters on aggregate expressions; the above check + // may preclude this case + if self.filter_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there are no order by expressions + if self.order_by_expr().iter().any(|e| e.is_some()) { + return false; + } + // ensure there is no output ordering; can this rule be relaxed? + if self.output_ordering().is_some() { + return false; + } + // ensure no ordering is required on the input + if self.required_input_ordering()[0].is_some() { + return false; + } + true } +} - fn required_input_ordering(&self) -> Vec> { - vec![self.required_input_ordering.clone()] - } - - fn equivalence_properties(&self) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(self.schema()); - project_equivalence_properties( - self.input.equivalence_properties(), - &self.columns_map, - &mut new_properties, - ); - new_properties - } - - fn children(&self) -> Vec> { - vec![self.input.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(AggregateExec::try_new( - self.mode, - self.group_by.clone(), - self.aggr_expr.clone(), - self.filter_expr.clone(), - self.order_by_expr.clone(), - children[0].clone(), - self.input_schema.clone(), - )?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.execute_typed(partition, context) - .map(|stream| stream.into()) - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - +impl DisplayAs for AggregateExec { fn fmt_as( &self, t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { match t { - DisplayFormatType::Default => { + DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "AggregateExec: mode={:?}", self.mode)?; - let g: Vec = if self.group_by.groups.len() == 1 { + let g: Vec = if self.group_by.is_single() { self.group_by .expr .iter() @@ -884,37 +763,181 @@ impl ExecutionPlan for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; + if let Some(limit) = self.limit { + write!(f, ", lim=[{limit}]")?; + } + + if self.input_order_mode != InputOrderMode::Linear { + write!(f, ", ordering_mode={:?}", self.input_order_mode)?; + } + } + } + Ok(()) + } +} + +impl ExecutionPlan for AggregateExec { + /// Return a reference to Any that can be used for down-casting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + let input_partition = self.input.output_partitioning(); + if self.mode.is_first_stage() { + // First stage aggregation will not change the output partitioning, + // but needs to respect aliases (e.g. mapping in the GROUP BY + // expression). + let input_eq_properties = self.input.equivalence_properties(); + // First stage Aggregation will not change the output partitioning but need to respect the Alias + let input_partition = self.input.output_partitioning(); + if let Partitioning::Hash(exprs, part) = input_partition { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + input_eq_properties + .project_expr(&expr, &self.projection_mapping) + .unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) + }) + .collect(); + return Partitioning::Hash(normalized_exprs, part); + } + } + // Final Aggregation's output partitioning is the same as its real input + input_partition + } + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + if children[0] { + if self.input_order_mode == InputOrderMode::Linear { + // Cannot run without breaking pipeline. + plan_err!( + "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." + ) + } else { + Ok(true) + } + } else { + Ok(false) + } + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + fn required_input_distribution(&self) -> Vec { + match &self.mode { + AggregateMode::Partial => { + vec![Distribution::UnspecifiedDistribution] + } + AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { + vec![Distribution::HashPartitioned(self.output_group_expr())] + } + AggregateMode::Final | AggregateMode::Single => { + vec![Distribution::SinglePartition] + } + } + } + + fn required_input_ordering(&self) -> Vec> { + vec![self.required_input_ordering.clone()] + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + self.input + .equivalence_properties() + .project(&self.projection_mapping, self.schema()) + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let mut me = AggregateExec::try_new( + self.mode, + self.group_by.clone(), + self.aggr_expr.clone(), + self.filter_expr.clone(), + self.order_by_expr.clone(), + children[0].clone(), + self.input_schema.clone(), + )?; + me.limit = self.limit; + Ok(Arc::new(me)) + } - if let Some(aggregation_ordering) = &self.aggregation_ordering { - write!(f, ", ordering_mode={:?}", aggregation_ordering.mode)?; - } - } - } - Ok(()) + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.execute_typed(partition, context) + .map(|stream| stream.into()) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: // - aggregations somtimes also preserve invariants such as min, max... + let column_statistics = Statistics::unknown_column(&self.schema()); match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned if self.group_by.expr.is_empty() => { - Statistics { - num_rows: Some(1), - is_exact: true, - ..Default::default() - } + Ok(Statistics { + num_rows: Precision::Exact(1), + column_statistics, + total_byte_size: Precision::Absent, + }) + } + _ => { + // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. + // When it is larger than 1, we degrade the precision since it may decrease after aggregation. + let num_rows = if let Some(value) = + self.input().statistics()?.num_rows.get_value() + { + if *value > 1 { + self.input().statistics()?.num_rows.to_inexact() + } else if *value == 0 { + // Aggregation on an empty table creates a null row. + self.input() + .statistics()? + .num_rows + .add(&Precision::Exact(1)) + } else { + // num_rows = 1 case + self.input().statistics()?.num_rows + } + } else { + Precision::Absent + }; + Ok(Statistics { + num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) } - _ => Statistics { - // the output row count is surely not larger than its input row count - num_rows: self.input.statistics().num_rows, - is_exact: false, - ..Default::default() - }, } } } @@ -947,7 +970,8 @@ fn create_schema( } AggregateMode::Final | AggregateMode::FinalPartitioned - | AggregateMode::Single => { + | AggregateMode::Single + | AggregateMode::SinglePartitioned => { // in final mode, the field with the final result of the accumulator for expr in aggr_expr { fields.push(expr.field()?) @@ -958,12 +982,30 @@ fn create_schema( Ok(Schema::new(fields)) } +/// returns schema with dictionary group keys materialized as their value types +/// The actual convertion happens in `RowConverter` and we don't do unnecessary +/// conversion back into dictionaries +fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { + let fields = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| match field.data_type() { + DataType::Dictionary(_, value_data_type) if i < group_count => { + Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) + } + _ => Field::clone(field), + }) + .collect::>(); + Schema::new(fields) +} + fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) } -/// returns physical expressions to evaluate against a batch +/// returns physical expressions for arguments to evaluate against a batch /// The expressions are different depending on `mode`: /// * Partial: AggregateExpr::expressions /// * Final: columns of `AggregateExpr::state_fields()` @@ -973,42 +1015,25 @@ fn aggregate_expressions( col_idx_base: usize, ) -> Result>>> { match mode { - AggregateMode::Partial | AggregateMode::Single => Ok(aggr_expr + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => Ok(aggr_expr .iter() .map(|agg| { - let pre_cast_type = if let Some(Sum { - data_type, - pre_cast_to_sum_type, - .. - }) = agg.as_any().downcast_ref::() - { - if *pre_cast_to_sum_type { - Some(data_type.clone()) - } else { - None - } - } else if let Some(Avg { - sum_data_type, - pre_cast_to_sum_type, - .. - }) = agg.as_any().downcast_ref::() - { - if *pre_cast_to_sum_type { - Some(sum_data_type.clone()) - } else { - None + let mut result = agg.expressions().clone(); + // In partial mode, append ordering requirements to expressions' results. + // Ordering requirements are used by subsequent executors to satisfy the required + // ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes. + if matches!(mode, AggregateMode::Partial) { + if let Some(ordering_req) = agg.order_bys() { + let ordering_exprs = ordering_req + .iter() + .map(|item| item.expr.clone()) + .collect::>(); + result.extend(ordering_exprs); } - } else { - None - }; - agg.expressions() - .into_iter() - .map(|expr| { - pre_cast_type.clone().map_or(expr.clone(), |cast_type| { - Arc::new(CastExpr::new(expr, cast_type, None)) - }) - }) - .collect::>() + } + result }) .collect()), // in this mode, we build the merge expressions of the aggregation @@ -1045,7 +1070,6 @@ fn merge_expressions( } pub(crate) type AccumulatorItem = Box; -pub(crate) type RowAccumulatorItem = Box; fn create_accumulators( aggr_expr: &[Arc], @@ -1056,20 +1080,6 @@ fn create_accumulators( .collect::>>() } -fn create_row_accumulators( - aggr_expr: &[Arc], -) -> Result> { - let mut state_index = 0; - aggr_expr - .iter() - .map(|expr| { - let result = expr.create_row_accumulator(state_index); - state_index += expr.state_fields().unwrap().len(); - result - }) - .collect::>>() -} - /// returns a vector of ArrayRefs, where each entry corresponds to either the /// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial) fn finalize_aggregation( @@ -1081,10 +1091,11 @@ fn finalize_aggregation( // build the vector of states let a = accumulators .iter() - .map(|accumulator| accumulator.state()) - .map(|value| { - value.map(|e| { - e.iter().map(|v| v.to_array()).collect::>() + .map(|accumulator| { + accumulator.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() }) }) .collect::>>()?; @@ -1092,11 +1103,12 @@ fn finalize_aggregation( } AggregateMode::Final | AggregateMode::FinalPartitioned - | AggregateMode::Single => { + | AggregateMode::Single + | AggregateMode::SinglePartitioned => { // merge the state to the final value accumulators .iter() - .map(|accumulator| accumulator.evaluate().map(|v| v.to_array())) + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) .collect::>>() } } @@ -1108,13 +1120,15 @@ fn evaluate( batch: &RecordBatch, ) -> Result> { expr.iter() - .map(|expr| expr.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() } /// Evaluates expressions against a record batch. -fn evaluate_many( +pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { @@ -1130,14 +1144,26 @@ fn evaluate_optional( expr.iter() .map(|expr| { expr.as_ref() - .map(|expr| expr.evaluate(batch)) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .transpose() - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) }) .collect::>>() } -fn evaluate_group_by( +/// Evaluate a group by expression against a `RecordBatch` +/// +/// Arguments: +/// `group_by`: the expression to evaluate +/// `batch`: the `RecordBatch` to evaluate against +/// +/// Returns: A Vec of Vecs of Array of results +/// The outer Vect appears to be for grouping sets +/// The inner Vect contains the results per expression +/// The inner-inner Array contains the results per row +pub(crate) fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { @@ -1146,7 +1172,7 @@ fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1155,7 +1181,7 @@ fn evaluate_group_by( .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; - Ok(value.into_array(batch.num_rows())) + value.into_array(batch.num_rows()) }) .collect::>>()?; @@ -1180,44 +1206,45 @@ fn evaluate_group_by( #[cfg(test)] mod tests { + use std::any::Any; + use std::sync::Arc; + use std::task::{Context, Poll}; + use super::*; - use crate::execution::context::SessionConfig; - use crate::physical_plan::aggregates::{ - get_finest_requirement, get_working_mode, AggregateExec, AggregateMode, - PhysicalGroupBy, + use crate::aggregates::{ + get_finest_requirement, AggregateExec, AggregateMode, PhysicalGroupBy, }; - use crate::physical_plan::expressions::{col, Avg}; + use crate::coalesce_batches::CoalesceBatchesExec; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::common; + use crate::expressions::{col, Avg}; + use crate::memory::MemoryExec; + use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; - use crate::test::{assert_is_pending, csv_exec_sorted}; - use crate::{assert_batches_sorted_eq, physical_plan::common}; + use crate::{ + DisplayAs, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }; + use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; - use datafusion_common::{DataFusionError, Result, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, + Result, ScalarValue, + }; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Column, Count, FirstValue, Median, + lit, ApproxDistinct, Count, FirstValue, LastValue, Median, }; use datafusion_physical_expr::{ - AggregateExpr, EquivalenceProperties, OrderingEquivalenceProperties, - PhysicalExpr, PhysicalSortExpr, + AggregateExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, }; - use futures::{FutureExt, Stream}; - use std::any::Any; - use std::sync::Arc; - use std::task::{Context, Poll}; - use super::StreamType; - use crate::physical_plan::aggregates::GroupByOrderMode::{ - FullyOrdered, PartiallyOrdered, - }; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::{ - ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, - Statistics, - }; - use crate::prelude::SessionContext; + use datafusion_execution::memory_pool::FairSpillPool; + use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) fn create_test_schema() -> Result { @@ -1231,79 +1258,6 @@ mod tests { Ok(schema) } - /// make PhysicalSortExpr with default options - fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) - } - - /// PhysicalSortExpr with specified options - fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, - ) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } - } - - #[tokio::test] - async fn test_get_working_mode() -> Result<()> { - let test_schema = create_test_schema()?; - // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST - // Column d, e is not ordered. - let sort_exprs = vec![ - sort_expr("a", &test_schema), - sort_expr("b", &test_schema), - sort_expr("c", &test_schema), - ]; - let input = csv_exec_sorted(&test_schema, sort_exprs, true); - - // test cases consists of vector of tuples. Where each tuple represents a single test case. - // First field in the tuple is Vec where each element in the vector represents GROUP BY columns - // For instance `vec!["a", "b"]` corresponds to GROUP BY a, b - // Second field in the tuple is Option, which corresponds to expected algorithm mode. - // None represents that existing ordering is not sufficient to run executor with any one of the algorithms - // (We need to add SortExec to be able to run it). - // Some(GroupByOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in - // GroupByOrderMode. - let test_cases = vec![ - (vec!["a"], Some((FullyOrdered, vec![0]))), - (vec!["b"], None), - (vec!["c"], None), - (vec!["b", "a"], Some((FullyOrdered, vec![1, 0]))), - (vec!["c", "b"], None), - (vec!["c", "a"], Some((PartiallyOrdered, vec![1]))), - (vec!["c", "b", "a"], Some((FullyOrdered, vec![2, 1, 0]))), - (vec!["d", "a"], Some((PartiallyOrdered, vec![1]))), - (vec!["d", "b"], None), - (vec!["d", "c"], None), - (vec!["d", "b", "a"], Some((PartiallyOrdered, vec![2, 1]))), - (vec!["d", "c", "b"], None), - (vec!["d", "c", "a"], Some((PartiallyOrdered, vec![2]))), - ( - vec!["d", "c", "b", "a"], - Some((PartiallyOrdered, vec![3, 2, 1])), - ), - ]; - for (case_idx, test_case) in test_cases.iter().enumerate() { - let (group_by_columns, expected) = &test_case; - let mut group_by_exprs = vec![]; - for col_name in group_by_columns { - group_by_exprs.push((col(col_name, &test_schema)?, col_name.to_string())); - } - let group_bys = PhysicalGroupBy::new_single(group_by_exprs); - let res = get_working_mode(&input, &group_bys); - assert_eq!( - res, *expected, - "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" - ); - } - - Ok(()) - } - /// some mock data to aggregates fn some_data() -> (Arc, Vec) { // define a schema. @@ -1336,7 +1290,76 @@ mod tests { ) } - async fn check_grouping_sets(input: Arc) -> Result<()> { + /// Generates some mock data for aggregate tests. + fn some_data_v2() -> (Arc, Vec) { + // Define a schema: + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Generate data so that first and last value results are at 2nd and + // 3rd partitions. With this construction, we guarantee we don't receive + // the expected result by accident, but merging actually works properly; + // i.e. it doesn't depend on the data insertion order. + ( + schema.clone(), + vec![ + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])), + ], + ) + .unwrap(), + ], + ) + } + + fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { + let session_config = SessionConfig::new().with_batch_size(batch_size); + let runtime = Arc::new( + RuntimeEnv::new( + RuntimeConfig::default() + .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))), + ) + .unwrap(), + ); + let task_ctx = TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime); + Arc::new(task_ctx) + } + + async fn check_grouping_sets( + input: Arc, + spill: bool, + ) -> Result<()> { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { @@ -1361,8 +1384,11 @@ mod tests { DataType::Int64, ))]; - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = if spill { + new_spill_ctx(4, 1000) + } else { + Arc::new(TaskContext::default()) + }; let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1377,24 +1403,53 @@ mod tests { let result = common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; - let expected = vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", - ]; + let expected = if spill { + vec![ + "+---+-----+-----------------+", + "| a | b | COUNT(1)[count] |", + "+---+-----+-----------------+", + "| | 1.0 | 1 |", + "| | 1.0 | 1 |", + "| | 2.0 | 1 |", + "| | 2.0 | 1 |", + "| | 3.0 | 1 |", + "| | 3.0 | 1 |", + "| | 4.0 | 1 |", + "| | 4.0 | 1 |", + "| 2 | | 1 |", + "| 2 | | 1 |", + "| 2 | 1.0 | 1 |", + "| 2 | 1.0 | 1 |", + "| 3 | | 1 |", + "| 3 | | 2 |", + "| 3 | 2.0 | 2 |", + "| 3 | 3.0 | 1 |", + "| 4 | | 1 |", + "| 4 | | 2 |", + "| 4 | 3.0 | 1 |", + "| 4 | 4.0 | 2 |", + "+---+-----+-----------------+", + ] + } else { + vec![ + "+---+-----+-----------------+", + "| a | b | COUNT(1)[count] |", + "+---+-----+-----------------+", + "| | 1.0 | 2 |", + "| | 2.0 | 2 |", + "| | 3.0 | 2 |", + "| | 4.0 | 2 |", + "| 2 | | 2 |", + "| 2 | 1.0 | 2 |", + "| 3 | | 3 |", + "| 3 | 2.0 | 2 |", + "| 3 | 3.0 | 1 |", + "| 4 | | 3 |", + "| 4 | 3.0 | 1 |", + "| 4 | 4.0 | 2 |", + "+---+-----+-----------------+", + ] + }; assert_batches_sorted_eq!(expected, &result); let groups = partial_aggregate.group_expr().expr().to_vec(); @@ -1408,6 +1463,12 @@ mod tests { let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let task_ctx = if spill { + new_spill_ctx(4, 3160) + } else { + task_ctx + }; + let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, final_grouping_set, @@ -1453,7 +1514,7 @@ mod tests { } /// build the aggregates on the data from some_data() and check the results - async fn check_aggregates(input: Arc) -> Result<()> { + async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { @@ -1468,8 +1529,11 @@ mod tests { DataType::Float64, ))]; - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = if spill { + new_spill_ctx(2, 1500) + } else { + Arc::new(TaskContext::default()) + }; let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1484,15 +1548,29 @@ mod tests { let result = common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; - let expected = vec![ - "+---+---------------+-------------+", - "| a | AVG(b)[count] | AVG(b)[sum] |", - "+---+---------------+-------------+", - "| 2 | 2 | 2.0 |", - "| 3 | 3 | 7.0 |", - "| 4 | 3 | 11.0 |", - "+---+---------------+-------------+", - ]; + let expected = if spill { + vec![ + "+---+---------------+-------------+", + "| a | AVG(b)[count] | AVG(b)[sum] |", + "+---+---------------+-------------+", + "| 2 | 1 | 1.0 |", + "| 2 | 1 | 1.0 |", + "| 3 | 1 | 2.0 |", + "| 3 | 2 | 5.0 |", + "| 4 | 3 | 11.0 |", + "+---+---------------+-------------+", + ] + } else { + vec![ + "+---+---------------+-------------+", + "| a | AVG(b)[count] | AVG(b)[sum] |", + "+---+---------------+-------------+", + "| 2 | 2 | 2.0 |", + "| 3 | 3 | 7.0 |", + "| 4 | 3 | 11.0 |", + "+---+---------------+-------------+", + ] + }; assert_batches_sorted_eq!(expected, &result); let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); @@ -1535,7 +1613,13 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); - assert_eq!(3, output_rows); + if spill { + // When spilling, the output rows metrics become partial output size + final output size + // This is because final aggregation starts while partial aggregation is still emitting + assert_eq!(8, output_rows); + } else { + assert_eq!(3, output_rows); + } Ok(()) } @@ -1548,6 +1632,20 @@ mod tests { pub yield_first: bool, } + impl DisplayAs for TestYieldingExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "TestYieldingExec") + } + } + } + } + impl ExecutionPlan for TestYieldingExec { fn as_any(&self) -> &dyn Any { self @@ -1572,9 +1670,7 @@ mod tests { self: Arc, _: Vec>, ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {self:?}" - ))) + internal_err!("Children cannot be replaced in {self:?}") } fn execute( @@ -1591,9 +1687,13 @@ mod tests { Ok(Box::pin(stream)) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let (_, batches) = some_data(); - common::compute_record_batch_statistics(&[batches], &self.schema(), None) + Ok(common::compute_record_batch_statistics( + &[batches], + &self.schema(), + None, + )) } } @@ -1637,14 +1737,14 @@ mod tests { } } - //// Tests //// + //--- Tests ---// #[tokio::test] async fn aggregate_source_not_yielding() -> Result<()> { let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); - check_aggregates(input).await + check_aggregates(input, false).await } #[tokio::test] @@ -1652,7 +1752,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); - check_grouping_sets(input).await + check_grouping_sets(input, false).await } #[tokio::test] @@ -1660,7 +1760,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); - check_aggregates(input).await + check_aggregates(input, false).await } #[tokio::test] @@ -1668,7 +1768,40 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); - check_grouping_sets(input).await + check_grouping_sets(input, false).await + } + + #[tokio::test] + async fn aggregate_source_not_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); + + check_aggregates(input, true).await + } + + #[tokio::test] + async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); + + check_grouping_sets(input, true).await + } + + #[tokio::test] + #[ignore] + async fn aggregate_source_with_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + + check_aggregates(input, true).await + } + + #[tokio::test] + async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + + check_grouping_sets(input, true).await } #[tokio::test] @@ -1677,14 +1810,11 @@ mod tests { Arc::new(TestYieldingExec { yield_first: true }); let input_schema = input.schema(); - let session_ctx = SessionContext::with_config_rt( - SessionConfig::default(), - Arc::new( - RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) - .unwrap(), - ), + let runtime = Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)).unwrap(), ); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); let groups_some = PhysicalGroupBy { @@ -1738,10 +1868,10 @@ mod tests { assert!(matches!(stream, StreamType::AggregateStream(_))); } 1 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } 2 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } _ => panic!("Unknown version: {version}"), } @@ -1762,8 +1892,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1787,7 +1916,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(aggregate_exec, task_ctx); + let fut = crate::collect(aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1799,8 +1928,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_with_groups() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), @@ -1827,7 +1955,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(aggregate_exec, task_ctx); + let fut = crate::collect(aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1837,6 +1965,141 @@ mod tests { Ok(()) } + #[tokio::test] + async fn run_first_last_multi_partitions() -> Result<()> { + for use_coalesce_batches in [false, true] { + for is_first_acc in [false, true] { + for spill in [false, true] { + first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill) + .await? + } + } + } + Ok(()) + } + + // This function either constructs the physical plan below, + // + // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", + // " CoalesceBatchesExec: target_batch_size=1024", + // " CoalescePartitionsExec", + // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", + // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", + // + // or + // + // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", + // " CoalescePartitionsExec", + // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", + // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", + // + // and checks whether the function `merge_batch` works correctly for + // FIRST_VALUE and LAST_VALUE functions. + async fn first_last_multi_partitions( + use_coalesce_batches: bool, + is_first_acc: bool, + spill: bool, + ) -> Result<()> { + let task_ctx = if spill { + new_spill_ctx(2, 2886) + } else { + Arc::new(TaskContext::default()) + }; + + let (schema, data) = some_data_v2(); + let partition1 = data[0].clone(); + let partition2 = data[1].clone(); + let partition3 = data[2].clone(); + let partition4 = data[3].clone(); + + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + + let ordering_req = vec![PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions::default(), + }]; + let aggregates: Vec> = if is_first_acc { + vec![Arc::new(FirstValue::new( + col("b", &schema)?, + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + ordering_req.clone(), + vec![DataType::Float64], + ))] + } else { + vec![Arc::new(LastValue::new( + col("b", &schema)?, + "LAST_VALUE(b)".to_string(), + DataType::Float64, + ordering_req.clone(), + vec![DataType::Float64], + ))] + }; + + let memory_exec = Arc::new(MemoryExec::try_new( + &[ + vec![partition1], + vec![partition2], + vec![partition3], + vec![partition4], + ], + schema.clone(), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + vec![Some(ordering_req.clone())], + memory_exec, + schema.clone(), + )?); + let coalesce = if use_coalesce_batches { + let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); + Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc + } else { + Arc::new(CoalescePartitionsExec::new(aggregate_exec)) + as Arc + }; + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + groups, + aggregates.clone(), + vec![None], + vec![Some(ordering_req)], + coalesce, + schema, + )?) as Arc; + + let result = crate::collect(aggregate_final, task_ctx).await?; + if is_first_acc { + let expected = [ + "+---+----------------+", + "| a | FIRST_VALUE(b) |", + "+---+----------------+", + "| 2 | 0.0 |", + "| 3 | 1.0 |", + "| 4 | 3.0 |", + "+---+----------------+", + ]; + assert_batches_eq!(expected, &result); + } else { + let expected = [ + "+---+---------------+", + "| a | LAST_VALUE(b) |", + "+---+---------------+", + "| 2 | 3.0 |", + "| 3 | 5.0 |", + "| 4 | 6.0 |", + "+---+---------------+", + ]; + assert_batches_eq!(expected, &result); + }; + Ok(()) + } + #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; @@ -1851,67 +2114,72 @@ mod tests { descending: true, nulls_first: true, }; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - let col_a = Column::new("a", 0); - let col_b = Column::new("b", 1); - let col_c = Column::new("c", 2); - let col_d = Column::new("d", 3); - eq_properties.add_equal_conditions((&col_a, &col_b)); - let mut ordering_eq_properties = OrderingEquivalenceProperties::new(test_schema); - ordering_eq_properties.add_equal_conditions(( - &vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()) as _, - options: options1, - }], - &vec![PhysicalSortExpr { - expr: Arc::new(col_c.clone()) as _, - options: options2, - }], - )); + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(test_schema); + // Columns a and b are equal. + eq_properties.add_equal_conditions(col_a, col_b); + // Aggregate requirements are + // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let mut order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { - expr: Arc::new(col_a.clone()), - options: options1, - }]), - Some(vec![PhysicalSortExpr { - expr: Arc::new(col_b.clone()), + expr: col_a.clone(), options: options1, }]), - Some(vec![PhysicalSortExpr { - expr: Arc::new(col_c), - options: options2, - }]), Some(vec![ PhysicalSortExpr { - expr: Arc::new(col_a.clone()), + expr: col_a.clone(), options: options1, }, PhysicalSortExpr { - expr: Arc::new(col_d), + expr: col_b.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_c.clone(), + options: options1, + }, + ]), + Some(vec![ + PhysicalSortExpr { + expr: col_a.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_b.clone(), options: options1, }, ]), // Since aggregate expression is reversible (FirstValue), we should be able to resolve below // contradictory requirement by reversing it. Some(vec![PhysicalSortExpr { - expr: Arc::new(col_b.clone()), + expr: col_b.clone(), options: options2, }]), ]; + let common_requirement = Some(vec![ + PhysicalSortExpr { + expr: col_a.clone(), + options: options1, + }, + PhysicalSortExpr { + expr: col_c.clone(), + options: options1, + }, + ]); let aggr_expr = Arc::new(FirstValue::new( - Arc::new(col_a.clone()), + col_a.clone(), "first1", DataType::Int32, + vec![], + vec![], )) as _; let mut aggr_exprs = vec![aggr_expr; order_by_exprs.len()]; - let res = get_finest_requirement( - &mut aggr_exprs, - &mut order_by_exprs, - || eq_properties.clone(), - || ordering_eq_properties.clone(), - )?; - assert_eq!(res, order_by_exprs[4]); + let res = + get_finest_requirement(&mut aggr_exprs, &mut order_by_exprs, &eq_properties)?; + assert_eq!(res, common_requirement); Ok(()) } } diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs similarity index 92% rename from datafusion/core/src/physical_plan/aggregates/no_grouping.rs rename to datafusion/physical-plan/src/aggregates/no_grouping.rs index 89d392f0b67c5..90eb488a2ead2 100644 --- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -17,12 +17,12 @@ //! Aggregate without grouping columns -use crate::physical_plan::aggregates::{ +use crate::aggregates::{ aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, AggregateMode, }; -use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; -use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use crate::metrics::{BaselineMetrics, RecordOutput}; +use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; @@ -33,7 +33,7 @@ use std::borrow::Cow; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::physical_plan::filter::batch_filter; +use crate::filter::batch_filter; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; @@ -79,7 +79,9 @@ impl AggregateStream { let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?; let filter_expressions = match agg.mode { - AggregateMode::Partial | AggregateMode::Single => agg_filter_expr, + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => agg_filter_expr, AggregateMode::Final | AggregateMode::FinalPartitioned => { vec![None; agg.aggr_expr.len()] } @@ -215,16 +217,18 @@ fn aggregate_batch( // 1.3 let values = &expr .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .map(|e| { + e.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) .collect::>>()?; // 1.4 let size_pre = accum.size(); let res = match mode { - AggregateMode::Partial | AggregateMode::Single => { - accum.update_batch(values) - } + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => accum.update_batch(values), AggregateMode::Final | AggregateMode::FinalPartitioned => { accum.merge_batch(values) } diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs new file mode 100644 index 0000000000000..f46ee687faf16 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_physical_expr::EmitTo; + +/// Tracks grouping state when the data is ordered entirely by its +/// group keys +/// +/// When the group values are sorted, as soon as we see group `n+1` we +/// know we will never see any rows for group `n again and thus they +/// can be emitted. +/// +/// For example, given `SUM(amt) GROUP BY id` if the input is sorted +/// by `id` as soon as a new `id` value is seen all previous values +/// can be emitted. +/// +/// The state is tracked like this: +/// +/// ```text +/// ┌─────┐ ┌──────────────────┐ +/// │┌───┐│ │ ┌──────────────┐ │ ┏━━━━━━━━━━━━━━┓ +/// ││ 0 ││ │ │ 123 │ │ ┌─────┃ 13 ┃ +/// │└───┘│ │ └──────────────┘ │ │ ┗━━━━━━━━━━━━━━┛ +/// │ ... │ │ ... │ │ +/// │┌───┐│ │ ┌──────────────┐ │ │ current +/// ││12 ││ │ │ 234 │ │ │ +/// │├───┤│ │ ├──────────────┤ │ │ +/// ││12 ││ │ │ 234 │ │ │ +/// │├───┤│ │ ├──────────────┤ │ │ +/// ││13 ││ │ │ 456 │◀┼───┘ +/// │└───┘│ │ └──────────────┘ │ +/// └─────┘ └──────────────────┘ +/// +/// group indices group_values current tracks the most +/// (in group value recent group index +/// order) +/// ``` +/// +/// In this diagram, the current group is `13`, and thus groups +/// `0..12` can be emitted. Note that `13` can not yet be emitted as +/// there may be more values in the next batch with the same group_id. +#[derive(Debug)] +pub(crate) struct GroupOrderingFull { + state: State, +} + +#[derive(Debug)] +enum State { + /// Seen no input yet + Start, + + /// Data is in progress. `current is the current group for which + /// values are being generated. Can emit `current` - 1 + InProgress { current: usize }, + + /// Seen end of input: all groups can be emitted + Complete, +} + +impl GroupOrderingFull { + pub fn new() -> Self { + Self { + state: State::Start, + } + } + + // How many groups be emitted, or None if no data can be emitted + pub fn emit_to(&self) -> Option { + match &self.state { + State::Start => None, + State::InProgress { current, .. } => { + if *current == 0 { + // Can not emit if still on the first row + None + } else { + // otherwise emit all rows prior to the current group + Some(EmitTo::First(*current)) + } + } + State::Complete { .. } => Some(EmitTo::All), + } + } + + /// remove the first n groups from the internal state, shifting + /// all existing indexes down by `n` + pub fn remove_groups(&mut self, n: usize) { + match &mut self.state { + State::Start => panic!("invalid state: start"), + State::InProgress { current } => { + // shift down by n + assert!(*current >= n); + *current -= n; + } + State::Complete { .. } => panic!("invalid state: complete"), + } + } + + /// Note that the input is complete so any outstanding groups are done as well + pub fn input_done(&mut self) { + self.state = State::Complete; + } + + /// Called when new groups are added in a batch. See documentation + /// on [`super::GroupOrdering::new_groups`] + pub fn new_groups(&mut self, total_num_groups: usize) { + assert_ne!(total_num_groups, 0); + + // Update state + let max_group_index = total_num_groups - 1; + self.state = match self.state { + State::Start => State::InProgress { + current: max_group_index, + }, + State::InProgress { current } => { + // expect to see new group indexes when called again + assert!(current <= max_group_index, "{current} <= {max_group_index}"); + State::InProgress { + current: max_group_index, + } + } + State::Complete { .. } => { + panic!("Saw new group after input was complete"); + } + }; + } + + pub(crate) fn size(&self) -> usize { + std::mem::size_of::() + } +} diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs new file mode 100644 index 0000000000000..b258b97a9e84f --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::ArrayRef; +use arrow_schema::Schema; +use datafusion_common::Result; +use datafusion_physical_expr::{EmitTo, PhysicalSortExpr}; + +mod full; +mod partial; + +use crate::InputOrderMode; +pub(crate) use full::GroupOrderingFull; +pub(crate) use partial::GroupOrderingPartial; + +/// Ordering information for each group in the hash table +#[derive(Debug)] +pub(crate) enum GroupOrdering { + /// Groups are not ordered + None, + /// Groups are ordered by some pre-set of the group keys + Partial(GroupOrderingPartial), + /// Groups are entirely contiguous, + Full(GroupOrderingFull), +} + +impl GroupOrdering { + /// Create a `GroupOrdering` for the the specified ordering + pub fn try_new( + input_schema: &Schema, + mode: &InputOrderMode, + ordering: &[PhysicalSortExpr], + ) -> Result { + match mode { + InputOrderMode::Linear => Ok(GroupOrdering::None), + InputOrderMode::PartiallySorted(order_indices) => { + GroupOrderingPartial::try_new(input_schema, order_indices, ordering) + .map(GroupOrdering::Partial) + } + InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), + } + } + + // How many groups be emitted, or None if no data can be emitted + pub fn emit_to(&self) -> Option { + match self { + GroupOrdering::None => None, + GroupOrdering::Partial(partial) => partial.emit_to(), + GroupOrdering::Full(full) => full.emit_to(), + } + } + + /// Updates the state the input is done + pub fn input_done(&mut self) { + match self { + GroupOrdering::None => {} + GroupOrdering::Partial(partial) => partial.input_done(), + GroupOrdering::Full(full) => full.input_done(), + } + } + + /// remove the first n groups from the internal state, shifting + /// all existing indexes down by `n` + pub fn remove_groups(&mut self, n: usize) { + match self { + GroupOrdering::None => {} + GroupOrdering::Partial(partial) => partial.remove_groups(n), + GroupOrdering::Full(full) => full.remove_groups(n), + } + } + + /// Called when new groups are added in a batch + /// + /// * `total_num_groups`: total number of groups (so max + /// group_index is total_num_groups - 1). + /// + /// * `group_values`: group key values for *each row* in the batch + /// + /// * `group_indices`: indices for each row in the batch + /// + /// * `hashes`: hash values for each row in the batch + pub fn new_groups( + &mut self, + batch_group_values: &[ArrayRef], + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + match self { + GroupOrdering::None => {} + GroupOrdering::Partial(partial) => { + partial.new_groups( + batch_group_values, + group_indices, + total_num_groups, + )?; + } + GroupOrdering::Full(full) => { + full.new_groups(total_num_groups); + } + }; + Ok(()) + } + + /// Return the size of memory used by the ordering state, in bytes + pub(crate) fn size(&self) -> usize { + std::mem::size_of::() + + match self { + GroupOrdering::None => 0, + GroupOrdering::Partial(partial) => partial.size(), + GroupOrdering::Full(full) => full.size(), + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs new file mode 100644 index 0000000000000..ff8a75b9b28be --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; +use arrow_array::ArrayRef; +use arrow_schema::Schema; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_physical_expr::EmitTo; +use datafusion_physical_expr::PhysicalSortExpr; + +/// Tracks grouping state when the data is ordered by some subset of +/// the group keys. +/// +/// Once the next *sort key* value is seen, never see groups with that +/// sort key again, so we can emit all groups with the previous sort +/// key and earlier. +/// +/// For example, given `SUM(amt) GROUP BY id, state` if the input is +/// sorted by `state, when a new value of `state` is seen, all groups +/// with prior values of `state` can be emitted. +/// +/// The state is tracked like this: +/// +/// ```text +/// ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓ +/// ┌─────┐ ┌───────────────────┐ ┌─────┃ 9 ┃ ┃ "MD" ┃ +/// │┌───┐│ │ ┌──────────────┐ │ │ ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛ +/// ││ 0 ││ │ │ 123, "MA" │ │ │ current_sort sort_key +/// │└───┘│ │ └──────────────┘ │ │ +/// │ ... │ │ ... │ │ current_sort tracks the +/// │┌───┐│ │ ┌──────────────┐ │ │ smallest group index that had +/// ││ 8 ││ │ │ 765, "MA" │ │ │ the same sort_key as current +/// │├───┤│ │ ├──────────────┤ │ │ +/// ││ 9 ││ │ │ 923, "MD" │◀─┼─┘ +/// │├───┤│ │ ├──────────────┤ │ ┏━━━━━━━━━━━━━━┓ +/// ││10 ││ │ │ 345, "MD" │ │ ┌─────┃ 11 ┃ +/// │├───┤│ │ ├──────────────┤ │ │ ┗━━━━━━━━━━━━━━┛ +/// ││11 ││ │ │ 124, "MD" │◀─┼──┘ current +/// │└───┘│ │ └──────────────┘ │ +/// └─────┘ └───────────────────┘ +/// +/// group indices +/// (in group value group_values current tracks the most +/// order) recent group index +///``` +#[derive(Debug)] +pub(crate) struct GroupOrderingPartial { + /// State machine + state: State, + + /// The indexes of the group by columns that form the sort key. + /// For example if grouping by `id, state` and ordered by `state` + /// this would be `[1]`. + order_indices: Vec, + + /// Converter for the sort key (used on the group columns + /// specified in `order_indexes`) + row_converter: RowConverter, +} + +#[derive(Debug, Default)] +enum State { + /// The ordering was temporarily taken. `Self::Taken` is left + /// when state must be temporarily taken to satisfy the borrow + /// checker. If an error happens before the state can be restored, + /// the ordering information is lost and execution can not + /// proceed, but there is no undefined behavior. + #[default] + Taken, + + /// Seen no input yet + Start, + + /// Data is in progress. + InProgress { + /// Smallest group index with the sort_key + current_sort: usize, + /// The sort key of group_index `current_sort` + sort_key: OwnedRow, + /// index of the current group for which values are being + /// generated + current: usize, + }, + + /// Seen end of input, all groups can be emitted + Complete, +} + +impl GroupOrderingPartial { + pub fn try_new( + input_schema: &Schema, + order_indices: &[usize], + ordering: &[PhysicalSortExpr], + ) -> Result { + assert!(!order_indices.is_empty()); + assert!(order_indices.len() <= ordering.len()); + + // get only the section of ordering, that consist of group by expressions. + let fields = ordering[0..order_indices.len()] + .iter() + .map(|sort_expr| { + Ok(SortField::new_with_options( + sort_expr.expr.data_type(input_schema)?, + sort_expr.options, + )) + }) + .collect::>>()?; + + Ok(Self { + state: State::Start, + order_indices: order_indices.to_vec(), + row_converter: RowConverter::new(fields)?, + }) + } + + /// Creates sort keys from the group values + /// + /// For example, if group_values had `A, B, C` but the input was + /// only sorted on `B` and `C` this should return rows for (`B`, + /// `C`) + fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Result { + // Take only the columns that are in the sort key + let sort_values: Vec<_> = self + .order_indices + .iter() + .map(|&idx| group_values[idx].clone()) + .collect(); + + Ok(self.row_converter.convert_columns(&sort_values)?) + } + + /// How many groups be emitted, or None if no data can be emitted + pub fn emit_to(&self) -> Option { + match &self.state { + State::Taken => unreachable!("State previously taken"), + State::Start => None, + State::InProgress { current_sort, .. } => { + // Can not emit if we are still on the first row sort + // row otherwise we can emit all groups that had earlier sort keys + // + if *current_sort == 0 { + None + } else { + Some(EmitTo::First(*current_sort)) + } + } + State::Complete => Some(EmitTo::All), + } + } + + /// remove the first n groups from the internal state, shifting + /// all existing indexes down by `n` + pub fn remove_groups(&mut self, n: usize) { + match &mut self.state { + State::Taken => unreachable!("State previously taken"), + State::Start => panic!("invalid state: start"), + State::InProgress { + current_sort, + current, + sort_key: _, + } => { + // shift indexes down by n + assert!(*current >= n); + *current -= n; + assert!(*current_sort >= n); + *current_sort -= n; + } + State::Complete { .. } => panic!("invalid state: complete"), + } + } + + /// Note that the input is complete so any outstanding groups are done as well + pub fn input_done(&mut self) { + self.state = match self.state { + State::Taken => unreachable!("State previously taken"), + _ => State::Complete, + }; + } + + /// Called when new groups are added in a batch. See documentation + /// on [`super::GroupOrdering::new_groups`] + pub fn new_groups( + &mut self, + batch_group_values: &[ArrayRef], + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + assert!(total_num_groups > 0); + assert!(!batch_group_values.is_empty()); + + let max_group_index = total_num_groups - 1; + + // compute the sort key values for each group + let sort_keys = self.compute_sort_keys(batch_group_values)?; + + let old_state = std::mem::take(&mut self.state); + let (mut current_sort, mut sort_key) = match &old_state { + State::Taken => unreachable!("State previously taken"), + State::Start => (0, sort_keys.row(0)), + State::InProgress { + current_sort, + sort_key, + .. + } => (*current_sort, sort_key.row()), + State::Complete => { + panic!("Saw new group after the end of input"); + } + }; + + // Find latest sort key + let iter = group_indices.iter().zip(sort_keys.iter()); + for (&group_index, group_sort_key) in iter { + // Does this group have seen a new sort_key? + if sort_key != group_sort_key { + current_sort = group_index; + sort_key = group_sort_key; + } + } + + self.state = State::InProgress { + current_sort, + sort_key: sort_key.owned(), + current: max_group_index, + }; + + Ok(()) + } + + /// Return the size of memory allocated by this structure + pub(crate) fn size(&self) -> usize { + std::mem::size_of::() + + self.order_indices.allocated_size() + + self.row_converter.size() + } +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs new file mode 100644 index 0000000000000..89614fd3020ce --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -0,0 +1,799 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash aggregation + +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::vec; + +use crate::aggregates::group_values::{new_group_values, GroupValues}; +use crate::aggregates::order::GroupOrderingFull; +use crate::aggregates::{ + evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, + PhysicalGroupBy, +}; +use crate::common::IPCWriter; +use crate::metrics::{BaselineMetrics, RecordOutput}; +use crate::sorts::sort::{read_spill_as_stream, sort_batch}; +use crate::sorts::streaming_merge; +use crate::stream::RecordBatchStreamAdapter; +use crate::{aggregates, ExecutionPlan, PhysicalExpr}; +use crate::{RecordBatchStream, SendableRecordBatchStream}; + +use arrow::array::*; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use arrow_schema::SortOptions; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{ + AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, PhysicalSortExpr, +}; + +use futures::ready; +use futures::stream::{Stream, StreamExt}; +use log::debug; + +#[derive(Debug, Clone)] +/// This object tracks the aggregation phase (input/output) +pub(crate) enum ExecutionState { + ReadingInput, + /// When producing output, the remaining rows to output are stored + /// here and are sliced off as needed in batch_size chunks + ProducingOutput(RecordBatch), + Done, +} + +use super::order::GroupOrdering; +use super::AggregateExec; + +/// This encapsulates the spilling state +struct SpillState { + /// If data has previously been spilled, the locations of the + /// spill files (in Arrow IPC format) + spills: Vec, + + /// Sorting expression for spilling batches + spill_expr: Vec, + + /// Schema for spilling batches + spill_schema: SchemaRef, + + /// true when streaming merge is in progress + is_stream_merging: bool, + + /// aggregate_arguments for merging spilled data + merging_aggregate_arguments: Vec>>, + + /// GROUP BY expressions for merging spilled data + merging_group_by: PhysicalGroupBy, +} + +/// HashTable based Grouping Aggregator +/// +/// # Design Goals +/// +/// This structure is designed so that updating the aggregates can be +/// vectorized (done in a tight loop) without allocations. The +/// accumulator state is *not* managed by this operator (e.g in the +/// hash table) and instead is delegated to the individual +/// accumulators which have type specialized inner loops that perform +/// the aggregation. +/// +/// # Architecture +/// +/// ```text +/// +/// Assigns a consecutive group internally stores aggregate values +/// index for each unique set for all groups +/// of group values +/// +/// ┌────────────┐ ┌──────────────┐ ┌──────────────┐ +/// │ ┌────────┐ │ │┌────────────┐│ │┌────────────┐│ +/// │ │ "A" │ │ ││accumulator ││ ││accumulator ││ +/// │ ├────────┤ │ ││ 0 ││ ││ N ││ +/// │ │ "Z" │ │ ││ ┌────────┐ ││ ││ ┌────────┐ ││ +/// │ └────────┘ │ ││ │ state │ ││ ││ │ state │ ││ +/// │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐ │ ││ +/// │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤ │ ││ +/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ +/// │ │ ││ │ │ ││ ││ │ │ ││ +/// │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ... │ ││ +/// │ │ "Q" │ │ ││ │ │ ││ ││ │ │ ││ +/// │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐ │ ││ +/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││ +/// └────────────┘ ││ └────────┘ ││ ││ └────────┘ ││ +/// │└────────────┘│ │└────────────┘│ +/// └──────────────┘ └──────────────┘ +/// +/// group_values accumulators +/// +/// ``` +/// +/// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`, +/// [`group_values`] will store the distinct values of `z`. There will +/// be one accumulator for `COUNT(x)`, specialized for the data type +/// of `x` and one accumulator for `SUM(y)`, specialized for the data +/// type of `y`. +/// +/// # Description +/// +/// [`group_values`] does not store any aggregate state inline. It only +/// assigns "group indices", one for each (distinct) group value. The +/// accumulators manage the in-progress aggregate state for each +/// group, with the group values themselves are stored in +/// [`group_values`] at the corresponding group index. +/// +/// The accumulator state (e.g partial sums) is managed by and stored +/// by a [`GroupsAccumulator`] accumulator. There is one accumulator +/// per aggregate expression (COUNT, AVG, etc) in the +/// stream. Internally, each `GroupsAccumulator` manages the state for +/// multiple groups, and is passed `group_indexes` during update. Note +/// The accumulator state is not managed by this operator (e.g in the +/// hash table). +/// +/// [`group_values`]: Self::group_values +/// +/// # Spilling +/// +/// The sizes of group values and accumulators can become large. Before that causes out of memory, +/// this hash aggregator outputs partial states early for partial aggregation or spills to local +/// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory +/// manager checks whether the new input size meets the memory configuration. If not, outputting or +/// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling, +/// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot +/// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition, +/// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation. +/// +/// ```text +/// Partial Aggregation [batch_size = 2] (max memory = 3 rows) +/// +/// INPUTS PARTIALLY AGGREGATED (UPDATE BATCH) OUTPUTS +/// ┌─────────┐ ┌─────────────────┐ ┌─────────────────┐ +/// │ a │ b │ │ a │ AVG(b) │ │ a │ AVG(b) │ +/// │---│-----│ │ │[count]│[sum]│ │ │[count]│[sum]│ +/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│ │---│-------│-----│ +/// │ 2 │ 2.0 │ │ 2 │ 1 │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1 │ 2.0 │ +/// └─────────┘ │ 3 │ 2 │ 7.0 │ │ │ 3 │ 2 │ 7.0 │ +/// ┌─────────┐ ─▶ │ 4 │ 1 │ 8.0 │ │ └─────────────────┘ +/// │ 3 │ 4.0 │ └─────────────────┘ └▶ ┌─────────────────┐ +/// │ 4 │ 8.0 │ ┌─────────────────┐ │ 4 │ 1 │ 8.0 │ +/// └─────────┘ │ a │ AVG(b) │ ┌▶ │ 1 │ 1 │ 1.0 │ +/// ┌─────────┐ │---│-------│-----│ │ └─────────────────┘ +/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1 │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐ +/// │ 3 │ 2.0 │ │ 3 │ 1 │ 2.0 │ │ 3 │ 1 │ 2.0 │ +/// └─────────┘ └─────────────────┘ └─────────────────┘ +/// +/// +/// Final Aggregation [batch_size = 2] (max memory = 3 rows) +/// +/// PARTIALLY INPUTS FINAL AGGREGATION (MERGE BATCH) RE-GROUPED (SORTED) +/// ┌─────────────────┐ [keep using the partial schema] [Real final aggregation +/// │ a │ AVG(b) │ ┌─────────────────┐ output] +/// │ │[count]│[sum]│ │ a │ AVG(b) │ ┌────────────┐ +/// │---│-------│-----│ ─▶ │ │[count]│[sum]│ │ a │ AVG(b) │ +/// │ 3 │ 3 │ 3.0 │ │---│-------│-----│ ─▶ spill ─┐ │---│--------│ +/// │ 2 │ 2 │ 1.0 │ │ 2 │ 2 │ 1.0 │ │ │ 1 │ 4.0 │ +/// └─────────────────┘ │ 3 │ 4 │ 8.0 │ ▼ │ 2 │ 1.0 │ +/// ┌─────────────────┐ ─▶ │ 4 │ 1 │ 7.0 │ Streaming ─▶ └────────────┘ +/// │ 3 │ 1 │ 5.0 │ └─────────────────┘ merge sort ─▶ ┌────────────┐ +/// │ 4 │ 1 │ 7.0 │ ┌─────────────────┐ ▲ │ a │ AVG(b) │ +/// └─────────────────┘ │ a │ AVG(b) │ │ │---│--------│ +/// ┌─────────────────┐ │---│-------│-----│ ─▶ memory ─┘ │ 3 │ 2.0 │ +/// │ 1 │ 2 │ 8.0 │ ─▶ │ 1 │ 2 │ 8.0 │ │ 4 │ 7.0 │ +/// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘ +/// └─────────────────┘ └─────────────────┘ +/// ``` +pub(crate) struct GroupedHashAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + mode: AggregateMode, + + /// Accumulators, one for each `AggregateExpr` in the query + /// + /// For example, if the query has aggregates, `SUM(x)`, + /// `COUNT(y)`, there will be two accumulators, each one + /// specialized for that particular aggregate and its input types + accumulators: Vec>, + + /// Arguments to pass to each accumulator. + /// + /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]` + /// + /// The argument to each accumulator is itself a `Vec` because + /// some aggregates such as `CORR` can accept more than one + /// argument. + aggregate_arguments: Vec>>, + + /// Optional filter expression to evaluate, one for each for + /// accumulator. If present, only those rows for which the filter + /// evaluate to true should be included in the aggregate results. + /// + /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, + /// the filter expression is `x > 100`. + filter_expressions: Vec>>, + + /// GROUP BY expressions + group_by: PhysicalGroupBy, + + /// The memory reservation for this grouping + reservation: MemoryReservation, + + /// An interning store of group keys + group_values: Box, + + /// scratch space for the current input [`RecordBatch`] being + /// processed. Reused across batches here to avoid reallocations + current_group_indices: Vec, + + /// Tracks if this stream is generating input or output + exec_state: ExecutionState, + + /// Execution metrics + baseline_metrics: BaselineMetrics, + + /// max rows in output RecordBatches + batch_size: usize, + + /// Optional ordering information, that might allow groups to be + /// emitted from the hash table prior to seeing the end of the + /// input + group_ordering: GroupOrdering, + + /// Have we seen the end of the input + input_done: bool, + + /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument + runtime: Arc, + + /// The spill state object + spill_state: SpillState, + + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. + group_values_soft_limit: Option, +} + +impl GroupedHashAggregateStream { + /// Create a new GroupedHashAggregateStream + pub fn new( + agg: &AggregateExec, + context: Arc, + partition: usize, + ) -> Result { + debug!("Creating GroupedHashAggregateStream"); + let agg_schema = Arc::clone(&agg.schema); + let agg_group_by = agg.group_by.clone(); + let agg_filter_expr = agg.filter_expr.clone(); + + let batch_size = context.session_config().batch_size(); + let input = agg.input.execute(partition, Arc::clone(&context))?; + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + let timer = baseline_metrics.elapsed_compute().timer(); + + let aggregate_exprs = agg.aggr_expr.clone(); + + // arguments for each aggregate, one vec of expressions per + // aggregate + let aggregate_arguments = aggregates::aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg_group_by.expr.len(), + )?; + // arguments for aggregating spilled data is the same as the one for final aggregation + let merging_aggregate_arguments = aggregates::aggregate_expressions( + &agg.aggr_expr, + &AggregateMode::Final, + agg_group_by.expr.len(), + )?; + + let filter_expressions = match agg.mode { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => agg_filter_expr, + AggregateMode::Final | AggregateMode::FinalPartitioned => { + vec![None; agg.aggr_expr.len()] + } + }; + + // Instantiate the accumulators + let accumulators: Vec<_> = aggregate_exprs + .iter() + .map(create_group_accumulator) + .collect::>()?; + + // we need to use original schema so RowConverter in group_values below + // will do the proper coversion of dictionaries into value types + let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); + let spill_expr = group_schema + .fields + .into_iter() + .enumerate() + .map(|(idx, field)| PhysicalSortExpr { + expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, + options: SortOptions::default(), + }) + .collect(); + + let name = format!("GroupedHashAggregateStream[{partition}]"); + let reservation = MemoryConsumer::new(name) + .with_can_spill(true) + .register(context.memory_pool()); + let (ordering, _) = agg + .equivalence_properties() + .find_longest_permutation(&agg_group_by.output_exprs()); + let group_ordering = GroupOrdering::try_new( + &group_schema, + &agg.input_order_mode, + ordering.as_slice(), + )?; + + let group_values = new_group_values(group_schema)?; + timer.done(); + + let exec_state = ExecutionState::ReadingInput; + + let spill_state = SpillState { + spills: vec![], + spill_expr, + spill_schema: agg_schema.clone(), + is_stream_merging: false, + merging_aggregate_arguments, + merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + }; + + Ok(GroupedHashAggregateStream { + schema: agg_schema, + input, + mode: agg.mode, + accumulators, + aggregate_arguments, + filter_expressions, + group_by: agg_group_by, + reservation, + group_values, + current_group_indices: Default::default(), + exec_state, + baseline_metrics, + batch_size, + group_ordering, + input_done: false, + runtime: context.runtime_env(), + spill_state, + group_values_soft_limit: agg.limit, + }) + } +} + +/// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if +/// that is supported by the aggregate, or a +/// [`GroupsAccumulatorAdapter`] if not. +pub(crate) fn create_group_accumulator( + agg_expr: &Arc, +) -> Result> { + if agg_expr.groups_accumulator_supported() { + agg_expr.create_groups_accumulator() + } else { + // Note in the log when the slow path is used + debug!( + "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", + agg_expr.name() + ); + let agg_expr_captured = agg_expr.clone(); + let factory = move || agg_expr_captured.create_accumulator(); + Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) + } +} + +/// Extracts a successful Ok(_) or returns Poll::Ready(Some(Err(e))) with errors +macro_rules! extract_ok { + ($RES: expr) => {{ + match $RES { + Ok(v) => v, + Err(e) => return Poll::Ready(Some(Err(e))), + } + }}; +} + +impl Stream for GroupedHashAggregateStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + match &self.exec_state { + ExecutionState::ReadingInput => 'reading_input: { + match ready!(self.input.poll_next_unpin(cx)) { + // new batch to aggregate + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + extract_ok!(self.emit_early_if_necessary()); + + timer.done(); + } + Some(Err(e)) => { + // inner had error, return to caller + return Poll::Ready(Some(Err(e))); + } + None => { + // inner is done, emit all rows and switch to producing output + extract_ok!(self.set_input_done_and_produce_output()); + } + } + } + + ExecutionState::ProducingOutput(batch) => { + // slice off a part of the batch, if needed + let output_batch; + let size = self.batch_size; + (self.exec_state, output_batch) = if batch.num_rows() <= size { + ( + if self.input_done { + ExecutionState::Done + } else { + ExecutionState::ReadingInput + }, + batch.clone(), + ) + } else { + // output first batch_size rows + let size = self.batch_size; + let num_remaining = batch.num_rows() - size; + let remaining = batch.slice(size, num_remaining); + let output = batch.slice(0, size); + (ExecutionState::ProducingOutput(remaining), output) + }; + return Poll::Ready(Some(Ok( + output_batch.record_output(&self.baseline_metrics) + ))); + } + + ExecutionState::Done => { + // release the memory reservation since sending back output batch itself needs + // some memory reservation, so make some room for it. + self.clear_all(); + let _ = self.update_memory_reservation(); + return Poll::Ready(None); + } + } + } + } +} + +impl RecordBatchStream for GroupedHashAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl GroupedHashAggregateStream { + /// Perform group-by aggregation for the given [`RecordBatch`]. + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { + // Evaluate the grouping expressions + let group_by_values = if self.spill_state.is_stream_merging { + evaluate_group_by(&self.spill_state.merging_group_by, &batch)? + } else { + evaluate_group_by(&self.group_by, &batch)? + }; + + // Evaluate the aggregation expressions. + let input_values = if self.spill_state.is_stream_merging { + evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)? + } else { + evaluate_many(&self.aggregate_arguments, &batch)? + }; + + // Evaluate the filter expressions, if any, against the inputs + let filter_values = if self.spill_state.is_stream_merging { + let filter_expressions = vec![None; self.accumulators.len()]; + evaluate_optional(&filter_expressions, &batch)? + } else { + evaluate_optional(&self.filter_expressions, &batch)? + }; + + for group_values in &group_by_values { + // calculate the group indices for each input row + let starting_num_groups = self.group_values.len(); + self.group_values + .intern(group_values, &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + + // Update ordering information if necessary + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + self.group_ordering.new_groups( + group_values, + group_indices, + total_num_groups, + )?; + } + + // Gather the inputs to call the actual accumulator + let t = self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in t { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes + match self.mode { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned + if !self.spill_state.is_stream_merging => + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + _ => { + // if aggregation is over intermediate states, + // use merge + acc.merge_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } + } + } + } + + match self.update_memory_reservation() { + // Here we can ignore `insufficient_capacity_err` because we will spill later, + // but at least one batch should fit in the memory + Err(DataFusionError::ResourcesExhausted(_)) + if self.group_values.len() >= self.batch_size => + { + Ok(()) + } + other => other, + } + } + + fn update_memory_reservation(&mut self) -> Result<()> { + let acc = self.accumulators.iter().map(|x| x.size()).sum::(); + self.reservation.try_resize( + acc + self.group_values.size() + + self.group_ordering.size() + + self.current_group_indices.allocated_size(), + ) + } + + /// Create an output RecordBatch with the group keys and + /// accumulator states/values specified in emit_to + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { + let schema = if spilling { + self.spill_state.spill_schema.clone() + } else { + self.schema() + }; + if self.group_values.is_empty() { + return Ok(RecordBatch::new_empty(schema)); + } + + let mut output = self.group_values.emit(emit_to)?; + if let EmitTo::First(n) = emit_to { + self.group_ordering.remove_groups(n); + } + + // Next output each aggregate value + for acc in self.accumulators.iter_mut() { + match self.mode { + AggregateMode::Partial => output.extend(acc.state(emit_to)?), + _ if spilling => { + // If spilling, output partial state because the spilled data will be + // merged and re-evaluated later. + output.extend(acc.state(emit_to)?) + } + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), + } + } + + // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is + // over the target memory size after emission, we can emit again rather than returning Err. + let _ = self.update_memory_reservation(); + let batch = RecordBatch::try_new(schema, output)?; + Ok(batch) + } + + /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly + /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the + /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. + fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { + // TODO: support group_ordering for spilling + if self.group_values.len() > 0 + && batch.num_rows() > 0 + && matches!(self.group_ordering, GroupOrdering::None) + && !matches!(self.mode, AggregateMode::Partial) + && !self.spill_state.is_stream_merging + && self.update_memory_reservation().is_err() + { + // Use input batch (Partial mode) schema for spilling because + // the spilled data will be merged and re-evaluated later. + self.spill_state.spill_schema = batch.schema(); + self.spill()?; + self.clear_shrink(batch); + } + Ok(()) + } + + /// Emit all rows, sort them, and store them on disk. + fn spill(&mut self) -> Result<()> { + let emit = self.emit(EmitTo::All, true)?; + let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; + let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; + let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; + // TODO: slice large `sorted` and write to multiple files in parallel + let mut offset = 0; + let total_rows = sorted.num_rows(); + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, self.batch_size); + let batch = sorted.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + + writer.finish()?; + self.spill_state.spills.push(spillfile); + Ok(()) + } + + /// Clear memory and shirk capacities to the size of the batch. + fn clear_shrink(&mut self, batch: &RecordBatch) { + self.group_values.clear_shrink(batch); + self.current_group_indices.clear(); + self.current_group_indices.shrink_to(batch.num_rows()); + } + + /// Clear memory and shirk capacities to zero. + fn clear_all(&mut self) { + let s = self.schema(); + self.clear_shrink(&RecordBatch::new_empty(s)); + } + + /// Emit if the used memory exceeds the target for partial aggregation. + /// Currently only [`GroupOrdering::None`] is supported for early emitting. + /// TODO: support group_ordering for early emitting + fn emit_early_if_necessary(&mut self) -> Result<()> { + if self.group_values.len() >= self.batch_size + && matches!(self.group_ordering, GroupOrdering::None) + && matches!(self.mode, AggregateMode::Partial) + && self.update_memory_reservation().is_err() + { + let n = self.group_values.len() / self.batch_size * self.batch_size; + let batch = self.emit(EmitTo::First(n), false)?; + self.exec_state = ExecutionState::ProducingOutput(batch); + } + Ok(()) + } + + /// At this point, all the inputs are read and there are some spills. + /// Emit the remaining rows and create a batch. + /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully + /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. + fn update_merged_stream(&mut self) -> Result<()> { + let batch = self.emit(EmitTo::All, true)?; + // clear up memory for streaming_merge + self.clear_all(); + self.update_memory_reservation()?; + let mut streams: Vec = vec![]; + let expr = self.spill_state.spill_expr.clone(); + let schema = batch.schema(); + streams.push(Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(futures::future::lazy(move |_| { + sort_batch(&batch, &expr, None) + })), + ))); + for spill in self.spill_state.spills.drain(..) { + let stream = read_spill_as_stream(spill, schema.clone())?; + streams.push(stream); + } + self.spill_state.is_stream_merging = true; + self.input = streaming_merge( + streams, + schema, + &self.spill_state.spill_expr, + self.baseline_metrics.clone(), + self.batch_size, + None, + self.reservation.new_empty(), + )?; + self.input_done = false; + self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + Ok(()) + } + + /// returns true if there is a soft groups limit and the number of distinct + /// groups we have seen is over that limit + fn hit_soft_group_limit(&self) -> bool { + let Some(group_values_soft_limit) = self.group_values_soft_limit else { + return false; + }; + group_values_soft_limit <= self.group_values.len() + } + + /// common function for signalling end of processing of the input stream + fn set_input_done_and_produce_output(&mut self) -> Result<()> { + self.input_done = true; + self.group_ordering.input_done(); + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + self.exec_state = if self.spill_state.spills.is_empty() { + let batch = self.emit(EmitTo::All, false)?; + ExecutionState::ProducingOutput(batch) + } else { + // If spill files exist, stream-merge them. + self.update_merged_stream()?; + ExecutionState::ReadingInput + }; + timer.done(); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs new file mode 100644 index 0000000000000..808a068b28506 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index + +use crate::aggregates::group_values::primitive::HashValue; +use crate::aggregates::topk::heap::Comparable; +use ahash::RandomState; +use arrow::datatypes::i256; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray, +}; +use arrow_schema::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use half::f16; +use hashbrown::raw::RawTable; +use std::fmt::Debug; +use std::sync::Arc; + +/// A "type alias" for Keys which are stored in our map +pub trait KeyType: Clone + Comparable + Debug {} + +impl KeyType for T where T: Clone + Comparable + Debug {} + +/// An entry in our hash table that: +/// 1. memoizes the hash +/// 2. contains the key (ID) +/// 3. contains the value (heap_idx - an index into the corresponding heap) +pub struct HashTableItem { + hash: u64, + pub id: ID, + pub heap_idx: usize, +} + +/// A custom wrapper around `hashbrown::RawTable` that: +/// 1. limits the number of entries to the top K +/// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing +/// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash +/// 4. Catches resize events to allow the corresponding heap to update it's indexes +struct TopKHashTable { + map: RawTable>, + limit: usize, +} + +/// An interface to hide the generic type signature of TopKHashTable behind arrow arrays +pub trait ArrowHashTable { + fn set_batch(&mut self, ids: ArrayRef); + fn len(&self) -> usize; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide valid indexes + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide a valid index + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; + unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef; + + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide valid indexes + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> (usize, bool); +} + +// An implementation of ArrowHashTable for String keys +pub struct StringHashTable { + owned: ArrayRef, + map: TopKHashTable>, + rnd: RandomState, +} + +// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key +struct PrimitiveHashTable +where + Option<::Native>: Comparable, +{ + owned: ArrayRef, + map: TopKHashTable>, + rnd: RandomState, +} + +impl StringHashTable { + pub fn new(limit: usize) -> Self { + let vals: Vec<&str> = Vec::new(); + let owned = Arc::new(StringArray::from(vals)); + Self { + owned, + map: TopKHashTable::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl ArrowHashTable for StringHashTable { + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { + let ids = self.map.take_all(indexes); + Arc::new(StringArray::from(ids)) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("StringArray required"); + let id = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash = self.rnd.hash_one(id); + if let Some(map_idx) = self + .map + .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) + { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let id = id.map(|id| id.to_string()); + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +impl PrimitiveHashTable +where + Option<::Native>: Comparable, + Option<::Native>: HashValue, +{ + pub fn new(limit: usize) -> Self { + let owned = Arc::new(PrimitiveArray::::builder(0).finish()); + Self { + owned, + map: TopKHashTable::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl ArrowHashTable for PrimitiveHashTable +where + Option<::Native>: Comparable, + Option<::Native>: HashValue, +{ + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { + let ids = self.map.take_all(indexes); + let mut builder: PrimitiveBuilder = PrimitiveArray::builder(ids.len()); + for id in ids.into_iter() { + match id { + None => builder.append_null(), + Some(id) => builder.append_value(id), + } + } + let ids = builder.finish(); + Arc::new(ids) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self.owned.as_primitive::(); + let id: Option = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash: u64 = id.hash(&self.rnd); + if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +impl TopKHashTable { + pub fn new(limit: usize, capacity: usize) -> Self { + Self { + map: RawTable::with_capacity(capacity), + limit, + } + } + + pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option { + let bucket = self.map.find(hash, |mi| eq(&mi.id))?; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: getting the index of a bucket we just found + let idx = unsafe { self.map.bucket_index(&bucket) }; + Some(idx) + } + + pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + let bucket = unsafe { self.map.bucket(map_idx) }; + bucket.as_ref().heap_idx + } + + pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { + if self.map.len() >= self.limit { + self.map.erase(self.map.bucket(replace_idx)); + 0 // if full, always replace top node + } else { + self.map.len() // if we're not full, always append to end + } + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + for (m, h) in mapper { + self.map.bucket(*m).as_mut().heap_idx = *h + } + } + + pub fn insert( + &mut self, + hash: u64, + id: ID, + heap_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> usize { + let mi = HashTableItem::new(hash, id, heap_idx); + let bucket = self.map.try_insert_no_grow(hash, mi); + let bucket = match bucket { + Ok(bucket) => bucket, + Err(new_item) => { + let bucket = self.map.insert(hash, new_item, |mi| mi.hash); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: we're getting indexes of buckets, not dereferencing them + unsafe { + for bucket in self.map.iter() { + let heap_idx = bucket.as_ref().heap_idx; + let map_idx = self.map.bucket_index(&bucket); + mapper.push((heap_idx, map_idx)); + } + } + bucket + } + }; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: we're getting indexes of buckets, not dereferencing them + unsafe { self.map.bucket_index(&bucket) } + } + + pub fn len(&self) -> usize { + self.map.len() + } + + pub unsafe fn take_all(&mut self, idxs: Vec) -> Vec { + let ids = idxs + .into_iter() + .map(|idx| self.map.bucket(idx).as_ref().id.clone()) + .collect(); + self.map.clear(); + ids + } +} + +impl HashTableItem { + pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self { + Self { hash, id, heap_idx } + } +} + +impl HashValue for Option { + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self) + } +} + +macro_rules! hash_float { + ($($t:ty),+) => { + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +macro_rules! has_integer { + ($($t:ty),+) => { + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +has_integer!(i8, i16, i32, i64, i128, i256); +has_integer!(u8, u16, u32, u64); +hash_float!(f16, f32, f64); + +pub fn new_hash_table(limit: usize, kt: DataType) -> Result> { + macro_rules! downcast_helper { + ($kt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) + }; + } + + downcast_primitive! { + kt => (downcast_helper, kt), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + _ => {} + } + + Err(DataFusionError::Execution(format!( + "Can't create HashTable for type: {kt:?}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::Result; + use std::collections::BTreeMap; + + #[test] + fn should_resize_properly() -> Result<()> { + let mut heap_to_map = BTreeMap::::new(); + let mut map = TopKHashTable::>::new(5, 3); + for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() { + let mut mapper = vec![]; + let hash = heap_idx as u64; + let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper); + let _ = heap_to_map.insert(heap_idx, map_idx); + if heap_idx == 3 { + assert_eq!( + mapper, + vec![(0, 0), (1, 1), (2, 2), (3, 3)], + "Pass {heap_idx} resized incorrectly!" + ); + for (heap_idx, map_idx) in mapper { + let _ = heap_to_map.insert(heap_idx, map_idx); + } + } else { + assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!"); + } + } + + let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); + let ids = unsafe { map.take_all(map_idxs) }; + assert_eq!( + format!("{:?}", ids), + r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# + ); + assert_eq!(map.len(), 0, "Map should have been cleared!"); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs new file mode 100644 index 0000000000000..bf95a42bde515 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -0,0 +1,627 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A custom binary heap implementation for performant top K aggregation + +use arrow::datatypes::i256; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_physical_expr::aggregate::utils::adjust_output_array; +use half::f16; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +/// A custom version of `Ord` that only exists to we can implement it for the Values in our heap +pub trait Comparable { + fn comp(&self, other: &Self) -> Ordering; +} + +impl Comparable for Option { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } +} + +/// A "type alias" for Values which are stored in our heap +pub trait ValueType: Comparable + Clone + Debug {} + +impl ValueType for T where T: Comparable + Clone + Debug {} + +/// An entry in our heap, which contains both the value and a index into an external HashTable +struct HeapItem { + val: VAL, + map_idx: usize, +} + +/// A custom heap implementation that allows several things that couldn't be achieved with +/// `collections::BinaryHeap`: +/// 1. It allows values to be updated at arbitrary positions (when group values change) +/// 2. It can be either a min or max heap +/// 3. It can use our `HeapItem` type & `Comparable` trait +/// 4. It is specialized to grow to a certain limit, then always replace without grow & shrink +struct TopKHeap { + desc: bool, + len: usize, + capacity: usize, + heap: Vec>>, +} + +/// An interface to hide the generic type signature of TopKHeap behind arrow arrays +pub trait ArrowHeap { + fn set_batch(&mut self, vals: ArrayRef); + fn is_worse(&self, idx: usize) -> bool; + fn worst_map_idx(&self) -> usize; + fn renumber(&mut self, heap_to_map: &[(usize, usize)]); + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ); + fn drain(&mut self) -> (ArrayRef, Vec); +} + +/// An implementation of `ArrowHeap` that deals with primitive values +pub struct PrimitiveHeap +where + ::Native: Comparable, +{ + batch: ArrayRef, + heap: TopKHeap, + desc: bool, + data_type: DataType, +} + +impl PrimitiveHeap +where + ::Native: Comparable, +{ + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let owned: ArrayRef = Arc::new(PrimitiveArray::::builder(0).finish()); + Self { + batch: owned, + heap: TopKHeap::new(limit, desc), + desc, + data_type, + } + } +} + +impl ArrowHeap for PrimitiveHeap +where + ::Native: Comparable, +{ + fn set_batch(&mut self, vals: ArrayRef) { + self.batch = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val) + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + self.heap.renumber(heap_to_map); + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + self.heap.append_or_replace(new_val, map_idx, map); + } + + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + self.heap.replace_if_better(heap_idx, new_val, map); + } + + fn drain(&mut self) -> (ArrayRef, Vec) { + let (vals, map_idxs) = self.heap.drain(); + let vals = Arc::new(PrimitiveArray::::from_iter_values(vals)); + let vals = adjust_output_array(&self.data_type, vals).expect("Type is incorrect"); + (vals, map_idxs) + } +} + +impl TopKHeap { + pub fn new(limit: usize, desc: bool) -> Self { + Self { + desc, + capacity: limit, + len: 0, + heap: (0..=limit).map(|_| None).collect::>(), + } + } + + pub fn worst_val(&self) -> Option<&VAL> { + let root = self.heap.first()?; + let hi = match root { + None => return None, + Some(hi) => hi, + }; + Some(&hi.val) + } + + pub fn worst_map_idx(&self) -> usize { + self.heap[0].as_ref().map(|hi| hi.map_idx).unwrap_or(0) + } + + pub fn is_full(&self) -> bool { + self.len >= self.capacity + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn append_or_replace( + &mut self, + new_val: VAL, + map_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + if self.is_full() { + self.replace_root(new_val, map_idx, map); + } else { + self.append(new_val, map_idx, map); + } + } + + fn append(&mut self, new_val: VAL, map_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let hi = HeapItem::new(new_val, map_idx); + self.heap[self.len] = Some(hi); + self.heapify_up(self.len, mapper); + self.len += 1; + } + + fn pop(&mut self, map: &mut Vec<(usize, usize)>) -> Option> { + if self.len() == 0 { + return None; + } + if self.len() == 1 { + self.len = 0; + return self.heap[0].take(); + } + self.swap(0, self.len - 1, map); + let former_root = self.heap[self.len - 1].take(); + self.len -= 1; + self.heapify_down(0, map); + former_root + } + + pub fn drain(&mut self) -> (Vec, Vec) { + let mut map = Vec::with_capacity(self.len); + let mut vals = Vec::with_capacity(self.len); + let mut map_idxs = Vec::with_capacity(self.len); + while let Some(worst_hi) = self.pop(&mut map) { + vals.push(worst_hi.val); + map_idxs.push(worst_hi.map_idx); + } + vals.reverse(); + map_idxs.reverse(); + (vals, map_idxs) + } + + fn replace_root( + &mut self, + new_val: VAL, + map_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) { + let hi = self.heap[0].as_mut().expect("No root"); + hi.val = new_val; + hi.map_idx = map_idx; + self.heapify_down(0, mapper); + } + + pub fn replace_if_better( + &mut self, + heap_idx: usize, + new_val: VAL, + mapper: &mut Vec<(usize, usize)>, + ) { + let existing = self.heap[heap_idx].as_mut().expect("Missing heap item"); + if (!self.desc && new_val.comp(&existing.val) != Ordering::Less) + || (self.desc && new_val.comp(&existing.val) != Ordering::Greater) + { + return; + } + existing.val = new_val; + self.heapify_down(heap_idx, mapper); + } + + pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + for (heap_idx, map_idx) in heap_to_map.iter() { + if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { + hi.map_idx = *map_idx; + } + } + } + + fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { + let desc = self.desc; + while idx != 0 { + let parent_idx = (idx - 1) / 2; + let node = self.heap[idx].as_ref().expect("No heap item"); + let parent = self.heap[parent_idx].as_ref().expect("No heap item"); + if (!desc && node.val.comp(&parent.val) != Ordering::Greater) + || (desc && node.val.comp(&parent.val) != Ordering::Less) + { + return; + } + self.swap(idx, parent_idx, mapper); + idx = parent_idx; + } + } + + fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let a_hi = self.heap[a_idx].take().expect("Missing heap entry"); + let b_hi = self.heap[b_idx].take().expect("Missing heap entry"); + + mapper.push((a_hi.map_idx, b_idx)); + mapper.push((b_hi.map_idx, a_idx)); + + self.heap[a_idx] = Some(b_hi); + self.heap[b_idx] = Some(a_hi); + } + + fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let left_child = node_idx * 2 + 1; + let desc = self.desc; + let entry = self.heap.get(node_idx).expect("Missing node!"); + let entry = entry.as_ref().expect("Missing node!"); + let mut best_idx = node_idx; + let mut best_val = &entry.val; + for child_idx in left_child..=left_child + 1 { + if let Some(Some(child)) = self.heap.get(child_idx) { + if (!desc && child.val.comp(best_val) == Ordering::Greater) + || (desc && child.val.comp(best_val) == Ordering::Less) + { + best_val = &child.val; + best_idx = child_idx; + } + } + } + if best_val.comp(&entry.val) != Ordering::Equal { + self.swap(best_idx, node_idx, mapper); + self.heapify_down(best_idx, mapper); + } + } + + #[cfg(test)] + fn _tree_print(&self, idx: usize) -> Option> { + let hi = self.heap.get(idx)?; + match hi { + None => None, + Some(hi) => { + let label = + format!("val={:?} idx={}, bucket={}", hi.val, idx, hi.map_idx); + let left = self._tree_print(idx * 2 + 1); + let right = self._tree_print(idx * 2 + 2); + let children = left.into_iter().chain(right); + let me = termtree::Tree::new(label).with_leaves(children); + Some(me) + } + } + } + + #[cfg(test)] + fn tree_print(&self) -> String { + match self._tree_print(0) { + None => "".to_string(), + Some(root) => format!("{}", root), + } + } +} + +impl HeapItem { + pub fn new(val: VAL, buk_idx: usize) -> Self { + Self { + val, + map_idx: buk_idx, + } + } +} + +impl Debug for HeapItem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("bucket=")?; + self.map_idx.fmt(f)?; + f.write_str(" val=")?; + self.val.fmt(f)?; + f.write_str("\n")?; + Ok(()) + } +} + +impl Eq for HeapItem {} + +impl PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + let res = self.val.comp(&other.val); + if res != Ordering::Equal { + return res; + } + self.map_idx.cmp(&other.map_idx) + } +} + +macro_rules! compare_float { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + match (self, other) { + (Some(me), Some(other)) => me.total_cmp(other), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.total_cmp(other) + } + })+ + }; +} + +macro_rules! compare_integer { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + }; +} + +compare_integer!(i8, i16, i32, i64, i128, i256); +compare_integer!(u8, u16, u32, u64); +compare_float!(f16, f32, f64); + +pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result> { + macro_rules! downcast_helper { + ($vt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) + }; + } + + downcast_primitive! { + vt => (downcast_helper, vt), + _ => {} + } + + Err(DataFusionError::Execution(format!( + "Can't group type: {vt:?}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::Result; + + #[test] + fn should_append() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + heap.append_or_replace(1, 1, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +val=1 idx=0, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } + + #[test] + fn should_heapify_up() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + assert_eq!(map, vec![]); + + heap.append_or_replace(2, 2, &mut map); + assert_eq!(map, vec![(2, 0), (1, 1)]); + + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=2 +└── val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } + + #[test] + fn should_heapify_down() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(3, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + heap.append_or_replace(3, 3, &mut map); + let actual = heap.tree_print(); + let expected = r#" +val=3 idx=0, bucket=3 +├── val=1 idx=1, bucket=1 +└── val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let mut map = vec![]; + heap.append_or_replace(0, 0, &mut map); + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=2 +├── val=1 idx=1, bucket=1 +└── val=0 idx=2, bucket=0 + "#; + assert_eq!(actual.trim(), expected.trim()); + assert_eq!(map, vec![(2, 0), (0, 2)]); + + Ok(()) + } + + #[test] + fn should_replace() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(4, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + heap.append_or_replace(3, 3, &mut map); + heap.append_or_replace(4, 4, &mut map); + let actual = heap.tree_print(); + let expected = r#" +val=4 idx=0, bucket=4 +├── val=3 idx=1, bucket=3 +│ └── val=1 idx=3, bucket=1 +└── val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let mut map = vec![]; + heap.replace_if_better(1, 0, &mut map); + let actual = heap.tree_print(); + let expected = r#" +val=4 idx=0, bucket=4 +├── val=1 idx=1, bucket=1 +│ └── val=0 idx=3, bucket=3 +└── val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + assert_eq!(map, vec![(1, 1), (3, 3)]); + + Ok(()) + } + + #[test] + fn should_find_worst() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=2 +└── val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + assert_eq!(heap.worst_val(), Some(&2)); + assert_eq!(heap.worst_map_idx(), 2); + + Ok(()) + } + + #[test] + fn should_drain() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=2 +└── val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let (vals, map_idxs) = heap.drain(); + assert_eq!(vals, vec![1, 2]); + assert_eq!(map_idxs, vec![1, 2]); + assert_eq!(heap.len(), 0); + + Ok(()) + } + + #[test] + fn should_renumber() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=2 +└── val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let numbers = vec![(0, 1), (1, 2)]; + heap.renumber(numbers.as_slice()); + let actual = heap.tree_print(); + let expected = r#" +val=2 idx=0, bucket=1 +└── val=1 idx=1, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/topk/mod.rs b/datafusion/physical-plan/src/aggregates/topk/mod.rs new file mode 100644 index 0000000000000..c6a0f40cc8171 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/topk/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TopK functionality for aggregates + +pub mod hash_table; +pub mod heap; +pub mod priority_map; diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs new file mode 100644 index 0000000000000..ee72e4083bf46 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -0,0 +1,381 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` + +use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable}; +use crate::aggregates::topk::heap::{new_heap, ArrowHeap}; +use arrow_array::ArrayRef; +use arrow_schema::DataType; +use datafusion_common::Result; + +/// A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` +pub struct PriorityMap { + map: Box, + heap: Box, + capacity: usize, + mapper: Vec<(usize, usize)>, +} + +// JUSTIFICATION +// Benefit: ~15% speedup + required to index into RawTable from binary heap +// Soundness: it is only accessed by one thread at a time, and indexes are kept up to date +unsafe impl Send for PriorityMap {} + +impl PriorityMap { + pub fn new( + key_type: DataType, + val_type: DataType, + capacity: usize, + descending: bool, + ) -> Result { + Ok(Self { + map: new_hash_table(capacity, key_type)?, + heap: new_heap(capacity, descending, val_type)?, + capacity, + mapper: Vec::with_capacity(capacity), + }) + } + + pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) { + self.map.set_batch(ids); + self.heap.set_batch(vals); + } + + pub fn insert(&mut self, row_idx: usize) -> Result<()> { + assert!(self.map.len() <= self.capacity, "Overflow"); + + // if we're full, and the new val is worse than all our values, just bail + if self.heap.is_worse(row_idx) { + return Ok(()); + } + let map = &mut self.mapper; + + // handle new groups we haven't seen yet + map.clear(); + let replace_idx = self.heap.worst_map_idx(); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: replace_idx kept valid during resizes + let (map_idx, did_insert) = + unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + if did_insert { + self.heap.renumber(map); + map.clear(); + self.heap.insert(row_idx, map_idx, map); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the map was created on the line above, so all the indexes should be valid + unsafe { self.map.update_heap_idx(map) }; + return Ok(()); + }; + + // this is a value for an existing group + map.clear(); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: map_idx was just found, so it is valid + let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; + self.heap.replace_if_better(heap_idx, row_idx, map); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the index map was just built, so it will be valid + unsafe { self.map.update_heap_idx(map) }; + + Ok(()) + } + + pub fn emit(&mut self) -> Result> { + let (vals, map_idxs) = self.heap.drain(); + let ids = unsafe { self.map.take_all(map_idxs) }; + Ok(vec![ids, vals]) + } + + pub fn is_empty(&self) -> bool { + self.map.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::{Int64Array, RecordBatch, StringArray}; + use arrow_schema::Field; + use arrow_schema::Schema; + use arrow_schema::{DataType, SchemaRef}; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn should_append() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_handle_null_ids() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + agg.insert(2)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| | 3 | +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } +} diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs new file mode 100644 index 0000000000000..9f25473cb9b42 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A memory-conscious aggregation implementation that limits group buckets to a fixed number + +use crate::aggregates::topk::priority_map::PriorityMap; +use crate::aggregates::{ + aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, + PhysicalGroupBy, +}; +use crate::{RecordBatchStream, SendableRecordBatchStream}; +use arrow::util::pretty::print_batches; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{Stream, StreamExt}; +use log::{trace, Level}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct GroupedTopKAggregateStream { + partition: usize, + row_count: usize, + started: bool, + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec>>, + group_by: PhysicalGroupBy, + priority_map: PriorityMap, +} + +impl GroupedTopKAggregateStream { + pub fn new( + aggr: &AggregateExec, + context: Arc, + partition: usize, + limit: usize, + ) -> Result { + let agg_schema = Arc::clone(&aggr.schema); + let group_by = aggr.group_by.clone(); + let input = aggr.input.execute(partition, Arc::clone(&context))?; + let aggregate_arguments = + aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; + let (val_field, desc) = aggr + .get_minmax_desc() + .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; + + let (expr, _) = &aggr.group_expr().expr()[0]; + let kt = expr.data_type(&aggr.input().schema())?; + let vt = val_field.data_type().clone(); + + let priority_map = PriorityMap::new(kt, vt, limit, desc)?; + + Ok(GroupedTopKAggregateStream { + partition, + started: false, + row_count: 0, + schema: agg_schema, + input, + aggregate_arguments, + group_by, + priority_map, + }) + } +} + +impl RecordBatchStream for GroupedTopKAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl GroupedTopKAggregateStream { + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { + let len = ids.len(); + self.priority_map.set_batch(ids, vals.clone()); + + let has_nulls = vals.null_count() > 0; + for row_idx in 0..len { + if has_nulls && vals.is_null(row_idx) { + continue; + } + self.priority_map.insert(row_idx)?; + } + Ok(()) + } +} + +impl Stream for GroupedTopKAggregateStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { + match res { + // got a batch, convert to rows and append to our TreeMap + Some(Ok(batch)) => { + self.started = true; + trace!( + "partition {} has {} rows and got batch with {} rows", + self.partition, + self.row_count, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { + print_batches(&[batch.clone()])?; + } + self.row_count += batch.num_rows(); + let batches = &[batch]; + let group_by_values = + evaluate_group_by(&self.group_by, batches.first().unwrap())?; + assert_eq!( + group_by_values.len(), + 1, + "Exactly 1 group value required" + ); + assert_eq!( + group_by_values[0].len(), + 1, + "Exactly 1 group value required" + ); + let group_by_values = group_by_values[0][0].clone(); + let input_values = evaluate_many( + &self.aggregate_arguments, + batches.first().unwrap(), + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + let input_values = input_values[0][0].clone(); + + // iterate over each column of group_by values + (*self).intern(group_by_values, input_values)?; + } + // inner is done, emit all rows and switch to producing output + None => { + if self.priority_map.is_empty() { + trace!("partition {} emit None", self.partition); + return Poll::Ready(None); + } + let cols = self.priority_map.emit()?; + let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + trace!( + "partition {} emit batch with {} rows", + self.partition, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) { + print_batches(&[batch.clone()])?; + } + return Poll::Ready(Some(Ok(batch))); + } + // inner had error, return to caller + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + } + } + Poll::Pending + } +} diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/physical-plan/src/analyze.rs similarity index 82% rename from datafusion/core/src/physical_plan/analyze.rs rename to datafusion/physical-plan/src/analyze.rs index 3923033d2e6ea..ded37983bb211 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -20,25 +20,27 @@ use std::sync::Arc; use std::{any::Any, time::Instant}; -use crate::physical_plan::{ - display::DisplayableExecutionPlan, DisplayFormatType, ExecutionPlan, Partitioning, - Statistics, -}; -use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::{DataFusionError, Result}; -use futures::StreamExt; - use super::expressions::PhysicalSortExpr; use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; -use super::{Distribution, SendableRecordBatchStream}; +use super::{DisplayAs, Distribution, SendableRecordBatchStream}; + +use crate::display::DisplayableExecutionPlan; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use futures::StreamExt; + /// `EXPLAIN ANALYZE` execution plan operator. This operator runs its input, /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] pub struct AnalyzeExec { /// control how much extra to print verbose: bool, + /// if statistics should be displayed + show_statistics: bool, /// The input plan (the plan being analyzed) pub(crate) input: Arc, /// The output schema for RecordBatches of this exec node @@ -47,13 +49,48 @@ pub struct AnalyzeExec { impl AnalyzeExec { /// Create a new AnalyzeExec - pub fn new(verbose: bool, input: Arc, schema: SchemaRef) -> Self { + pub fn new( + verbose: bool, + show_statistics: bool, + input: Arc, + schema: SchemaRef, + ) -> Self { AnalyzeExec { verbose, + show_statistics, input, schema, } } + + /// access to verbose + pub fn verbose(&self) -> bool { + self.verbose + } + + /// access to show_statistics + pub fn show_statistics(&self) -> bool { + self.show_statistics + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for AnalyzeExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "AnalyzeExec verbose={}", self.verbose) + } + } + } } impl ExecutionPlan for AnalyzeExec { @@ -79,9 +116,7 @@ impl ExecutionPlan for AnalyzeExec { /// If the plan does not support pipelining, but its input(s) are /// infinite, returns an error to indicate this. fn unbounded_output(&self, _children: &[bool]) -> Result { - Err(DataFusionError::Internal( - "Optimization not supported for ANALYZE".to_string(), - )) + internal_err!("Optimization not supported for ANALYZE") } /// Get the output partitioning of this plan @@ -99,6 +134,7 @@ impl ExecutionPlan for AnalyzeExec { ) -> Result> { Ok(Arc::new(Self::new( self.verbose, + self.show_statistics, children.pop().unwrap(), self.schema.clone(), ))) @@ -110,9 +146,9 @@ impl ExecutionPlan for AnalyzeExec { context: Arc, ) -> Result { if 0 != partition { - return Err(DataFusionError::Internal(format!( + return internal_err!( "AnalyzeExec invalid partition. Expected 0, got {partition}" - ))); + ); } // Gather futures that will run each input partition in @@ -131,6 +167,7 @@ impl ExecutionPlan for AnalyzeExec { let captured_input = self.input.clone(); let captured_schema = self.schema.clone(); let verbose = self.verbose; + let show_statistics = self.show_statistics; // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final @@ -145,6 +182,7 @@ impl ExecutionPlan for AnalyzeExec { let duration = Instant::now() - start; create_output_batch( verbose, + show_statistics, total_rows, duration, captured_input, @@ -157,28 +195,12 @@ impl ExecutionPlan for AnalyzeExec { futures::stream::once(output), ))) } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "AnalyzeExec verbose={}", self.verbose) - } - } - } - - fn statistics(&self) -> Statistics { - // Statistics an an ANALYZE plan are not relevant - Statistics::default() - } } /// Creates the ouput of AnalyzeExec as a RecordBatch fn create_output_batch( verbose: bool, + show_statistics: bool, total_rows: usize, duration: std::time::Duration, input: Arc, @@ -191,7 +213,8 @@ fn create_output_batch( type_builder.append_value("Plan with Metrics"); let annotated_plan = DisplayableExecutionPlan::with_metrics(input.as_ref()) - .indent() + .set_show_statistics(show_statistics) + .indent(verbose) .to_string(); plan_builder.append_value(annotated_plan); @@ -201,7 +224,8 @@ fn create_output_batch( type_builder.append_value("Plan with Full Metrics"); let annotated_plan = DisplayableExecutionPlan::with_full_metrics(input.as_ref()) - .indent() + .set_show_statistics(show_statistics) + .indent(verbose) .to_string(); plan_builder.append_value(annotated_plan); @@ -209,7 +233,7 @@ fn create_output_batch( plan_builder.append_value(total_rows.to_string()); type_builder.append_value("Duration"); - plan_builder.append_value(format!("{:?}", duration)); + plan_builder.append_value(format!("{duration:?}")); } RecordBatch::try_new( @@ -227,9 +251,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; - use crate::prelude::SessionContext; use crate::{ - physical_plan::collect, + collect, test::{ assert_is_pending, exec::{assert_strong_count_converges_to_zero, BlockingExec}, @@ -240,14 +263,13 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); - let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, schema)); + let analyze_exec = Arc::new(AnalyzeExec::new(true, false, blocking_exec, schema)); let fut = collect(analyze_exec, task_ctx); let mut fut = fut.boxed(); diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs similarity index 83% rename from datafusion/core/src/physical_plan/coalesce_batches.rs rename to datafusion/physical-plan/src/coalesce_batches.rs index 0ca01aacfa19c..09d1ea87ca370 100644 --- a/datafusion/core/src/physical_plan/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -23,23 +23,24 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::physical_plan::{ - DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, +use super::expressions::PhysicalSortExpr; +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use super::{DisplayAs, Statistics}; +use crate::{ + DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, }; -use datafusion_common::Result; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; + use futures::stream::{Stream, StreamExt}; use log::trace; -use super::expressions::PhysicalSortExpr; -use super::metrics::{BaselineMetrics, MetricsSet}; -use super::{metrics::ExecutionPlanMetricsSet, Statistics}; - /// CoalesceBatchesExec combines small batches into larger batches for more efficient use of /// vectorized processing by upstream operators. #[derive(Debug)] @@ -73,6 +74,24 @@ impl CoalesceBatchesExec { } } +impl DisplayAs for CoalesceBatchesExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "CoalesceBatchesExec: target_batch_size={}", + self.target_batch_size + ) + } + } + } +} + impl ExecutionPlan for CoalesceBatchesExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -107,6 +126,14 @@ impl ExecutionPlan for CoalesceBatchesExec { self.input.output_ordering() } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + fn equivalence_properties(&self) -> EquivalenceProperties { self.input.equivalence_properties() } @@ -137,27 +164,11 @@ impl ExecutionPlan for CoalesceBatchesExec { })) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "CoalesceBatchesExec: target_batch_size={}", - self.target_batch_size - ) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } @@ -213,17 +224,17 @@ impl CoalesceBatchesStream { let _timer = cloned_time.timer(); match input_batch { Poll::Ready(x) => match x { - Some(Ok(ref batch)) => { + Some(Ok(batch)) => { if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { - return Poll::Ready(Some(Ok(batch.clone()))); + return Poll::Ready(Some(Ok(batch))); } else if batch.num_rows() == 0 { // discard empty batches } else { // add to the buffered batches - self.buffer.push(batch.clone()); self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); // check to see if we have enough batches yet if self.buffered_rows >= self.target_batch_size { // combine the batches and return @@ -285,52 +296,16 @@ pub fn concat_batches( batches.len(), row_count ); - let b = arrow::compute::concat_batches(schema, batches)?; - Ok(b) + arrow::compute::concat_batches(schema, batches) } #[cfg(test)] mod tests { use super::*; - use crate::config::ConfigOptions; - use crate::datasource::MemTable; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; - use crate::prelude::SessionContext; - use crate::test::create_vec_batches; - use arrow::datatypes::{DataType, Field, Schema}; - - #[tokio::test] - async fn test_custom_batch_size() -> Result<()> { - let mut config = ConfigOptions::new(); - config.execution.batch_size = 1234; - - let ctx = SessionContext::with_config(config.into()); - let plan = create_physical_plan(ctx).await?; - let coalesce = plan.as_any().downcast_ref::().unwrap(); - assert_eq!(1234, coalesce.target_batch_size); - Ok(()) - } - - #[tokio::test] - async fn test_disable_coalesce() -> Result<()> { - let mut config = ConfigOptions::new(); - config.execution.coalesce_batches = false; - - let ctx = SessionContext::with_config(config.into()); - let plan = create_physical_plan(ctx).await?; - let _filter = plan.as_any().downcast_ref::().unwrap(); - Ok(()) - } + use crate::{memory::MemoryExec, repartition::RepartitionExec}; - async fn create_physical_plan(ctx: SessionContext) -> Result> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let table = MemTable::try_new(schema, vec![partition])?; - ctx.register_table("a", Arc::new(table))?; - let dataframe = ctx.sql("SELECT * FROM a WHERE c0 < 1").await?; - dataframe.create_physical_plan().await - } + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::UInt32Array; #[tokio::test(flavor = "multi_thread")] async fn test_concat_batches() -> Result<()> { @@ -372,10 +347,9 @@ mod tests { // execute and collect results let output_partition_count = exec.output_partitioning().partition_count(); let mut output_partitions = Vec::with_capacity(output_partition_count); - let session_ctx = SessionContext::new(); for i in 0..output_partition_count { // execute this *output* partition and collect all batches - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let mut stream = exec.execute(i, task_ctx.clone())?; let mut batches = vec![]; while let Some(result) = stream.next().await { @@ -385,4 +359,23 @@ mod tests { } Ok(output_partitions) } + + /// Create vector batches + fn create_vec_batches(schema: &Schema, n: usize) -> Vec { + let batch = create_batch(schema); + let mut vec = Vec::with_capacity(n); + for _ in 0..n { + vec.push(batch.clone()); + } + vec + } + + /// Create batch + fn create_batch(schema: &Schema) -> RecordBatch { + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], + ) + .unwrap() + } } diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs similarity index 85% rename from datafusion/core/src/physical_plan/coalesce_partitions.rs rename to datafusion/physical-plan/src/coalesce_partitions.rs index d05c413caf4f9..bfcff28535386 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -21,19 +21,17 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::SchemaRef; - use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::stream::{ObservedStream, RecordBatchReceiverStream}; -use super::Statistics; -use crate::physical_plan::{ - DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, -}; -use datafusion_common::{DataFusionError, Result}; +use super::{DisplayAs, SendableRecordBatchStream, Statistics}; -use super::SendableRecordBatchStream; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -60,6 +58,20 @@ impl CoalescePartitionsExec { } } +impl DisplayAs for CoalescePartitionsExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CoalescePartitionsExec") + } + } + } +} + impl ExecutionPlan for CoalescePartitionsExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -90,7 +102,14 @@ impl ExecutionPlan for CoalescePartitionsExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + let mut output_eq = self.input.equivalence_properties(); + // Coalesce partitions loses existing orderings. + output_eq.clear_orderings(); + output_eq + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] } fn with_new_children( @@ -107,16 +126,14 @@ impl ExecutionPlan for CoalescePartitionsExec { ) -> Result { // CoalescePartitionsExec produces a single partition if 0 != partition { - return Err(DataFusionError::Internal(format!( - "CoalescePartitionsExec invalid partition {partition}" - ))); + return internal_err!("CoalescePartitionsExec invalid partition {partition}"); } let input_partitions = self.input.output_partitioning().partition_count(); match input_partitions { - 0 => Err(DataFusionError::Internal( - "CoalescePartitionsExec requires at least one input partition".to_owned(), - )), + 0 => internal_err!( + "CoalescePartitionsExec requires at least one input partition" + ), 1 => { // bypass any threading / metrics if there is a single partition self.input.execute(0, context) @@ -146,23 +163,11 @@ impl ExecutionPlan for CoalescePartitionsExec { } } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "CoalescePartitionsExec") - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } @@ -174,20 +179,18 @@ mod tests { use futures::FutureExt; use super::*; - use crate::physical_plan::{collect, common}; - use crate::prelude::SessionContext; use crate::test::exec::{ assert_strong_count_converges_to_zero, BlockingExec, PanicExec, }; use crate::test::{self, assert_is_pending}; + use crate::{collect, common}; #[tokio::test] async fn merge() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let num_partitions = 4; - let csv = test::scan_partitioned_csv(num_partitions)?; + let csv = test::scan_partitioned(num_partitions); // input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); @@ -202,17 +205,16 @@ mod tests { let batches = common::collect(iter).await?; assert_eq!(batches.len(), num_partitions); - // there should be a total of 100 rows + // there should be a total of 400 rows (100 per each partition) let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); - assert_eq!(row_count, 100); + assert_eq!(row_count, 400); Ok(()) } #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -234,8 +236,7 @@ mod tests { #[tokio::test] #[should_panic(expected = "PanickingStream did panic")] async fn test_panic() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/physical-plan/src/common.rs similarity index 88% rename from datafusion/core/src/physical_plan/common.rs rename to datafusion/physical-plan/src/common.rs index 982bb4f2e6bae..649f3a31aa7ef 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -17,24 +17,28 @@ //! Defines common code used in execution plans +use std::fs; +use std::fs::{metadata, File}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::task::{Context, Poll}; + use super::SendableRecordBatchStream; -use crate::physical_plan::stream::RecordBatchReceiverStream; -use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::stream::RecordBatchReceiverStream; +use crate::{ColumnStatistics, ExecutionPlan, Statistics}; + use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::stats::Precision; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use futures::{Future, StreamExt, TryStreamExt}; use parking_lot::Mutex; use pin_project_lite::pin_project; -use std::fs; -use std::fs::{metadata, File}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams @@ -50,9 +54,7 @@ pub fn build_checked_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); build_file_list_recurse(dir, &mut filenames, ext)?; if filenames.is_empty() { - return Err(DataFusionError::Plan(format!( - "No files found at {dir} with file extension {ext}" - ))); + return plan_err!("No files found at {dir} with file extension {ext}"); } Ok(filenames) } @@ -86,7 +88,7 @@ fn build_file_list_recurse( filenames.push(path_name.to_string()); } } else { - return Err(DataFusionError::Plan("Invalid path".to_string())); + return plan_err!("Invalid path"); } } } @@ -99,24 +101,31 @@ pub(crate) fn spawn_buffered( mut input: SendableRecordBatchStream, buffer: usize, ) -> SendableRecordBatchStream { - // Use tokio only if running from a tokio context (#2201) - if tokio::runtime::Handle::try_current().is_err() { - return input; - }; - - let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + // Use tokio only if running from a multi-thread tokio context + match tokio::runtime::Handle::try_current() { + Ok(handle) + if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => + { + let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + + let sender = builder.tx(); + + builder.spawn(async move { + while let Some(item) = input.next().await { + if sender.send(item).await.is_err() { + // receiver dropped when query is shutdown early (e.g., limit) or error, + // no need to return propagate the send error. + return Ok(()); + } + } - let sender = builder.tx(); + Ok(()) + }); - builder.spawn(async move { - while let Some(item) = input.next().await { - if sender.send(item).await.is_err() { - return; - } + builder.build() } - }); - - builder.build() + _ => input, + } } /// Computes the statistics for an in-memory RecordBatch @@ -130,29 +139,32 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches.iter().flatten().map(batch_byte_size).sum(); + let total_byte_size = batches + .iter() + .flatten() + .map(|b| b.get_array_memory_size()) + .sum(); let projection = match projection { Some(p) => p, None => (0..schema.fields().len()).collect(), }; - let mut column_statistics = vec![ColumnStatistics::default(); projection.len()]; + let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - *column_statistics[stat_index].null_count.get_or_insert(0) += - batch.column(*col_index).null_count(); + column_statistics[stat_index].null_count = + Precision::Exact(batch.column(*col_index).null_count()); } } } Statistics { - num_rows: Some(nb_rows), - total_byte_size: Some(total_byte_size), - column_statistics: Some(column_statistics), - is_exact: true, + num_rows: Precision::Exact(nb_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics, } } @@ -284,14 +296,92 @@ fn get_meet_of_orderings_helper( } } +/// Write in Arrow IPC format. +pub struct IPCWriter { + /// path + pub path: PathBuf, + /// inner writer + pub writer: FileWriter, + /// batches written + pub num_batches: u64, + /// rows written + pub num_rows: u64, + /// bytes written + pub num_bytes: u64, +} + +impl IPCWriter { + /// Create new writer + pub fn new(path: &Path, schema: &Schema) -> Result { + let file = File::create(path).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to create partition file at {path:?}: {e:?}" + )) + })?; + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + path: path.into(), + writer: FileWriter::try_new(file, schema)?, + }) + } + + /// Create new writer with IPC write options + pub fn new_with_options( + path: &Path, + schema: &Schema, + write_options: IpcWriteOptions, + ) -> Result { + let file = File::create(path).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to create partition file at {path:?}: {e:?}" + )) + })?; + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + path: path.into(), + writer: FileWriter::try_new_with_options(file, schema, write_options)?, + }) + } + /// Write one single batch + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + self.num_batches += 1; + self.num_rows += batch.num_rows() as u64; + let num_bytes: usize = batch.get_array_memory_size(); + self.num_bytes += num_bytes as u64; + Ok(()) + } + + /// Finish the writer + pub fn finish(&mut self) -> Result<()> { + self.writer.finish().map_err(Into::into) + } + + /// Path write to + pub fn path(&self) -> &Path { + &self.path + } +} + +/// Returns the total number of bytes of memory occupied physically by this batch. +#[deprecated(since = "28.0.0", note = "RecordBatch::get_array_memory_size")] +pub fn batch_byte_size(batch: &RecordBatch) -> usize { + batch.get_array_memory_size() +} + #[cfg(test)] mod tests { use std::ops::Not; use super::*; - use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::sorts::sort::SortExec; - use crate::physical_plan::union::UnionExec; + use crate::memory::MemoryExec; + use crate::sorts::sort::SortExec; + use crate::union::UnionExec; + use arrow::compute::SortOptions; use arrow::{ array::{Float32Array, Float64Array}, @@ -585,9 +675,8 @@ mod tests { ])); let stats = compute_record_batch_statistics(&[], &schema, Some(vec![0, 1])); - assert_eq!(stats.num_rows, Some(0)); - assert!(stats.is_exact); - assert_eq!(stats.total_byte_size, Some(0)); + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); Ok(()) } @@ -608,27 +697,26 @@ mod tests { compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); let mut expected = Statistics { - is_exact: true, - num_rows: Some(3), - total_byte_size: Some(464), // this might change a bit if the way we compute the size changes - column_statistics: Some(vec![ + num_rows: Precision::Exact(3), + total_byte_size: Precision::Exact(464), // this might change a bit if the way we compute the size changes + column_statistics: vec![ ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: Some(0), + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: Some(0), + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(0), }, - ]), + ], }; // Prevent test flakiness due to undefined / changing implementation details - expected.total_byte_size = actual.total_byte_size; + expected.total_byte_size = actual.total_byte_size.clone(); assert_eq!(actual, expected); Ok(()) @@ -643,83 +731,3 @@ mod tests { Ok(()) } } - -/// Write in Arrow IPC format. -pub struct IPCWriter { - /// path - pub path: PathBuf, - /// inner writer - pub writer: FileWriter, - /// batches written - pub num_batches: u64, - /// rows written - pub num_rows: u64, - /// bytes written - pub num_bytes: u64, -} - -impl IPCWriter { - /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new(file, schema)?, - }) - } - - /// Create new writer with IPC write options - pub fn new_with_options( - path: &Path, - schema: &Schema, - write_options: IpcWriteOptions, - ) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new_with_options(file, schema, write_options)?, - }) - } - /// Write one single batch - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows() as u64; - let num_bytes: usize = batch_byte_size(batch); - self.num_bytes += num_bytes as u64; - Ok(()) - } - - /// Finish the writer - pub fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(Into::into) - } - - /// Path write to - pub fn path(&self) -> &Path { - &self.path - } -} - -/// Returns the total number of bytes of memory occupied physically by this batch. -pub fn batch_byte_size(batch: &RecordBatch) -> usize { - batch - .columns() - .iter() - .map(|array| array.get_array_memory_size()) - .sum() -} diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs new file mode 100644 index 0000000000000..612e164be0e2c --- /dev/null +++ b/datafusion/physical-plan/src/display.rs @@ -0,0 +1,438 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of physical plan display. See +//! [`crate::displayable`] for examples of how to format + +use std::fmt; + +use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; + +use arrow_schema::SchemaRef; +use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; +use datafusion_physical_expr::PhysicalSortExpr; + +/// Options for controlling how each [`ExecutionPlan`] should format itself +#[derive(Debug, Clone, Copy)] +pub enum DisplayFormatType { + /// Default, compact format. Example: `FilterExec: c12 < 10.0` + Default, + /// Verbose, showing all available details + Verbose, +} + +/// Wraps an `ExecutionPlan` with various ways to display this plan +pub struct DisplayableExecutionPlan<'a> { + inner: &'a dyn ExecutionPlan, + /// How to show metrics + show_metrics: ShowMetrics, + /// If statistics should be displayed + show_statistics: bool, +} + +impl<'a> DisplayableExecutionPlan<'a> { + /// Create a wrapper around an [`ExecutionPlan`] which can be + /// pretty printed in a variety of ways + pub fn new(inner: &'a dyn ExecutionPlan) -> Self { + Self { + inner, + show_metrics: ShowMetrics::None, + show_statistics: false, + } + } + + /// Create a wrapper around an [`ExecutionPlan`] which can be + /// pretty printed in a variety of ways that also shows aggregated + /// metrics + pub fn with_metrics(inner: &'a dyn ExecutionPlan) -> Self { + Self { + inner, + show_metrics: ShowMetrics::Aggregated, + show_statistics: false, + } + } + + /// Create a wrapper around an [`ExecutionPlan`] which can be + /// pretty printed in a variety of ways that also shows all low + /// level metrics + pub fn with_full_metrics(inner: &'a dyn ExecutionPlan) -> Self { + Self { + inner, + show_metrics: ShowMetrics::Full, + show_statistics: false, + } + } + + /// Enable display of statistics + pub fn set_show_statistics(mut self, show_statistics: bool) -> Self { + self.show_statistics = show_statistics; + self + } + + /// Return a `format`able structure that produces a single line + /// per node. + /// + /// ```text + /// ProjectionExec: expr=[a] + /// CoalesceBatchesExec: target_batch_size=8192 + /// FilterExec: a < 5 + /// RepartitionExec: partitioning=RoundRobinBatch(16) + /// CsvExec: source=...", + /// ``` + pub fn indent(&self, verbose: bool) -> impl fmt::Display + 'a { + let format_type = if verbose { + DisplayFormatType::Verbose + } else { + DisplayFormatType::Default + }; + struct Wrapper<'a> { + format_type: DisplayFormatType, + plan: &'a dyn ExecutionPlan, + show_metrics: ShowMetrics, + show_statistics: bool, + } + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut visitor = IndentVisitor { + t: self.format_type, + f, + indent: 0, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + }; + accept(self.plan, &mut visitor) + } + } + Wrapper { + format_type, + plan: self.inner, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + } + } + + /// Returns a `format`able structure that produces graphviz format for execution plan, which can + /// be directly visualized [here](https://dreampuf.github.io/GraphvizOnline). + /// + /// An example is + /// ```dot + /// strict digraph dot_plan { + // 0[label="ProjectionExec: expr=[id@0 + 2 as employee.id + Int32(2)]",tooltip=""] + // 1[label="EmptyExec",tooltip=""] + // 0 -> 1 + // } + /// ``` + pub fn graphviz(&self) -> impl fmt::Display + 'a { + struct Wrapper<'a> { + plan: &'a dyn ExecutionPlan, + show_metrics: ShowMetrics, + show_statistics: bool, + } + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let t = DisplayFormatType::Default; + + let mut visitor = GraphvizVisitor { + f, + t, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + graphviz_builder: GraphvizBuilder::default(), + parents: Vec::new(), + }; + + visitor.start_graph()?; + + accept(self.plan, &mut visitor)?; + + visitor.end_graph()?; + Ok(()) + } + } + + Wrapper { + plan: self.inner, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + } + } + + /// Return a single-line summary of the root of the plan + /// Example: `ProjectionExec: expr=[a@0 as a]`. + pub fn one_line(&self) -> impl fmt::Display + 'a { + struct Wrapper<'a> { + plan: &'a dyn ExecutionPlan, + show_metrics: ShowMetrics, + show_statistics: bool, + } + + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut visitor = IndentVisitor { + f, + t: DisplayFormatType::Default, + indent: 0, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + }; + visitor.pre_visit(self.plan)?; + Ok(()) + } + } + + Wrapper { + plan: self.inner, + show_metrics: self.show_metrics, + show_statistics: self.show_statistics, + } + } + + /// format as a `StringifiedPlan` + pub fn to_stringified(&self, verbose: bool, plan_type: PlanType) -> StringifiedPlan { + StringifiedPlan::new(plan_type, self.indent(verbose).to_string()) + } +} + +#[derive(Debug, Clone, Copy)] +enum ShowMetrics { + /// Do not show any metrics + None, + + /// Show aggregrated metrics across partition + Aggregated, + + /// Show full per-partition metrics + Full, +} + +/// Formats plans with a single line per node. +struct IndentVisitor<'a, 'b> { + /// How to format each node + t: DisplayFormatType, + /// Write to this formatter + f: &'a mut fmt::Formatter<'b>, + /// Indent size + indent: usize, + /// How to show metrics + show_metrics: ShowMetrics, + /// If statistics should be displayed + show_statistics: bool, +} + +impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { + type Error = fmt::Error; + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + write!(self.f, "{:indent$}", "", indent = self.indent * 2)?; + plan.fmt_as(self.t, self.f)?; + match self.show_metrics { + ShowMetrics::None => {} + ShowMetrics::Aggregated => { + if let Some(metrics) = plan.metrics() { + let metrics = metrics + .aggregate_by_name() + .sorted_for_display() + .timestamps_removed(); + + write!(self.f, ", metrics=[{metrics}]")?; + } else { + write!(self.f, ", metrics=[]")?; + } + } + ShowMetrics::Full => { + if let Some(metrics) = plan.metrics() { + write!(self.f, ", metrics=[{metrics}]")?; + } else { + write!(self.f, ", metrics=[]")?; + } + } + } + let stats = plan.statistics().map_err(|_e| fmt::Error)?; + if self.show_statistics { + write!(self.f, ", statistics=[{}]", stats)?; + } + writeln!(self.f)?; + self.indent += 1; + Ok(true) + } + + fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { + self.indent -= 1; + Ok(true) + } +} + +struct GraphvizVisitor<'a, 'b> { + f: &'a mut fmt::Formatter<'b>, + /// How to format each node + t: DisplayFormatType, + /// How to show metrics + show_metrics: ShowMetrics, + /// If statistics should be displayed + show_statistics: bool, + + graphviz_builder: GraphvizBuilder, + /// Used to record parent node ids when visiting a plan. + parents: Vec, +} + +impl GraphvizVisitor<'_, '_> { + fn start_graph(&mut self) -> fmt::Result { + self.graphviz_builder.start_graph(self.f) + } + + fn end_graph(&mut self) -> fmt::Result { + self.graphviz_builder.end_graph(self.f) + } +} + +impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { + type Error = fmt::Error; + + fn pre_visit( + &mut self, + plan: &dyn ExecutionPlan, + ) -> datafusion_common::Result { + let id = self.graphviz_builder.next_id(); + + struct Wrapper<'a>(&'a dyn ExecutionPlan, DisplayFormatType); + + impl<'a> std::fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt_as(self.1, f) + } + } + + let label = { format!("{}", Wrapper(plan, self.t)) }; + + let metrics = match self.show_metrics { + ShowMetrics::None => "".to_string(), + ShowMetrics::Aggregated => { + if let Some(metrics) = plan.metrics() { + let metrics = metrics + .aggregate_by_name() + .sorted_for_display() + .timestamps_removed(); + + format!("metrics=[{metrics}]") + } else { + "metrics=[]".to_string() + } + } + ShowMetrics::Full => { + if let Some(metrics) = plan.metrics() { + format!("metrics=[{metrics}]") + } else { + "metrics=[]".to_string() + } + } + }; + + let stats = plan.statistics().map_err(|_e| fmt::Error)?; + let statistics = if self.show_statistics { + format!("statistics=[{}]", stats) + } else { + "".to_string() + }; + + let delimiter = if !metrics.is_empty() && !statistics.is_empty() { + ", " + } else { + "" + }; + + self.graphviz_builder.add_node( + self.f, + id, + &label, + Some(&format!("{}{}{}", metrics, delimiter, statistics)), + )?; + + if let Some(parent_node_id) = self.parents.last() { + self.graphviz_builder + .add_edge(self.f, *parent_node_id, id)?; + } + + self.parents.push(id); + + Ok(true) + } + + fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { + self.parents.pop(); + Ok(true) + } +} + +/// Trait for types which could have additional details when formatted in `Verbose` mode +pub trait DisplayAs { + /// Format according to `DisplayFormatType`, used when verbose representation looks + /// different from the default one + /// + /// Should not include a newline + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result; +} + +/// A newtype wrapper to display `T` implementing`DisplayAs` using the `Default` mode +pub struct DefaultDisplay(pub T); + +impl fmt::Display for DefaultDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt_as(DisplayFormatType::Default, f) + } +} + +/// A newtype wrapper to display `T` implementing `DisplayAs` using the `Verbose` mode +pub struct VerboseDisplay(pub T); + +impl fmt::Display for VerboseDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt_as(DisplayFormatType::Verbose, f) + } +} + +/// A wrapper to customize partitioned file display +#[derive(Debug)] +pub struct ProjectSchemaDisplay<'a>(pub &'a SchemaRef); + +impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let parts: Vec<_> = self + .0 + .fields() + .iter() + .map(|x| x.name().to_owned()) + .collect::>(); + write!(f, "[{}]", parts.join(", ")) + } +} + +/// A wrapper to customize output ordering display. +#[derive(Debug)] +pub struct OutputOrderingDisplay<'a>(pub &'a [PhysicalSortExpr]); + +impl<'a> fmt::Display for OutputOrderingDisplay<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[")?; + for (i, e) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ", ")? + } + write!(f, "{e}")?; + } + write!(f, "]") + } +} diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs new file mode 100644 index 0000000000000..41c8dbed14536 --- /dev/null +++ b/datafusion/physical-plan/src/empty.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! EmptyRelation with produce_one_row=false execution plan + +use std::any::Any; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; + +use log::trace; + +/// Execution plan for empty relation with produce_one_row=false +#[derive(Debug)] +pub struct EmptyExec { + /// The schema for the produced row + schema: SchemaRef, + /// Number of partitions + partitions: usize, +} + +impl EmptyExec { + /// Create a new EmptyExec + pub fn new(schema: SchemaRef) -> Self { + EmptyExec { + schema, + partitions: 1, + } + } + + /// Create a new EmptyExec with specified partition number + pub fn with_partitions(mut self, partitions: usize) -> Self { + self.partitions = partitions; + self + } + + fn data(&self) -> Result> { + Ok(vec![]) + } +} + +impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "EmptyExec") + } + } + } +} + +impl ExecutionPlan for EmptyExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(self.schema.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start EmptyExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + + if partition >= self.partitions { + return internal_err!( + "EmptyExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + + Ok(Box::pin(MemoryStream::try_new( + self.data()?, + self.schema.clone(), + None, + )?)) + } + + fn statistics(&self) -> Result { + let batch = self + .data() + .expect("Create empty RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::with_new_children_if_necessary; + use crate::{common, test}; + + #[tokio::test] + async fn empty() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + + let empty = EmptyExec::new(schema.clone()); + assert_eq!(empty.schema(), schema); + + // we should have no results + let iter = empty.execute(0, task_ctx)?; + let batches = common::collect(iter).await?; + assert!(batches.is_empty()); + + Ok(()) + } + + #[test] + fn with_new_children() -> Result<()> { + let schema = test::aggr_test_schema(); + let empty = Arc::new(EmptyExec::new(schema.clone())); + + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); + assert_eq!(empty.schema(), empty2.schema()); + + let too_many_kids = vec![empty2]; + assert!( + with_new_children_if_necessary(empty, too_many_kids).is_err(), + "expected error when providing list of kids" + ); + Ok(()) + } + + #[tokio::test] + async fn invalid_execute() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let empty = EmptyExec::new(schema); + + // ask for the wrong partition + assert!(empty.execute(1, task_ctx.clone()).is_err()); + assert!(empty.execute(20, task_ctx).is_err()); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/explain.rs b/datafusion/physical-plan/src/explain.rs similarity index 89% rename from datafusion/core/src/physical_plan/explain.rs rename to datafusion/physical-plan/src/explain.rs index fc70626d9ba05..e4904ddd34100 100644 --- a/datafusion/core/src/physical_plan/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -20,19 +20,18 @@ use std::any::Any; use std::sync::Arc; -use datafusion_common::{DataFusionError, Result}; +use super::expressions::PhysicalSortExpr; +use super::{DisplayAs, SendableRecordBatchStream}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; -use crate::{ - logical_expr::StringifiedPlan, - physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}, -}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; -use log::trace; - -use super::{expressions::PhysicalSortExpr, SendableRecordBatchStream}; -use crate::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_common::display::StringifiedPlan; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use log::trace; + /// Explain execution plan operator. This operator contains the string /// values of the various plans it has when it is created, and passes /// them to its output. @@ -71,6 +70,20 @@ impl ExplainExec { } } +impl DisplayAs for ExplainExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ExplainExec") + } + } + } +} + impl ExecutionPlan for ExplainExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -109,9 +122,7 @@ impl ExecutionPlan for ExplainExec { ) -> Result { trace!("Start ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); if 0 != partition { - return Err(DataFusionError::Internal(format!( - "ExplainExec invalid partition {partition}" - ))); + return internal_err!("ExplainExec invalid partition {partition}"); } let mut type_builder = @@ -156,23 +167,6 @@ impl ExecutionPlan for ExplainExec { futures::stream::iter(vec![Ok(record_batch)]), ))) } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "ExplainExec") - } - } - } - - fn statistics(&self) -> Statistics { - // Statistics an EXPLAIN plan are not relevant - Statistics::default() - } } /// If this plan should be shown, given the previous plan that was diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs new file mode 100644 index 0000000000000..56a1b4e178219 --- /dev/null +++ b/datafusion/physical-plan/src/filter.rs @@ -0,0 +1,1060 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FilterExec evaluates a boolean predicate against all input batches to determine which rows to +//! include in its output batches. + +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::expressions::PhysicalSortExpr; +use super::{ + ColumnStatistics, DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics, +}; +use crate::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + Column, DisplayFormatType, ExecutionPlan, Partitioning, +}; + +use arrow::compute::filter_record_batch; +use arrow::datatypes::{DataType, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::stats::Precision; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::intervals::utils::check_support; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{ + analyze, split_conjunction, AnalysisContext, EquivalenceProperties, ExprBoundaries, + PhysicalExpr, +}; + +use futures::stream::{Stream, StreamExt}; +use log::trace; + +/// FilterExec evaluates a boolean predicate against all input batches to determine which rows to +/// include in its output batches. +#[derive(Debug)] +pub struct FilterExec { + /// The expression to filter on. This expression must evaluate to a boolean value. + predicate: Arc, + /// The input plan + input: Arc, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Selectivity for statistics. 0 = no rows, 100 all rows + default_selectivity: u8, +} + +impl FilterExec { + /// Create a FilterExec on an input + pub fn try_new( + predicate: Arc, + input: Arc, + ) -> Result { + match predicate.data_type(input.schema().as_ref())? { + DataType::Boolean => Ok(Self { + predicate, + input: input.clone(), + metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: 20, + }), + other => { + plan_err!("Filter predicate must return boolean values, not {other:?}") + } + } + } + + pub fn with_default_selectivity( + mut self, + default_selectivity: u8, + ) -> Result { + if default_selectivity > 100 { + return plan_err!("Default flter selectivity needs to be less than 100"); + } + self.default_selectivity = default_selectivity; + Ok(self) + } + + /// The expression to filter on. This expression must evaluate to a boolean value. + pub fn predicate(&self) -> &Arc { + &self.predicate + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// The default selectivity + pub fn default_selectivity(&self) -> u8 { + self.default_selectivity + } +} + +impl DisplayAs for FilterExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "FilterExec: {}", self.predicate) + } + } + } +} + +impl ExecutionPlan for FilterExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + // The filter operator does not make any changes to the schema of its input + self.input.schema() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + self.input.output_partitioning() + } + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn maintains_input_order(&self) -> Vec { + // tell optimizer this operator doesn't reorder its input + vec![true] + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + let stats = self.statistics().unwrap(); + // Combine the equal predicates with the input equivalence properties + let mut result = self.input.equivalence_properties(); + let (equal_pairs, _) = collect_columns_from_predicate(&self.predicate); + for (lhs, rhs) in equal_pairs { + let lhs_expr = Arc::new(lhs.clone()) as _; + let rhs_expr = Arc::new(rhs.clone()) as _; + result.add_equal_conditions(&lhs_expr, &rhs_expr) + } + // Add the columns that have only one value (singleton) after filtering to constants. + let constants = collect_columns(self.predicate()) + .into_iter() + .filter(|column| stats.column_statistics[column.index()].is_singleton()) + .map(|column| Arc::new(column) as _); + result.add_constants(constants) + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + .and_then(|e| { + let selectivity = e.default_selectivity(); + e.with_default_selectivity(selectivity) + }) + .map(|e| Arc::new(e) as _) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + Ok(Box::pin(FilterExecStream { + schema: self.input.schema(), + predicate: self.predicate.clone(), + input: self.input.execute(partition, context)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + /// The output statistics of a filtering operation can be estimated if the + /// predicate's selectivity value can be determined for the incoming data. + fn statistics(&self) -> Result { + let predicate = self.predicate(); + + let input_stats = self.input.statistics()?; + let schema = self.schema(); + if !check_support(predicate, &schema) { + let selectivity = self.default_selectivity as f64 / 100.0; + let mut stats = input_stats.into_inexact(); + stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); + stats.total_byte_size = stats + .total_byte_size + .with_estimated_selectivity(selectivity); + return Ok(stats); + } + + let num_rows = input_stats.num_rows; + let total_byte_size = input_stats.total_byte_size; + let input_analysis_ctx = AnalysisContext::try_from_statistics( + &self.input.schema(), + &input_stats.column_statistics, + )?; + + let analysis_ctx = analyze(predicate, input_analysis_ctx, &self.schema())?; + + // Estimate (inexact) selectivity of predicate + let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); + let num_rows = num_rows.with_estimated_selectivity(selectivity); + let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); + + let column_statistics = collect_new_statistics( + &input_stats.column_statistics, + analysis_ctx.boundaries, + ); + Ok(Statistics { + num_rows, + total_byte_size, + column_statistics, + }) + } +} + +/// This function ensures that all bounds in the `ExprBoundaries` vector are +/// converted to closed bounds. If a lower/upper bound is initially open, it +/// is adjusted by using the next/previous value for its data type to convert +/// it into a closed bound. +fn collect_new_statistics( + input_column_stats: &[ColumnStatistics], + analysis_boundaries: Vec, +) -> Vec { + analysis_boundaries + .into_iter() + .enumerate() + .map( + |( + idx, + ExprBoundaries { + interval, + distinct_count, + .. + }, + )| { + let (lower, upper) = interval.into_bounds(); + let (min_value, max_value) = if lower.eq(&upper) { + (Precision::Exact(lower), Precision::Exact(upper)) + } else { + (Precision::Inexact(lower), Precision::Inexact(upper)) + }; + ColumnStatistics { + null_count: input_column_stats[idx].null_count.clone().to_inexact(), + max_value, + min_value, + distinct_count: distinct_count.to_inexact(), + } + }, + ) + .collect() +} + +/// The FilterExec streams wraps the input iterator and applies the predicate expression to +/// determine which rows to include in its output batches +struct FilterExecStream { + /// Output schema, which is the same as the input schema for this operator + schema: SchemaRef, + /// The expression to filter on. This expression must evaluate to a boolean value. + predicate: Arc, + /// The input partition to filter. + input: SendableRecordBatchStream, + /// runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +pub(crate) fn batch_filter( + batch: &RecordBatch, + predicate: &Arc, +) -> Result { + predicate + .evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + .and_then(|array| { + Ok(as_boolean_array(&array)?) + // apply filter array to record batch + .and_then(|filter_array| Ok(filter_record_batch(batch, filter_array)?)) + }) +} + +impl Stream for FilterExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll; + loop { + match self.input.poll_next_unpin(cx) { + Poll::Ready(value) => match value { + Some(Ok(batch)) => { + let timer = self.baseline_metrics.elapsed_compute().timer(); + let filtered_batch = batch_filter(&batch, &self.predicate)?; + // skip entirely filtered batches + if filtered_batch.num_rows() == 0 { + continue; + } + timer.done(); + poll = Poll::Ready(Some(Ok(filtered_batch))); + break; + } + _ => { + poll = Poll::Ready(value); + break; + } + }, + Poll::Pending => { + poll = Poll::Pending; + break; + } + } + } + self.baseline_metrics.record_poll(poll) + } + + fn size_hint(&self) -> (usize, Option) { + // same number of record batches + self.input.size_hint() + } +} + +impl RecordBatchStream for FilterExecStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Return the equals Column-Pairs and Non-equals Column-Pairs +fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { + let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new(); + let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new(); + + let predicates = split_conjunction(predicate); + predicates.into_iter().for_each(|p| { + if let Some(binary) = p.as_any().downcast_ref::() { + if let (Some(left_column), Some(right_column)) = ( + binary.left().as_any().downcast_ref::(), + binary.right().as_any().downcast_ref::(), + ) { + match binary.op() { + Operator::Eq => { + eq_predicate_columns.push((left_column, right_column)) + } + Operator::NotEq => { + ne_predicate_columns.push((left_column, right_column)) + } + _ => {} + } + } + } + }); + + (eq_predicate_columns, ne_predicate_columns) +} +/// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates +pub type EqualAndNonEqual<'a> = + (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); + +#[cfg(test)] +mod tests { + use std::iter::Iterator; + use std::sync::Arc; + + use super::*; + use crate::expressions::*; + use crate::test; + use crate::test::exec::StatisticsExec; + use crate::ExecutionPlan; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ColumnStatistics, ScalarValue}; + use datafusion_expr::Operator; + + #[tokio::test] + async fn collect_columns_predicates() -> Result<()> { + let schema = test::aggr_test_schema(); + let predicate: Arc = binary( + binary( + binary(col("c2", &schema)?, Operator::GtEq, lit(1u32), &schema)?, + Operator::And, + binary(col("c2", &schema)?, Operator::Eq, lit(4u32), &schema)?, + &schema, + )?, + Operator::And, + binary( + binary( + col("c2", &schema)?, + Operator::Eq, + col("c9", &schema)?, + &schema, + )?, + Operator::And, + binary( + col("c1", &schema)?, + Operator::NotEq, + col("c13", &schema)?, + &schema, + )?, + &schema, + )?, + &schema, + )?; + + let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); + + assert_eq!(1, equal_pairs.len()); + assert_eq!(equal_pairs[0].0.name(), "c2"); + assert_eq!(equal_pairs[0].1.name(), "c9"); + + assert_eq!(1, ne_pairs.len()); + assert_eq!(ne_pairs[0].0.name(), "c1"); + assert_eq!(ne_pairs[0].1.name(), "c13"); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_basic_expr() -> Result<()> { + // Table: + // a: min=1, max=100 + let bytes_per_row = 4; + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(100 * bytes_per_row), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }], + }, + schema.clone(), + )); + + // a <= 25 + let predicate: Arc = + binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?; + + // WHERE a <= 25 + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(25)); + assert_eq!( + statistics.total_byte_size, + Precision::Inexact(25 * bytes_per_row) + ); + assert_eq!( + statistics.column_statistics, + vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), + ..Default::default() + }] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_column_level_nested() -> Result<()> { + // Table: + // a: min=1, max=100 + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }], + total_byte_size: Precision::Absent, + }, + schema.clone(), + )); + + // WHERE a <= 25 + let sub_filter: Arc = Arc::new(FilterExec::try_new( + binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?, + input, + )?); + + // Nested filters (two separate physical plans, instead of AND chain in the expr) + // WHERE a >= 10 + // WHERE a <= 25 + let filter: Arc = Arc::new(FilterExec::try_new( + binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, + sub_filter, + )?); + + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(16)); + assert_eq!( + statistics.column_statistics, + vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), + ..Default::default() + }] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_column_level_nested_multiple() -> Result<()> { + // Table: + // a: min=1, max=100 + // b: min=1, max=50 + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + ..Default::default() + }, + ], + total_byte_size: Precision::Absent, + }, + schema.clone(), + )); + + // WHERE a <= 25 + let a_lte_25: Arc = Arc::new(FilterExec::try_new( + binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?, + input, + )?); + + // WHERE b > 45 + let b_gt_5: Arc = Arc::new(FilterExec::try_new( + binary(col("b", &schema)?, Operator::Gt, lit(45i32), &schema)?, + a_lte_25, + )?); + + // WHERE a >= 10 + let filter: Arc = Arc::new(FilterExec::try_new( + binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, + b_gt_5, + )?); + let statistics = filter.statistics()?; + // On a uniform distribution, only fifteen rows will satisfy the + // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only + // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). + // + // Which would result with a selectivity of '15/100 * 5/50' or 0.015 + // and that means about %1.5 of the all rows (rounded up to 2 rows). + assert_eq!(statistics.num_rows, Precision::Inexact(2)); + assert_eq!( + statistics.column_statistics, + vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(46))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + ..Default::default() + } + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_when_input_stats_missing() -> Result<()> { + // Table: + // a: min=???, max=??? (missing) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema.clone(), + )); + + // a <= 25 + let predicate: Arc = + binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?; + + // WHERE a <= 25 + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Absent); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_multiple_columns() -> Result<()> { + // Table: + // a: min=1, max=100 + // b: min=1, max=3 + // c: min=1000.0 max=1100.0 + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Float32(Some(1000.0))), + max_value: Precision::Inexact(ScalarValue::Float32(Some(1100.0))), + ..Default::default() + }, + ], + }, + schema, + )); + // WHERE a<=53 AND (b=3 AND (c<=1075.0 AND a>b)) + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::LtEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(53)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::LtEq, + Arc::new(Literal::new(ScalarValue::Float32(Some(1075.0)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Column::new("b", 1)), + )), + )), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.statistics()?; + // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330... + // num_rows after ceil => 133.0... => 134 + // total_byte_size after ceil => 532.0... => 533 + assert_eq!(statistics.num_rows, Precision::Inexact(134)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(533)); + let exp_col_stats = vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(4))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(53))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Float32(Some(1000.0))), + max_value: Precision::Inexact(ScalarValue::Float32(Some(1075.0))), + ..Default::default() + }, + ]; + let _ = exp_col_stats + .into_iter() + .zip(statistics.column_statistics) + .map(|(expected, actual)| { + if let Some(val) = actual.min_value.get_value() { + if val.data_type().is_floating() { + // Windows rounds arithmetic operation results differently for floating point numbers. + // Therefore, we check if the actual values are in an epsilon range. + let actual_min = actual.min_value.get_value().unwrap(); + let actual_max = actual.max_value.get_value().unwrap(); + let expected_min = expected.min_value.get_value().unwrap(); + let expected_max = expected.max_value.get_value().unwrap(); + let eps = ScalarValue::Float32(Some(1e-6)); + + assert!(actual_min.sub(expected_min).unwrap() < eps); + assert!(actual_min.sub(expected_min).unwrap() < eps); + + assert!(actual_max.sub(expected_max).unwrap() < eps); + assert!(actual_max.sub(expected_max).unwrap() < eps); + } else { + assert_eq!(actual, expected); + } + } else { + assert_eq!(actual, expected); + } + }); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_full_selective() -> Result<()> { + // Table: + // a: min=1, max=100 + // b: min=1, max=3 + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ], + }, + schema, + )); + // WHERE a<200 AND 1<=b + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::LtEq, + Arc::new(Column::new("b", 1)), + )), + )); + // Since filter predicate passes all entries, statistics after filter shouldn't change. + let expected = input.statistics()?.column_statistics; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.statistics()?; + + assert_eq!(statistics.num_rows, Precision::Inexact(1000)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(4000)); + assert_eq!(statistics.column_statistics, expected); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_zero_selective() -> Result<()> { + // Table: + // a: min=1, max=100 + // b: min=1, max=3 + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ], + }, + schema, + )); + // WHERE a>200 AND 1<=b + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + Operator::LtEq, + Arc::new(Column::new("b", 1)), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.statistics()?; + + assert_eq!(statistics.num_rows, Precision::Inexact(0)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(0)); + assert_eq!( + statistics.column_statistics, + vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_more_inputs() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ], + }, + schema, + )); + // WHERE a<50 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.statistics()?; + + assert_eq!(statistics.num_rows, Precision::Inexact(490)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1960)); + assert_eq!( + statistics.column_statistics, + vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(49))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ] + ); + + Ok(()) + } + + #[tokio::test] + async fn test_empty_input_statistics() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a <= 10 AND 0 <= a - 5 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::LtEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + Operator::LtEq, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + + let expected_filter_statistics = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + distinct_count: Precision::Absent, + }], + }; + + assert_eq!(filter_statistics, expected_filter_statistics); + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_with_constant_column() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let filter_statistics = filter.statistics()?; + // First column is "a", and it is a column with only one value after the filter. + assert!(filter_statistics.column_statistics[0].is_singleton()); + + Ok(()) + } + + #[tokio::test] + async fn test_validation_filter_selectivity() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&schema), + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + let filter = FilterExec::try_new(predicate, input)?; + assert!(filter.with_default_selectivity(120).is_err()); + Ok(()) + } + + #[tokio::test] + async fn test_custom_filter_selectivity() -> Result<()> { + // Need a decimal to trigger inexact selectivity + let schema = + Schema::new(vec![Field::new("a", DataType::Decimal128(2, 3), false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ColumnStatistics { + ..Default::default() + }], + }, + schema, + )); + // WHERE a = 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), + )); + let filter = FilterExec::try_new(predicate, input)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(200)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); + let filter = filter.with_default_selectivity(40)?; + let statistics = filter.statistics()?; + assert_eq!(statistics.num_rows, Precision::Inexact(400)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs new file mode 100644 index 0000000000000..81cdfd753fe69 --- /dev/null +++ b/datafusion/physical-plan/src/insert.rs @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for writing data to [`DataSink`]s + +use std::any::Any; +use std::fmt; +use std::fmt::Debug; +use std::sync::Arc; + +use super::expressions::PhysicalSortExpr; +use super::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +}; +use crate::metrics::MetricsSet; +use crate::stream::RecordBatchStreamAdapter; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{Distribution, PhysicalSortRequirement}; + +use async_trait::async_trait; +use futures::StreamExt; + +/// `DataSink` implements writing streams of [`RecordBatch`]es to +/// user defined destinations. +/// +/// The `Display` impl is used to format the sink for explain plan +/// output. +#[async_trait] +pub trait DataSink: DisplayAs + Debug + Send + Sync { + /// Returns the data sink as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Return a snapshot of the [MetricsSet] for this + /// [DataSink]. + /// + /// See [ExecutionPlan::metrics()] for more details + fn metrics(&self) -> Option; + + // TODO add desired input ordering + // How does this sink want its input ordered? + + /// Writes the data to the sink, returns the number of values written + /// + /// This method will be called exactly once during each DML + /// statement. Thus prior to return, the sink should do any commit + /// or rollback required. + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result; +} + +/// Execution plan for writing record batches to a [`DataSink`] +/// +/// Returns a single row with the number of values written +pub struct FileSinkExec { + /// Input plan that produces the record batches to be written. + input: Arc, + /// Sink to which to write + sink: Arc, + /// Schema of the sink for validating the input data + sink_schema: SchemaRef, + /// Schema describing the structure of the output data. + count_schema: SchemaRef, + /// Optional required sort order for output data. + sort_order: Option>, +} + +impl fmt::Debug for FileSinkExec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "FileSinkExec schema: {:?}", self.count_schema) + } +} + +impl FileSinkExec { + /// Create a plan to write to `sink` + pub fn new( + input: Arc, + sink: Arc, + sink_schema: SchemaRef, + sort_order: Option>, + ) -> Self { + Self { + input, + sink, + sink_schema, + count_schema: make_count_schema(), + sort_order, + } + } + + fn execute_input_stream( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, context)?; + + debug_assert_eq!( + self.sink_schema.fields().len(), + self.input.schema().fields().len() + ); + + // Find input columns that may violate the not null constraint. + let risky_columns: Vec<_> = self + .sink_schema + .fields() + .iter() + .zip(self.input.schema().fields().iter()) + .enumerate() + .filter_map(|(i, (sink_field, input_field))| { + if !sink_field.is_nullable() && input_field.is_nullable() { + Some(i) + } else { + None + } + }) + .collect(); + + if risky_columns.is_empty() { + Ok(input_stream) + } else { + // Check not null constraint on the input stream + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.sink_schema.clone(), + input_stream + .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + ))) + } + } + + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns insert sink + pub fn sink(&self) -> &dyn DataSink { + self.sink.as_ref() + } + + /// Optional sort order for output data + pub fn sort_order(&self) -> &Option> { + &self.sort_order + } + + /// Returns the metrics of the underlying [DataSink] + pub fn metrics(&self) -> Option { + self.sink.metrics() + } +} + +impl DisplayAs for FileSinkExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "FileSinkExec: sink=")?; + self.sink.fmt_as(t, f) + } + } + } +} + +impl ExecutionPlan for FileSinkExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.count_schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn benefits_from_input_partitioning(&self) -> Vec { + // DataSink is responsible for dynamically partitioning its + // own input at execution time. + vec![false] + } + + fn required_input_distribution(&self) -> Vec { + // DataSink is responsible for dynamically partitioning its + // own input at execution time, and so requires a single input partition. + vec![Distribution::SinglePartition; self.children().len()] + } + + fn required_input_ordering(&self) -> Vec>> { + // The required input ordering is set externally (e.g. by a `ListingTable`). + // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). + vec![self.sort_order.as_ref().cloned()] + } + + fn maintains_input_order(&self) -> Vec { + // Maintains ordering in the sense that the written file will reflect + // the ordering of the input. For more context, see: + // + // https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 + vec![true] + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + input: children[0].clone(), + sink: self.sink.clone(), + sink_schema: self.sink_schema.clone(), + count_schema: self.count_schema.clone(), + sort_order: self.sort_order.clone(), + })) + } + + fn unbounded_output(&self, _children: &[bool]) -> Result { + Ok(_children[0]) + } + + /// Execute the plan and return a stream of `RecordBatch`es for + /// the specified partition. + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + if partition != 0 { + return internal_err!("FileSinkExec can only be called on partition 0!"); + } + let data = self.execute_input_stream(0, context.clone())?; + + let count_schema = self.count_schema.clone(); + let sink = self.sink.clone(); + + let stream = futures::stream::once(async move { + sink.write_all(data, &context).await.map(make_count_batch) + }) + .boxed(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + count_schema, + stream, + ))) + } +} + +/// Create a output record batch with a count +/// +/// ```text +/// +-------+, +/// | count |, +/// +-------+, +/// | 6 |, +/// +-------+, +/// ``` +fn make_count_batch(count: u64) -> RecordBatch { + let array = Arc::new(UInt64Array::from(vec![count])) as ArrayRef; + + RecordBatch::try_from_iter_with_nullable(vec![("count", array, false)]).unwrap() +} + +fn make_count_schema() -> SchemaRef { + // define a schema. + Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +} + +fn check_not_null_contraits( + batch: RecordBatch, + column_indices: &Vec, +) -> Result { + for &index in column_indices { + if batch.num_columns() <= index { + return exec_err!( + "Invalid batch column count {} expected > {}", + batch.num_columns(), + index + ); + } + + if batch.column(index).null_count() > 0 { + return exec_err!( + "Invalid batch column at '{}' has null but schema specifies non-nullable", + index + ); + } + } + + Ok(batch) +} diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs similarity index 65% rename from datafusion/core/src/physical_plan/joins/cross_join.rs rename to datafusion/physical-plan/src/joins/cross_join.rs index eb567ee130cad..938c9e4d343d6 100644 --- a/datafusion/core/src/physical_plan/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -18,30 +18,31 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{ready, StreamExt}; -use futures::{Stream, TryStreamExt}; use std::{any::Any, sync::Arc, task::Poll}; -use arrow::datatypes::{Fields, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; - -use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::physical_plan::{ +use super::utils::{ + adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::DisplayAs; +use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, - ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalSortExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use async_trait::async_trait; -use datafusion_common::DataFusionError; -use datafusion_common::{Result, ScalarValue}; + +use arrow::datatypes::{Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow_array::RecordBatchOptions; +use datafusion_common::stats::Precision; +use datafusion_common::{plan_err, DataFusionError, JoinType, Result, ScalarValue}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::EquivalenceProperties; -use super::utils::{ - adjust_right_output_partitioning, cross_join_equivalence_properties, - BuildProbeJoinMetrics, OnceAsync, OnceFut, -}; +use async_trait::async_trait; +use futures::{ready, Stream, StreamExt, TryStreamExt}; /// Data of the left side type JoinLeftData = (RecordBatch, MemoryReservation); @@ -51,9 +52,9 @@ type JoinLeftData = (RecordBatch, MemoryReservation); #[derive(Debug)] pub struct CrossJoinExec { /// left (build) side which gets loaded in memory - pub(crate) left: Arc, + pub left: Arc, /// right (probe) side which are combined with left side - pub(crate) right: Arc, + pub right: Arc, /// The schema once the join is applied schema: SchemaRef, /// Build-side data @@ -104,12 +105,11 @@ async fn load_left_input( reservation: MemoryReservation, ) -> Result { // merge all left parts into a single stream - let merge = { - if left.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(left.clone())) - } else { - left.clone() - } + let left_schema = left.schema(); + let merge = if left.output_partitioning().partition_count() != 1 { + Arc::new(CoalescePartitionsExec::new(left)) + } else { + left }; let stream = merge.execute(0, context)?; @@ -134,11 +134,25 @@ async fn load_left_input( ) .await?; - let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?; + let merged_batch = concat_batches(&left_schema, &batches, num_rows)?; Ok((merged_batch, reservation)) } +impl DisplayAs for CrossJoinExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CrossJoinExec") + } + } + } +} + impl ExecutionPlan for CrossJoinExec { fn as_any(&self) -> &dyn Any { self @@ -161,10 +175,9 @@ impl ExecutionPlan for CrossJoinExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] || children[1] { - Err(DataFusionError::Plan( + plan_err!( "Cross Join Error: Cross join is not supported for the unbounded inputs." - .to_string(), - )) + ) } else { Ok(false) } @@ -202,12 +215,14 @@ impl ExecutionPlan for CrossJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - cross_join_equivalence_properties( + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, + &JoinType::Full, self.schema(), + &[false, false], + None, + &[], ) } @@ -243,89 +258,55 @@ impl ExecutionPlan for CrossJoinExec { })) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "CrossJoinExec") - } - } - } - - fn statistics(&self) -> Statistics { - stats_cartesian_product( - self.left.statistics(), - self.left.schema().fields().len(), - self.right.statistics(), - self.right.schema().fields().len(), - ) + fn statistics(&self) -> Result { + Ok(stats_cartesian_product( + self.left.statistics()?, + self.right.statistics()?, + )) } } /// [left/right]_col_count are required in case the column statistics are None fn stats_cartesian_product( left_stats: Statistics, - left_col_count: usize, right_stats: Statistics, - right_col_count: usize, ) -> Statistics { let left_row_count = left_stats.num_rows; let right_row_count = right_stats.num_rows; // calculate global stats - let is_exact = left_stats.is_exact && right_stats.is_exact; - let num_rows = left_stats - .num_rows - .zip(right_stats.num_rows) - .map(|(a, b)| a * b); + let num_rows = left_row_count.multiply(&right_row_count); // the result size is two times a*b because you have the columns of both left and right let total_byte_size = left_stats .total_byte_size - .zip(right_stats.total_byte_size) - .map(|(a, b)| 2 * a * b); - - // calculate column stats - let column_statistics = - // complete the column statistics if they are missing only on one side - match (left_stats.column_statistics, right_stats.column_statistics) { - (None, None) => None, - (None, Some(right_col_stat)) => Some(( - vec![ColumnStatistics::default(); left_col_count], - right_col_stat, - )), - (Some(left_col_stat), None) => Some(( - left_col_stat, - vec![ColumnStatistics::default(); right_col_count], - )), - (Some(left_col_stat), Some(right_col_stat)) => { - Some((left_col_stat, right_col_stat)) - } - } - .map(|(left_col_stats, right_col_stats)| { - // the null counts must be multiplied by the row counts of the other side (if defined) - // Min, max and distinct_count on the other hand are invariants. - left_col_stats.into_iter().map(|s| ColumnStatistics{ - null_count: s.null_count.zip(right_row_count).map(|(a, b)| a * b), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - }).chain( - right_col_stats.into_iter().map(|s| ColumnStatistics{ - null_count: s.null_count.zip(left_row_count).map(|(a, b)| a * b), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - })).collect() - }); + .multiply(&right_stats.total_byte_size) + .multiply(&Precision::Exact(2)); + + let left_col_stats = left_stats.column_statistics; + let right_col_stats = right_stats.column_statistics; + + // the null counts must be multiplied by the row counts of the other side (if defined) + // Min, max and distinct_count on the other hand are invariants. + let cross_join_stats = left_col_stats + .into_iter() + .map(|s| ColumnStatistics { + null_count: s.null_count.multiply(&right_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + }) + .chain(right_col_stats.into_iter().map(|s| ColumnStatistics { + null_count: s.null_count.multiply(&left_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + })) + .collect(); Statistics { - is_exact, num_rows, total_byte_size, - column_statistics, + column_statistics: cross_join_stats, } } @@ -363,17 +344,18 @@ fn build_batch( .iter() .map(|arr| { let scalar = ScalarValue::try_from_array(arr, left_index)?; - Ok(scalar.to_array_of_size(batch.num_rows())) + scalar.to_array_of_size(batch.num_rows()) }) .collect::>>()?; - RecordBatch::try_new( + RecordBatch::try_new_with_options( Arc::new(schema.clone()), arrays .iter() .chain(batch.columns().iter()) .cloned() .collect(), + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), ) .map_err(Into::into) } @@ -455,11 +437,10 @@ impl CrossJoinStream { #[cfg(test)] mod tests { use super::*; - use crate::assert_batches_sorted_eq; - use crate::common::assert_contains; - use crate::physical_plan::common; - use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::{build_table_scan_i32, columns}; + use crate::common; + use crate::test::build_table_scan_i32; + + use datafusion_common::{assert_batches_sorted_eq, assert_contains}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; async fn join_collect( @@ -484,63 +465,60 @@ mod tests { let right_bytes = 27; let left = Statistics { - is_exact: true, - num_rows: Some(left_row_count), - total_byte_size: Some(left_bytes), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(left_bytes), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: Some(right_row_count), - total_byte_size: Some(right_bytes), - column_statistics: Some(vec![ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2), - }]), + num_rows: Precision::Exact(right_row_count), + total_byte_size: Precision::Exact(right_bytes), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2), + }], }; - let result = stats_cartesian_product(left, 3, right, 2); + let result = stats_cartesian_product(left, right); let expected = Statistics { - is_exact: true, - num_rows: Some(left_row_count * right_row_count), - total_byte_size: Some(2 * left_bytes * right_bytes), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count * right_row_count), + total_byte_size: Precision::Exact(2 * left_bytes * right_bytes), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3 * right_row_count), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3 * right_row_count), }, ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2 * left_row_count), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2 * left_row_count), }, - ]), + ], }; assert_eq!(result, expected); @@ -551,63 +529,60 @@ mod tests { let left_row_count = 11; let left = Statistics { - is_exact: true, - num_rows: Some(left_row_count), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: None, // not defined! - total_byte_size: None, // not defined! - column_statistics: Some(vec![ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2), - }]), + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2), + }], }; - let result = stats_cartesian_product(left, 3, right, 2); + let result = stats_cartesian_product(left, right); let expected = Statistics { - is_exact: true, - num_rows: None, - total_byte_size: None, - column_statistics: Some(vec![ + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: None, // we don't know the row count on the right + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: None, // we don't know the row count on the right + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Absent, // we don't know the row count on the right }, ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(12))), - min_value: Some(ScalarValue::Int64(Some(0))), - null_count: Some(2 * left_row_count), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(12))), + min_value: Precision::Exact(ScalarValue::Int64(Some(0))), + null_count: Precision::Exact(2 * left_row_count), }, - ]), + ], }; assert_eq!(result, expected); @@ -615,8 +590,7 @@ mod tests { #[tokio::test] async fn test_join() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table_scan_i32( ("a1", &vec![1, 2, 3]), @@ -632,7 +606,7 @@ mod tests { let (columns, batches) = join_collect(left, right, task_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -654,9 +628,8 @@ mod tests { async fn test_overallocation() -> Result<()> { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_ctx = - SessionContext::with_config_rt(SessionConfig::default(), runtime); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let left = build_table_scan_i32( ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), @@ -679,4 +652,9 @@ mod tests { Ok(()) } + + /// Returns the column names on the schema + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } } diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs similarity index 74% rename from datafusion/core/src/physical_plan/joins/hash_join.rs rename to datafusion/physical-plan/src/joins/hash_join.rs index 0e62540d6d559..4846d0a5e046b 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -15,49 +15,21 @@ // specific language governing permissions and limitations // under the License. -//! Defines the join plan for executing partitions in parallel and then joining the results -//! into a set of partitions. +//! [`HashJoinExec`] Partitioned Hash Join Operator -use ahash::RandomState; -use arrow::array::Array; -use arrow::array::{ - Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, -}; -use arrow::datatypes::{ArrowNativeType, DataType}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use arrow::{ - array::{ - ArrayData, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, - DictionaryArray, FixedSizeBinaryArray, LargeStringArray, PrimitiveArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampSecondArray, UInt32BufferBuilder, UInt64BufferBuilder, - }, - datatypes::{ - Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, - }, - util::bit_util, -}; -use futures::{ready, Stream, StreamExt, TryStreamExt}; -use hashbrown::raw::RawTable; -use smallvec::smallvec; use std::fmt; +use std::mem::size_of; use std::sync::Arc; use std::task::Poll; use std::{any::Any, usize, vec}; -use datafusion_common::cast::{as_dictionary_array, as_string_array}; -use datafusion_execution::memory_pool::MemoryReservation; - -use crate::physical_plan::joins::utils::{ +use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, - get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, + calculate_join_output_ordering, get_final_indices_from_bit_map, + need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::physical_plan::{ +use crate::DisplayAs; +use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, expressions::Column, @@ -65,64 +37,236 @@ use crate::physical_plan::{ hash_utils::create_hashes, joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, - combine_join_equivalence_properties, estimate_join_statistics, - partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, - JoinFilter, JoinOn, + estimate_join_statistics, partitioned_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, - PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use arrow::array::BooleanBufferBuilder; -use arrow::datatypes::TimeUnit; -use datafusion_common::JoinType; -use datafusion_common::{DataFusionError, Result}; -use datafusion_execution::{memory_pool::MemoryConsumer, TaskContext}; use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, }; -use crate::physical_plan::joins::hash_join_utils::JoinHashMap; -type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation); +use arrow::array::{ + Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, + UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder, +}; +use arrow::compute::kernels::cmp::{eq, not_distinct}; +use arrow::compute::{and, take, FilterBuilder}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use arrow::util::bit_util; +use arrow_array::cast::downcast_array; +use arrow_schema::ArrowError; +use datafusion_common::{ + exec_err, internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, +}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::EquivalenceProperties; + +use ahash::RandomState; +use futures::{ready, Stream, StreamExt, TryStreamExt}; + +/// HashTable and input data for the left (build side) of a join +struct JoinLeftData { + /// The hash table with indices into `batch` + hash_map: JoinHashMap, + /// The input rows for the build side + batch: RecordBatch, + /// Memory reservation that tracks memory used by `hash_map` hash table + /// `batch`. Cleared on drop. + #[allow(dead_code)] + reservation: MemoryReservation, +} + +impl JoinLeftData { + /// Create a new `JoinLeftData` from its parts + fn new( + hash_map: JoinHashMap, + batch: RecordBatch, + reservation: MemoryReservation, + ) -> Self { + Self { + hash_map, + batch, + reservation, + } + } -/// Join execution plan executes partitions in parallel and combines them into a set of -/// partitions. + /// Returns the number of rows in the build side + fn num_rows(&self) -> usize { + self.batch.num_rows() + } + + /// return a reference to the hash map + fn hash_map(&self) -> &JoinHashMap { + &self.hash_map + } + + /// returns a reference to the build side batch + fn batch(&self) -> &RecordBatch { + &self.batch + } +} + +/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple +/// partitions using a hash table and an optional filter list to apply post +/// join. +/// +/// # Join Expressions /// -/// Filter expression expected to contain non-equality predicates that can not be pushed -/// down to any of join inputs. -/// In case of outer join, filter applied to only matched rows. +/// This implementation is optimized for evaluating eqijoin predicates ( +/// ` = `) expressions, which are represented as a list of `Columns` +/// in [`Self::on`]. +/// +/// Non-equality predicates, which can not pushed down to a join inputs (e.g. +/// ` != `) are known as "filter expressions" and are evaluated +/// after the equijoin predicates. +/// +/// # "Build Side" vs "Probe Side" +/// +/// HashJoin takes two inputs, which are referred to as the "build" and the +/// "probe". The build side is the first child, and the probe side is the second +/// child. +/// +/// The two inputs are treated differently and it is VERY important that the +/// *smaller* input is placed on the build side to minimize the work of creating +/// the hash table. +/// +/// ```text +/// ┌───────────┐ +/// │ HashJoin │ +/// │ │ +/// └───────────┘ +/// │ │ +/// ┌─────┘ └─────┐ +/// ▼ ▼ +/// ┌────────────┐ ┌─────────────┐ +/// │ Input │ │ Input │ +/// │ [0] │ │ [1] │ +/// └────────────┘ └─────────────┘ +/// +/// "build side" "probe side" +/// ``` +/// +/// Execution proceeds in 2 stages: +/// +/// 1. the **build phase** where a hash table is created from the tuples of the +/// build side. +/// +/// 2. the **probe phase** where the tuples of the probe side are streamed +/// through, checking for matches of the join keys in the hash table. +/// +/// ```text +/// ┌────────────────┐ ┌────────────────┐ +/// │ ┌─────────┐ │ │ ┌─────────┐ │ +/// │ │ Hash │ │ │ │ Hash │ │ +/// │ │ Table │ │ │ │ Table │ │ +/// │ │(keys are│ │ │ │(keys are│ │ +/// │ │equi join│ │ │ │equi join│ │ Stage 2: batches from +/// Stage 1: the │ │columns) │ │ │ │columns) │ │ the probe side are +/// *entire* build │ │ │ │ │ │ │ │ streamed through, and +/// side is read │ └─────────┘ │ │ └─────────┘ │ checked against the +/// into the hash │ ▲ │ │ ▲ │ contents of the hash +/// table │ HashJoin │ │ HashJoin │ table +/// └──────┼─────────┘ └──────────┼─────┘ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// │ │ +/// +/// │ │ +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// ... ... +/// ┌────────────┐ ┌────────────┐ +/// │RecordBatch │ │RecordBatch │ +/// └────────────┘ └────────────┘ +/// +/// build side probe side +/// +/// ``` +/// +/// # Example "Optimal" Plans +/// +/// The differences in the inputs means that for classic "Star Schema Query", +/// the optimal plan will be a **"Right Deep Tree"** . A Star Schema Query is +/// one where there is one large table and several smaller "dimension" tables, +/// joined on `Foreign Key = Primary Key` predicates. +/// +/// A "Right Deep Tree" looks like this large table as the probe side on the +/// lowest join: +/// +/// ```text +/// ┌───────────┐ +/// │ HashJoin │ +/// │ │ +/// └───────────┘ +/// │ │ +/// ┌───────┘ └──────────┐ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────┐ +/// │ small table 1 │ │ HashJoin │ +/// │ "dimension" │ │ │ +/// └───────────────┘ └───┬───┬───┘ +/// ┌──────────┘ └───────┐ +/// │ │ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────┐ +/// │ small table 2 │ │ HashJoin │ +/// │ "dimension" │ │ │ +/// └───────────────┘ └───┬───┬───┘ +/// ┌────────┘ └────────┐ +/// │ │ +/// ▼ ▼ +/// ┌───────────────┐ ┌───────────────┐ +/// │ small table 3 │ │ large table │ +/// │ "dimension" │ │ "fact" │ +/// └───────────────┘ └───────────────┘ +/// ``` #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed - pub(crate) left: Arc, + pub left: Arc, /// right (probe) side which are filtered by the hash table - pub(crate) right: Arc, - /// Set of common columns used to join on - pub(crate) on: Vec<(Column, Column)>, + pub right: Arc, + /// Set of equijoin columns from the relations: `(left_col, right_col)` + pub on: Vec<(Column, Column)>, /// Filters which are applied while finding matching rows - pub(crate) filter: Option, - /// How the join is performed - pub(crate) join_type: JoinType, - /// The schema once the join is applied + pub filter: Option, + /// How the join is performed (`OUTER`, `INNER`, etc) + pub join_type: JoinType, + /// The output schema for the join schema: SchemaRef, - /// Build-side data + /// Future that consumes left input and builds the hash table left_fut: OnceAsync, - /// Shares the `RandomState` for the hashing algorithm + /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, + /// Output order + output_order: Option>, /// Partitioning mode to use - pub(crate) mode: PartitionMode, + pub mode: PartitionMode, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, + /// Null matching behavior: If `null_equals_null` is true, rows that have + /// `null`s in both left and right equijoin columns will be matched. + /// Otherwise, rows that have `null`s in the join columns will not be + /// matched and thus will not appear in the output. + pub null_equals_null: bool, } impl HashJoinExec { /// Tries to create a new [HashJoinExec]. + /// /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. pub fn try_new( @@ -137,9 +281,7 @@ impl HashJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); if on.is_empty() { - return Err(DataFusionError::Plan( - "On constraints in HashJoinExec should be non-empty".to_string(), - )); + return plan_err!("On constraints in HashJoinExec should be non-empty"); } check_join_is_valid(&left_schema, &right_schema, &on)?; @@ -149,6 +291,16 @@ impl HashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); + let output_order = calculate_join_output_ordering( + left.output_ordering().unwrap_or(&[]), + right.output_ordering().unwrap_or(&[]), + *join_type, + &on, + left_schema.fields.len(), + &Self::maintains_input_order(*join_type), + Some(Self::probe_side()), + ); + Ok(HashJoinExec { left, right, @@ -162,6 +314,7 @@ impl HashJoinExec { metrics: ExecutionPlanMetricsSet::new(), column_indices, null_equals_null, + output_order, }) } @@ -199,6 +352,47 @@ impl HashJoinExec { pub fn null_equals_null(&self) -> bool { self.null_equals_null } + + /// Calculate order preservation flags for this hash join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner | JoinType::RightAnti | JoinType::RightSemi + ), + ] + } + + /// Get probe side information for the hash join. + pub fn probe_side() -> JoinSide { + // In current implementation right side is always probe side. + JoinSide::Right + } +} + +impl DisplayAs for HashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .collect::>() + .join(", "); + write!( + f, + "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}", + self.mode, self.join_type, on, display_filter + ) + } + } + } } impl ExecutionPlan for HashJoinExec { @@ -257,14 +451,14 @@ impl ExecutionPlan for HashJoinExec { )); if breaking { - Err(DataFusionError::Plan(format!( + plan_err!( "Join Error: The join with cannot be executed with unbounded inputs. {}", if left && right { "Currently, we do not support unbounded inputs on both sides." } else { "Please consider a different type of join or sources." } - ))) + ) } else { Ok(left || right) } @@ -300,21 +494,39 @@ impl ExecutionPlan for HashJoinExec { } } - // TODO Output ordering might be kept for some cases. - // For example if it is inner join then the stream side order can be kept fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + self.output_order.as_deref() + } + + // For [JoinType::Inner] and [JoinType::RightSemi] in hash joins, the probe phase initiates by + // applying the hash function to convert the join key(s) in each row into a hash value from the + // probe side table in the order they're arranged. The hash value is used to look up corresponding + // entries in the hash table that was constructed from the build side table during the build phase. + // + // Because of the immediate generation of result rows once a match is found, + // the output of the join tends to follow the order in which the rows were read from + // the probe side table. This is simply due to the sequence in which the rows were processed. + // Hence, it appears that the hash join is preserving the order of the probe side. + // + // Meanwhile, in the case of a [JoinType::RightAnti] hash join, + // the unmatched rows from the probe side are also kept in order. + // This is because the **`RightAnti`** join is designed to return rows from the right + // (probe side) table that have no match in the left (build side) table. Because the rows + // are processed sequentially in the probe phase, and unmatched rows are directly output + // as results, these results tend to retain the order of the probe side table. + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - self.on(), + &self.join_type, self.schema(), + &self.maintains_input_order(), + Some(Self::probe_side()), + self.on(), ) } @@ -346,12 +558,13 @@ impl ExecutionPlan for HashJoinExec { let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); + if self.mode == PartitionMode::Partitioned && left_partitions != right_partitions { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Invalid HashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ - consider using RepartitionExec", - ))); + consider using RepartitionExec" + ); } let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); @@ -385,10 +598,10 @@ impl ExecutionPlan for HashJoinExec { )) } PartitionMode::Auto => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid HashJoinExec, unsupported PartitionMode {:?} in execute()", PartitionMode::Auto - ))); + ); } }; @@ -417,27 +630,11 @@ impl ExecutionPlan for HashJoinExec { })) } - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default => { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - write!( - f, - "HashJoinExec: mode={:?}, join_type={:?}, on={:?}{}", - self.mode, self.join_type, self.on, display_filter - ) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` @@ -446,10 +643,13 @@ impl ExecutionPlan for HashJoinExec { self.right.clone(), self.on.clone(), &self.join_type, + &self.schema, ) } } +/// Reads the left (build) side of the input, buffering it in memory, to build a +/// hash table (`LeftJoinData`) async fn collect_left_input( partition: Option, random_state: RandomState, @@ -463,16 +663,10 @@ async fn collect_left_input( let (left_input, left_input_partition) = if let Some(partition) = partition { (left, partition) + } else if left.output_partitioning().partition_count() != 1 { + (Arc::new(CoalescePartitionsExec::new(left)) as _, 0) } else { - let merge = { - if left.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(left)) - } else { - left - } - }; - - (merge, 0) + (left, 0) }; // Depending on partition argument load single partition or whole left side in memory @@ -510,15 +704,16 @@ async fn collect_left_input( ) })? / 7) .next_power_of_two(); - // 32 bytes per `(u64, SmallVec<[u64; 1]>)` + // 16 bytes per `(u64, u64)` // + 1 byte for each bucket - // + 16 bytes fixed - let estimated_hastable_size = 32 * estimated_buckets + estimated_buckets + 16; + // + fixed size of JoinHashMap (RawTable + Vec) + let estimated_hastable_size = + 16 * estimated_buckets + estimated_buckets + size_of::(); reservation.try_grow(estimated_hastable_size)?; metrics.build_mem_used.add(estimated_hastable_size); - let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows)); + let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; for batch in batches.iter() { @@ -531,74 +726,101 @@ async fn collect_left_input( offset, &random_state, &mut hashes_buffer, + 0, )?; offset += batch.num_rows(); } // Merge all batches into a single batch, so we // can directly index into the arrays let single_batch = concat_batches(&schema, &batches, num_rows)?; + let data = JoinLeftData::new(hashmap, single_batch, reservation); - Ok((hashmap, single_batch, reservation)) + Ok(data) } /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th -pub fn update_hash( +pub fn update_hash( on: &[Column], batch: &RecordBatch, - hash_map: &mut JoinHashMap, + hash_map: &mut T, offset: usize, random_state: &RandomState, hashes_buffer: &mut Vec, -) -> Result<()> { + deleted_offset: usize, +) -> Result<()> +where + T: JoinHashMapType, +{ // evaluate the keys let keys_values = on .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; // calculate the hash values let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + // For usual JoinHashmap, the implementation is void. + hash_map.extend_zero(batch.num_rows()); + // insert hashes to key of the hashmap + let (mut_map, mut_list) = hash_map.get_mut(); for (row, hash_value) in hash_values.iter().enumerate() { - let item = hash_map - .0 - .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, indices)) = item { - indices.push((row + offset) as u64); + let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, index)) = item { + // Already exists: add index to next array + let prev_index = *index; + // Store new value inside hashmap + *index = (row + offset + 1) as u64; + // Update chained Vec at row + offset with previous value + mut_list[row + offset - deleted_offset] = prev_index; } else { - hash_map.0.insert( + mut_map.insert( *hash_value, - (*hash_value, smallvec![(row + offset) as u64]), + // store the value + 1 as 0 value reserved for end of list + (*hash_value, (row + offset + 1) as u64), |(hash, _)| *hash, ); + // chained list at (row + offset) is already initialized with 0 + // meaning end of list } } Ok(()) } -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +/// [`Stream`] for [`HashJoinExec`] that does the actual join. +/// +/// This stream: +/// +/// 1. Reads the entire left input (build) and constructs a hash table +/// +/// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins +/// them with the contents of the hash table struct HashJoinStream { /// Input schema schema: Arc, - /// columns from the left + /// equijoin columns from the left (build side) on_left: Vec, - /// columns from the right used to compute the hash + /// equijoin columns from the right (probe side) on_right: Vec, - /// join filter + /// optional join filter filter: Option, - /// type of the join + /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future for data from left side + /// future which builds hash table from left side left_fut: OnceFut, - /// Keeps track of the left side rows whether they are visited + /// Which left (probe) side rows have been matches while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. visited_left_side: Option, - /// right + /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// There is nothing to process anymore and left side is processed in case of left join + /// The join output is complete. For outer joins, this is used to + /// distinguish when the input stream is exhausted and when any unmatched + /// rows are output. is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, @@ -616,83 +838,54 @@ impl RecordBatchStream for HashJoinStream { } } -/// Gets build and probe indices which satisfy the on condition (including -/// the equality condition and the join filter) in the join. -#[allow(clippy::too_many_arguments)] -pub fn build_join_indices( - probe_batch: &RecordBatch, - build_hashmap: &JoinHashMap, - build_input_buffer: &RecordBatch, - on_build: &[Column], - on_probe: &[Column], - filter: Option<&JoinFilter>, - random_state: &RandomState, - null_equals_null: bool, - hashes_buffer: &mut Vec, - offset: Option, - build_side: JoinSide, -) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices that satisfy the equality condition, like `left.a1 = right.a2` - let (build_indices, probe_indices) = build_equal_condition_join_indices( - build_hashmap, - build_input_buffer, - probe_batch, - on_build, - on_probe, - random_state, - null_equals_null, - hashes_buffer, - offset, - )?; - if let Some(filter) = filter { - // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - build_input_buffer, - probe_batch, - build_indices, - probe_indices, - filter, - build_side, - ) - } else { - Ok((build_indices, probe_indices)) - } -} - -// Returns build/probe indices satisfying the equality condition. -// On LEFT.b1 = RIGHT.b2 -// LEFT Table: -// a1 b1 c1 -// 1 1 10 -// 3 3 30 -// 5 5 50 -// 7 7 70 -// 9 8 90 -// 11 8 110 -// 13 10 130 -// RIGHT Table: -// a2 b2 c2 -// 2 2 20 -// 4 4 40 -// 6 6 60 -// 8 8 80 -// 10 10 100 -// 12 10 120 -// The result is -// "+----+----+-----+----+----+-----+", -// "| a1 | b1 | c1 | a2 | b2 | c2 |", -// "+----+----+-----+----+----+-----+", -// "| 11 | 8 | 110 | 8 | 8 | 80 |", -// "| 13 | 10 | 130 | 10 | 10 | 100 |", -// "| 13 | 10 | 130 | 12 | 10 | 120 |", -// "| 9 | 8 | 90 | 8 | 8 | 80 |", -// "+----+----+-----+----+----+-----+" -// And the result of build and probe indices are: -// Build indices: 5, 6, 6, 4 -// Probe indices: 3, 4, 5, 3 +/// Returns build/probe indices satisfying the equality condition. +/// +/// # Example +/// +/// For `LEFT.b1 = RIGHT.b2`: +/// LEFT (build) Table: +/// ```text +/// a1 b1 c1 +/// 1 1 10 +/// 3 3 30 +/// 5 5 50 +/// 7 7 70 +/// 9 8 90 +/// 11 8 110 +/// 13 10 130 +/// ``` +/// +/// RIGHT (probe) Table: +/// ```text +/// a2 b2 c2 +/// 2 2 20 +/// 4 4 40 +/// 6 6 60 +/// 8 8 80 +/// 10 10 100 +/// 12 10 120 +/// ``` +/// +/// The result is +/// ```text +/// "+----+----+-----+----+----+-----+", +/// "| a1 | b1 | c1 | a2 | b2 | c2 |", +/// "+----+----+-----+----+----+-----+", +/// "| 9 | 8 | 90 | 8 | 8 | 80 |", +/// "| 11 | 8 | 110 | 8 | 8 | 80 |", +/// "| 13 | 10 | 130 | 10 | 10 | 100 |", +/// "| 13 | 10 | 130 | 12 | 10 | 120 |", +/// "+----+----+-----+----+----+-----+" +/// ``` +/// +/// And the result of build and probe indices are: +/// ```text +/// Build indices: 4, 5, 6, 6 +/// Probe indices: 3, 3, 4, 5 +/// ``` #[allow(clippy::too_many_arguments)] -pub fn build_equal_condition_join_indices( - build_hashmap: &JoinHashMap, +pub fn build_equal_condition_join_indices( + build_hashmap: &T, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, build_on: &[Column], @@ -700,17 +893,19 @@ pub fn build_equal_condition_join_indices( random_state: &RandomState, null_equals_null: bool, hashes_buffer: &mut Vec, - offset: Option, + filter: Option<&JoinFilter>, + build_side: JoinSide, + deleted_offset: Option, ) -> Result<(UInt64Array, UInt32Array)> { let keys_values = probe_on .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .map(|c| c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())) .collect::>>()?; let build_join_values = build_on .iter() .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) + c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows()) }) .collect::>>()?; hashes_buffer.clear(); @@ -719,380 +914,152 @@ pub fn build_equal_condition_join_indices( // Using a buffer builder to avoid slower normal builder let mut build_indices = UInt64BufferBuilder::new(0); let mut probe_indices = UInt32BufferBuilder::new(0); - let offset_value = offset.unwrap_or(0); - // Visit all of the probe rows - for (row, hash_value) in hash_values.iter().enumerate() { + // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + // Build Indices: [5, 4, 3] + // Probe Indices: [1, 1, 1] + // + // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. + // Let's consider probe rows [0,1] as an example: + // + // When the probe iteration sequence is reversed, the following pairings can be derived: + // + // For probe row 1: + // (5, 1) + // (4, 1) + // (3, 1) + // + // For probe row 0: + // (5, 0) + // (4, 0) + // (3, 0) + // + // After reversing both sets of indices, we obtain reversed indices: + // + // (3,0) + // (4,0) + // (5,0) + // (3,1) + // (4,1) + // (5,1) + // + // With this approach, the lexicographic order on both the probe side and the build side is preserved. + let hash_map = build_hashmap.get_map(); + let next_chain = build_hashmap.get_list(); + for (row, hash_value) in hash_values.iter().enumerate().rev() { // Get the hash and find it in the build index // For every item on the build and probe we check if it matches // This possibly contains rows with hash collisions, // So we have to check here whether rows are equal or not - if let Some((_, indices)) = build_hashmap - .0 - .get(*hash_value, |(hash, _)| *hash_value == *hash) + if let Some((_, index)) = + hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) { - for &i in indices { - // Check hash collisions - let offset_build_index = i as usize - offset_value; - // Check hash collisions - if equal_rows( - offset_build_index, - row, - &build_join_values, - &keys_values, - null_equals_null, - )? { - build_indices.append(offset_build_index as u64); - probe_indices.append(row as u32); + let mut i = *index - 1; + loop { + let build_row_value = if let Some(offset) = deleted_offset { + // This arguments means that we prune the next index way before here. + if i < offset as u64 { + // End of the list due to pruning + break; + } + i - offset as u64 + } else { + i + }; + build_indices.append(build_row_value); + probe_indices.append(row as u32); + // Follow the chain to get the next index value + let next = next_chain[build_row_value as usize]; + if next == 0 { + // end of list + break; } + i = next - 1; } } } - let build = ArrayData::builder(DataType::UInt64) - .len(build_indices.len()) - .add_buffer(build_indices.finish()) - .build()?; - let probe = ArrayData::builder(DataType::UInt32) - .len(probe_indices.len()) - .add_buffer(probe_indices.finish()) - .build()?; + // Reversing both sets of indices + build_indices.as_slice_mut().reverse(); + probe_indices.as_slice_mut().reverse(); - Ok(( - PrimitiveArray::::from(build), - PrimitiveArray::::from(probe), - )) -} + let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); + let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + let (left, right) = if let Some(filter) = filter { + // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( + build_input_buffer, + probe_batch, + left, + right, + filter, + build_side, + )? + } else { + (left, right) + }; - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - (true, true) => $null_equals_null, - _ => false, - } - }}; + equal_rows_arr( + &left, + &right, + &build_join_values, + &keys_values, + null_equals_null, + ) } -macro_rules! equal_rows_elem_with_string_dict { - ($key_array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{ - let left_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($l).unwrap(); - let right_array: &DictionaryArray<$key_array_type> = - as_dictionary_array::<$key_array_type>($r).unwrap(); - - let (left_values, left_values_index) = { - let keys_col = left_array.keys(); - if keys_col.is_valid($left) { - let values_index = keys_col - .value($left) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(left_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(left_array.values()).unwrap(), None) - } - }; - let (right_values, right_values_index) = { - let keys_col = right_array.keys(); - if keys_col.is_valid($right) { - let values_index = keys_col - .value($right) - .to_usize() - .expect("Can not convert index to usize in dictionary"); - - ( - as_string_array(right_array.values()).unwrap(), - Some(values_index), - ) - } else { - (as_string_array(right_array.values()).unwrap(), None) - } - }; - - match (left_values_index, right_values_index) { - (Some(left_values_index), Some(right_values_index)) => { - left_values.value(left_values_index) - == right_values.value(right_values_index) - } - (None, None) => $null_equals_null, - _ => false, - } - }}; +// version of eq_dyn supporting equality on null arrays +fn eq_dyn_null( + left: &dyn Array, + right: &dyn Array, + null_equals_null: bool, +) -> Result { + match (left.data_type(), right.data_type()) { + _ if null_equals_null => not_distinct(&left, &right), + _ => eq(&left, &right), + } } -/// Left and right row have equal values -/// If more data types are supported here, please also add the data types in can_hash function -/// to generate hash join logical plan. -fn equal_rows( - left: usize, - right: usize, +pub fn equal_rows_arr( + indices_left: &UInt64Array, + indices_right: &UInt32Array, left_arrays: &[ArrayRef], right_arrays: &[ArrayRef], null_equals_null: bool, -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => { - // lhs and rhs are both `DataType::Null`, so the equal result - // is dependent on `null_equals_null` - null_equals_null - } - DataType::Boolean => { - equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null) - } - DataType::Int8 => { - equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null) - } - DataType::Int16 => { - equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null) - } - DataType::Int32 => { - equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null) - } - DataType::Int64 => { - equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null) - } - DataType::UInt8 => { - equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null) - } - DataType::UInt16 => { - equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null) - } - DataType::UInt32 => { - equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null) - } - DataType::UInt64 => { - equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null) - } - DataType::Float32 => { - equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null) - } - DataType::Float64 => { - equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null) - } - DataType::Date32 => { - equal_rows_elem!(Date32Array, l, r, left, right, null_equals_null) - } - DataType::Date64 => { - equal_rows_elem!(Date64Array, l, r, left, right, null_equals_null) - } - DataType::Time32(time_unit) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!(Time32SecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Millisecond => { - equal_rows_elem!(Time32MillisecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Time64(time_unit) => match time_unit { - TimeUnit::Microsecond => { - equal_rows_elem!(Time64MicrosecondArray, l, r, left, right, null_equals_null) - } - TimeUnit::Nanosecond => { - equal_rows_elem!(Time64NanosecondArray, l, r, left, right, null_equals_null) - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => { - equal_rows_elem!( - TimestampSecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Millisecond => { - equal_rows_elem!( - TimestampMillisecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Microsecond => { - equal_rows_elem!( - TimestampMicrosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - TimeUnit::Nanosecond => { - equal_rows_elem!( - TimestampNanosecondArray, - l, - r, - left, - right, - null_equals_null - ) - } - }, - DataType::Utf8 => { - equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) - } - DataType::LargeUtf8 => { - equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null) - } - DataType::FixedSizeBinary(_) => { - equal_rows_elem!(FixedSizeBinaryArray, l, r, left, right, null_equals_null) - } - DataType::Decimal128(_, lscale) => match r.data_type() { - DataType::Decimal128(_, rscale) => { - if lscale == rscale { - equal_rows_elem!( - Decimal128Array, - l, - r, - left, - right, - null_equals_null - ) - } else { - err = Some(Err(DataFusionError::Internal( - "Inconsistent Decimal data type in hasher, the scale should be same".to_string(), - ))); - false - } - } - _ => { - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - match key_type.as_ref() { - DataType::Int8 => { - equal_rows_elem_with_string_dict!( - Int8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int16 => { - equal_rows_elem_with_string_dict!( - Int16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int32 => { - equal_rows_elem_with_string_dict!( - Int32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::Int64 => { - equal_rows_elem_with_string_dict!( - Int64Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt8 => { - equal_rows_elem_with_string_dict!( - UInt8Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt16 => { - equal_rows_elem_with_string_dict!( - UInt16Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt32 => { - equal_rows_elem_with_string_dict!( - UInt32Type, - l, - r, - left, - right, - null_equals_null - ) - } - DataType::UInt64 => { - equal_rows_elem_with_string_dict!( - UInt64Type, - l, - r, - left, - right, - null_equals_null - ) - } - _ => { - // should not happen - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - } - } - other => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {other}" - )))); - false - } - }); +) -> Result<(UInt64Array, UInt32Array)> { + let mut iter = left_arrays.iter().zip(right_arrays.iter()); + + let (first_left, first_right) = iter.next().ok_or_else(|| { + DataFusionError::Internal( + "At least one array should be provided for both left and right".to_string(), + ) + })?; + + let arr_left = take(first_left.as_ref(), indices_left, None)?; + let arr_right = take(first_right.as_ref(), indices_right, None)?; + + let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?; + + // Use map and try_fold to iterate over the remaining pairs of arrays. + // In each iteration, take is used on the pair of arrays and their equality is determined. + // The results are then folded (combined) using the and function to get a final equality result. + equal = iter + .map(|(left, right)| { + let arr_left = take(left.as_ref(), indices_left, None)?; + let arr_right = take(right.as_ref(), indices_right, None)?; + eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null) + }) + .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; + + let filter_builder = FilterBuilder::new(&equal).optimize().build(); - err.unwrap_or(Ok(res)) + let left_filtered = filter_builder.filter(indices_left)?; + let right_filtered = filter_builder.filter(indices_right)?; + + Ok(( + downcast_array(left_filtered.as_ref()), + downcast_array(right_filtered.as_ref()), + )) } impl HashJoinStream { @@ -1103,32 +1070,33 @@ impl HashJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>> { let build_timer = self.join_metrics.build_time.timer(); + // build hash table from left (build) side, if not yet done let left_data = match ready!(self.left_fut.get(cx)) { Ok(left_data) => left_data, Err(e) => return Poll::Ready(Some(Err(e))), }; build_timer.done(); - // Reserving memory for visited_left_side bitmap in case it hasn't been initialied yet + // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it if self.visited_left_side.is_none() && need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 - let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 8); + let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); self.reservation.try_grow(visited_bitmap_size)?; self.join_metrics.build_mem_used.add(visited_bitmap_size); } let visited_left_side = self.visited_left_side.get_or_insert_with(|| { - let num_rows = left_data.1.num_rows(); + let num_rows = left_data.num_rows(); if need_produce_result_in_final(self.join_type) { - // these join type need the bitmap to identify which row has be matched or unmatched. - // For the `left semi` join, need to use the bitmap to produce the matched row in the left side - // For the `left` join, need to use the bitmap to produce the unmatched row in the left side with null - // For the `left anti` join, need to use the bitmap to produce the unmatched row in the left side - // For the `full` join, need to use the bitmap to produce the unmatched row in the left side with null + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null let mut buffer = BooleanBufferBuilder::new(num_rows); buffer.append_n(num_rows, false); buffer @@ -1137,6 +1105,7 @@ impl HashJoinStream { } }); let mut hashes_buffer = vec![]; + // get next right (probe) input batch self.right .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { @@ -1147,18 +1116,18 @@ impl HashJoinStream { let timer = self.join_metrics.join_time.timer(); // get the matched two indices for the on condition - let left_right_indices = build_join_indices( + let left_right_indices = build_equal_condition_join_indices( + left_data.hash_map(), + left_data.batch(), &batch, - &left_data.0, - &left_data.1, &self.on_left, &self.on_right, - self.filter.as_ref(), &self.random_state, self.null_equals_null, &mut hashes_buffer, - None, + self.filter.as_ref(), JoinSide::Left, + None, ); let result = match left_right_indices { @@ -1181,10 +1150,10 @@ impl HashJoinStream { let result = build_batch_from_indices( &self.schema, - &left_data.1, + left_data.batch(), &batch, - left_side, - right_side, + &left_side, + &right_side, &self.column_indices, JoinSide::Left, ); @@ -1192,9 +1161,9 @@ impl HashJoinStream { self.join_metrics.output_rows.add(batch.num_rows()); Some(result) } - Err(err) => Some(Err(DataFusionError::Execution(format!( - "Fail to build join indices in HashJoinExec, error:{err}", - )))), + Err(err) => Some(exec_err!( + "Fail to build join indices in HashJoinExec, error:{err}" + )), }; timer.done(); result @@ -1213,10 +1182,10 @@ impl HashJoinStream { // use the left and right indices to produce the batch result let result = build_batch_from_indices( &self.schema, - &left_data.1, + left_data.batch(), &empty_right_batch, - left_side, - right_side, + &left_side, + &right_side, &self.column_indices, JoinSide::Left, ); @@ -1256,34 +1225,22 @@ impl Stream for HashJoinStream { mod tests { use std::sync::Arc; - use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; - use arrow::datatypes::{DataType, Field, Schema}; - use smallvec::smallvec; - - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::Literal; - - use crate::execution::context::SessionConfig; - use crate::physical_expr::expressions::BinaryExpr; - use crate::prelude::SessionContext; + use super::*; use crate::{ - assert_batches_sorted_eq, - common::assert_contains, - physical_plan::{ - common, - expressions::Column, - hash_utils::create_hashes, - joins::{hash_join::build_equal_condition_join_indices, utils::JoinSide}, - memory::MemoryExec, - repartition::RepartitionExec, - }, - test::exec::MockExec, - test::{build_table_i32, columns}, + common, expressions::Column, hash_utils::create_hashes, + joins::hash_join::build_equal_condition_join_indices, memory::MemoryExec, + repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; + + use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use super::*; + use hashbrown::raw::RawTable; fn build_table( a: (&str, &Vec), @@ -1404,8 +1361,7 @@ mod tests { #[tokio::test] async fn join_inner_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1434,7 +1390,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1450,8 +1406,7 @@ mod tests { #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1479,7 +1434,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1495,8 +1450,7 @@ mod tests { #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1517,7 +1471,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1534,8 +1488,7 @@ mod tests { #[tokio::test] async fn join_inner_two() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1564,7 +1517,7 @@ mod tests { assert_eq!(batches.len(), 1); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1582,8 +1535,7 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] async fn join_inner_one_two_parts_left() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1619,7 +1571,7 @@ mod tests { assert_eq!(batches.len(), 1); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1637,8 +1589,7 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] async fn join_inner_one_two_parts_right() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1672,7 +1623,7 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1685,7 +1636,7 @@ mod tests { let stream = join.execute(1, task_ctx.clone())?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1713,8 +1664,7 @@ mod tests { #[tokio::test] async fn join_left_multi_batch() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1738,7 +1688,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1755,8 +1705,7 @@ mod tests { #[tokio::test] async fn join_full_multi_batch() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1781,7 +1730,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1800,8 +1749,7 @@ mod tests { #[tokio::test] async fn join_left_empty_right() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1822,7 +1770,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1837,8 +1785,7 @@ mod tests { #[tokio::test] async fn join_full_empty_right() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1859,7 +1806,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1874,8 +1821,7 @@ mod tests { #[tokio::test] async fn join_left_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1902,7 +1848,7 @@ mod tests { .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1918,8 +1864,7 @@ mod tests { #[tokio::test] async fn partitioned_join_left_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1946,7 +1891,7 @@ mod tests { .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1982,8 +1927,7 @@ mod tests { #[tokio::test] async fn join_left_semi() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 @@ -2001,7 +1945,7 @@ mod tests { let batches = common::collect(stream).await?; // ignore the order - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -2017,8 +1961,7 @@ mod tests { #[tokio::test] async fn join_left_semi_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2062,7 +2005,7 @@ mod tests { let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -2090,7 +2033,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -2104,8 +2047,7 @@ mod tests { #[tokio::test] async fn join_right_semi() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2123,7 +2065,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2139,8 +2081,7 @@ mod tests { #[tokio::test] async fn join_right_semi_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2184,7 +2125,7 @@ mod tests { let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2210,7 +2151,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2225,8 +2166,7 @@ mod tests { #[tokio::test] async fn join_left_anti() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 @@ -2243,7 +2183,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", "+----+----+----+", @@ -2259,8 +2199,7 @@ mod tests { #[tokio::test] async fn join_left_anti_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 @@ -2302,7 +2241,7 @@ mod tests { let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -2334,7 +2273,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -2353,8 +2292,7 @@ mod tests { #[tokio::test] async fn join_right_anti() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( @@ -2370,7 +2308,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2385,8 +2323,7 @@ mod tests { #[tokio::test] async fn join_right_anti_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 @@ -2429,7 +2366,7 @@ mod tests { let stream = join.execute(0, task_ctx.clone())?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2465,7 +2402,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -2482,8 +2419,7 @@ mod tests { #[tokio::test] async fn join_right_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2504,7 +2440,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -2521,8 +2457,7 @@ mod tests { #[tokio::test] async fn partitioned_join_right_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2544,7 +2479,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -2561,8 +2496,7 @@ mod tests { #[tokio::test] async fn join_full_one() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2586,7 +2520,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -2616,8 +2550,10 @@ mod tests { create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; // Create hash collisions (same hashes) - hashmap_left.insert(hashes[0], (hashes[0], smallvec![0, 1]), |(h, _)| *h); - hashmap_left.insert(hashes[1], (hashes[1], smallvec![0, 1]), |(h, _)| *h); + hashmap_left.insert(hashes[0], (hashes[0], 1), |(h, _)| *h); + hashmap_left.insert(hashes[1], (hashes[1], 1), |(h, _)| *h); + + let next = vec![2, 0]; let right = build_table_i32( ("a", &vec![10, 20]), @@ -2625,10 +2561,11 @@ mod tests { ("c", &vec![30, 40]), ); - let left_data = (JoinHashMap(hashmap_left), left); + let join_hash_map = JoinHashMap::new(hashmap_left, next); + let (l, r) = build_equal_condition_join_indices( - &left_data.0, - &left_data.1, + &join_hash_map, + &left, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], @@ -2636,6 +2573,8 @@ mod tests { false, &mut vec![0; right.num_rows()], None, + JoinSide::Left, + None, )?; let mut left_ids = UInt64Builder::with_capacity(0); @@ -2655,8 +2594,7 @@ mod tests { #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a", &vec![1, 2, 3]), ("b", &vec![4, 5, 7]), @@ -2681,7 +2619,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", "+---+---+---+----+---+----+", @@ -2720,8 +2658,7 @@ mod tests { #[tokio::test] async fn join_inner_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2746,7 +2683,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+---+", "| a | b | c | a | b | c |", "+---+---+---+----+---+---+", @@ -2761,8 +2698,7 @@ mod tests { #[tokio::test] async fn join_left_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2787,7 +2723,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+---+", "| a | b | c | a | b | c |", "+---+---+---+----+---+---+", @@ -2805,8 +2741,7 @@ mod tests { #[tokio::test] async fn join_right_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2831,7 +2766,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+---+", "| a | b | c | a | b | c |", "+---+---+---+----+---+---+", @@ -2848,8 +2783,7 @@ mod tests { #[tokio::test] async fn join_full_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -2874,7 +2808,7 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+---+", "| a | b | c | a | b | c |", "+---+---+---+----+---+---+", @@ -2917,12 +2851,11 @@ mod tests { let join = join(left, right, on, &JoinType::Inner, false)?; - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; - let expected = vec![ + let expected = [ "+------------+---+------------+---+", "| date | n | date | n |", "+------------+---+------------+---+", @@ -2946,7 +2879,7 @@ mod tests { // right input stream returns one good batch and then one error. // The error should be returned. - let err = Err(DataFusionError::Execution("bad data error".to_string())); + let err = exec_err!("bad data error"); let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); let on = vec![( @@ -2977,13 +2910,12 @@ mod tests { false, ) .unwrap(); - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let stream = join.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::physical_plan::common::collect(stream) + let result_string = crate::common::collect(stream) .await .unwrap_err() .to_string(); @@ -3025,9 +2957,8 @@ mod tests { for join_type in join_types { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_ctx = - SessionContext::with_config_rt(SessionConfig::default(), runtime); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let join = join(left.clone(), right.clone(), on.clone(), &join_type, false)?; @@ -3096,8 +3027,10 @@ mod tests { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); let session_config = SessionConfig::default().with_batch_size(50); - let session_ctx = SessionContext::with_config_rt(session_config, runtime); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let join = HashJoinExec::try_new( left.clone(), @@ -3123,4 +3056,9 @@ mod tests { Ok(()) } + + /// Returns the column names on the schema + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } } diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs similarity index 82% rename from datafusion/core/src/physical_plan/joins/mod.rs rename to datafusion/physical-plan/src/joins/mod.rs index 0a1bc147b80cd..6ddf19c511933 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -25,12 +25,15 @@ pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; -mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +#[cfg(test)] +pub mod test_utils; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] /// Partitioning mode to use for hash join pub enum PartitionMode { @@ -42,3 +45,12 @@ pub enum PartitionMode { /// It will also consider swapping the left and right inputs for the Join Auto, } + +/// Partitioning mode to use for symmetric hash join +#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq)] +pub enum StreamJoinPartitionMode { + /// Left/right children are partitioned using the left and right keys + Partitioned, + /// Both sides will collected into one partition + SinglePartition, +} diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs similarity index 90% rename from datafusion/core/src/physical_plan/joins/nested_loop_join.rs rename to datafusion/physical-plan/src/joins/nested_loop_join.rs index 82e677f7205d8..6951642ff8016 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -19,39 +19,39 @@ //! The nested loop join can execute in parallel by partitions and it is //! determined by the [`JoinType`]. -use crate::physical_plan::joins::utils::{ - adjust_right_output_partitioning, append_right_indices, apply_join_filter_to_indices, - build_batch_from_indices, build_join_schema, check_join_is_valid, - combine_join_equivalence_properties, estimate_join_statistics, get_anti_indices, +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; +use std::task::Poll; + +use crate::coalesce_batches::concat_batches; +use crate::joins::utils::{ + append_right_indices, apply_join_filter_to_indices, build_batch_from_indices, + build_join_schema, check_join_is_valid, estimate_join_statistics, get_anti_indices, get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, - get_semi_u64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinSide, - OnceAsync, OnceFut, + get_semi_u64_indices, partitioned_join_output_partitioning, BuildProbeJoinMetrics, + ColumnIndex, JoinFilter, OnceAsync, OnceFut, }; -use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::physical_plan::{ - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, }; + use arrow::array::{ BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, }; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{DataFusionError, Statistics}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_common::{exec_err, DataFusionError, JoinSide, Result, Statistics}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::TaskContext; use datafusion_expr::JoinType; +use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr}; -use futures::{ready, Stream, StreamExt, TryStreamExt}; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; -use std::task::Poll; -use crate::physical_plan::coalesce_batches::concat_batches; -use datafusion_common::Result; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_execution::TaskContext; +use futures::{ready, Stream, StreamExt, TryStreamExt}; /// Data of the inner table side type JoinLeftData = (RecordBatch, MemoryReservation); @@ -73,6 +73,7 @@ type JoinLeftData = (RecordBatch, MemoryReservation); /// |--------------------------------|--------------------------------------------|-------------| /// | Inner/Left/LeftSemi/LeftAnti | (UnspecifiedDistribution, SinglePartition) | right | /// | Right/RightSemi/RightAnti/Full | (SinglePartition, UnspecifiedDistribution) | left | +/// | Full | (SinglePartition, SinglePartition) | left | /// #[derive(Debug)] pub struct NestedLoopJoinExec { @@ -118,6 +119,44 @@ impl NestedLoopJoinExec { metrics: Default::default(), }) } + + /// left side + pub fn left(&self) -> &Arc { + &self.left + } + + /// right side + pub fn right(&self) -> &Arc { + &self.right + } + + /// Filters applied before join output + pub fn filter(&self) -> Option<&JoinFilter> { + self.filter.as_ref() + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } +} + +impl DisplayAs for NestedLoopJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + write!( + f, + "NestedLoopJoinExec: join_type={:?}{}", + self.join_type, display_filter + ) + } + } + } } impl ExecutionPlan for NestedLoopJoinExec { @@ -131,25 +170,15 @@ impl ExecutionPlan for NestedLoopJoinExec { fn output_partitioning(&self) -> Partitioning { // the partition of output is determined by the rule of `required_input_distribution` - // TODO we can replace it by `partitioned_join_output_partitioning` - match self.join_type { - // use the left partition - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full => self.left.output_partitioning(), - // use the right partition - JoinType::Right => { - // if the partition of right is hash, - // and the right partition should be adjusted the column index for the right expr - adjust_right_output_partitioning( - self.right.output_partitioning(), - self.left.schema().fields.len(), - ) - } - // use the right partition - JoinType::RightSemi | JoinType::RightAnti => self.right.output_partitioning(), + if self.join_type == JoinType::Full { + self.left.output_partitioning() + } else { + partitioned_join_output_partitioning( + self.join_type, + self.left.output_partitioning(), + self.right.output_partitioning(), + self.left.schema().fields.len(), + ) } } @@ -163,14 +192,15 @@ impl ExecutionPlan for NestedLoopJoinExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - &[], // empty join keys + &self.join_type, self.schema(), + &self.maintains_input_order(), + None, + // No on columns in nested loop join + &[], ) } @@ -249,32 +279,17 @@ impl ExecutionPlan for NestedLoopJoinExec { })) } - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={:?}", f.expression()), - ); - write!( - f, - "NestedLoopJoinExec: join_type={:?}{}", - self.join_type, display_filter - ) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { estimate_join_statistics( self.left.clone(), self.right.clone(), vec![], &self.join_type, + &self.schema, ) } } @@ -482,8 +497,8 @@ impl NestedLoopJoinStream { &self.schema, left_data, &empty_right_batch, - left_side, - right_side, + &left_side, + &right_side, &self.column_indices, JoinSide::Left, ); @@ -573,9 +588,9 @@ fn join_left_and_right_batch( let mut left_indices_builder = UInt64Builder::new(); let mut right_indices_builder = UInt32Builder::new(); let left_right_indices = match indices_result { - Err(err) => Err(DataFusionError::Execution(format!( - "Fail to build join indices in NestedLoopJoinExec, error:{err}" - ))), + Err(err) => { + exec_err!("Fail to build join indices in NestedLoopJoinExec, error:{err}") + } Ok(indices) => { for (left_side, right_side) in indices { left_indices_builder @@ -611,8 +626,8 @@ fn join_left_and_right_batch( schema, left_batch, right_batch, - left_side, - right_side, + &left_side, + &right_side, column_indices, JoinSide::Left, ) @@ -726,29 +741,20 @@ impl RecordBatchStream for NestedLoopJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; - use crate::physical_expr::expressions::BinaryExpr; use crate::{ - assert_batches_sorted_eq, - common::assert_contains, - execution::{ - context::SessionConfig, - runtime_env::{RuntimeConfig, RuntimeEnv}, - }, - physical_plan::{ - common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, - }, - test::{build_table_i32, columns}, + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + test::build_table_i32, }; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; - - use crate::physical_plan::joins::utils::JoinSide; - use crate::prelude::SessionContext; - use datafusion_common::ScalarValue; - use datafusion_physical_expr::expressions::Literal; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::PhysicalExpr; - use std::sync::Arc; fn build_table( a: (&str, &Vec), @@ -871,8 +877,7 @@ mod tests { #[tokio::test] async fn join_inner_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); @@ -885,7 +890,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -900,8 +905,7 @@ mod tests { #[tokio::test] async fn join_left_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -915,7 +919,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+-----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+-----+----+----+----+", @@ -932,8 +936,7 @@ mod tests { #[tokio::test] async fn join_right_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -947,7 +950,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+----+----+-----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+-----+", @@ -964,8 +967,7 @@ mod tests { #[tokio::test] async fn join_full_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -979,7 +981,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+-----+----+----+-----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+-----+----+----+-----+", @@ -998,8 +1000,7 @@ mod tests { #[tokio::test] async fn join_left_semi_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -1013,7 +1014,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - let expected = vec![ + let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", "+----+----+----+", @@ -1028,8 +1029,7 @@ mod tests { #[tokio::test] async fn join_left_anti_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -1043,7 +1043,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - let expected = vec![ + let expected = [ "+----+----+-----+", "| a1 | b1 | c1 |", "+----+----+-----+", @@ -1059,8 +1059,7 @@ mod tests { #[tokio::test] async fn join_right_semi_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -1074,7 +1073,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+----+", "| a2 | b2 | c2 |", "+----+----+----+", @@ -1089,8 +1088,7 @@ mod tests { #[tokio::test] async fn join_right_anti_with_filter() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); @@ -1104,7 +1102,7 @@ mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - let expected = vec![ + let expected = [ "+----+----+-----+", "| a2 | b2 | c2 |", "+----+----+-----+", @@ -1146,9 +1144,8 @@ mod tests { for join_type in join_types { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_ctx = - SessionContext::with_config_rt(SessionConfig::default(), runtime); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let err = multi_partitioned_join_collect( left.clone(), @@ -1169,4 +1166,9 @@ mod tests { Ok(()) } + + /// Returns the column names on the schema + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } } diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs similarity index 92% rename from datafusion/core/src/physical_plan/joins/sort_merge_join.rs rename to datafusion/physical-plan/src/joins/sort_merge_join.rs index aa6a77925e0a4..f6fdc6d77c0cb 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,45 +30,44 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::expressions::{Column, PhysicalSortExpr}; +use crate::joins::utils::{ + build_join_schema, calculate_join_output_ordering, check_join_is_valid, + estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + use arrow::array::*; use arrow::compute::{concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -use datafusion_physical_expr::PhysicalSortRequirement; -use futures::{Stream, StreamExt}; - -use crate::physical_plan::expressions::Column; -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::joins::utils::{ - build_join_schema, check_join_is_valid, combine_join_equivalence_properties, - estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, -}; -use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::physical_plan::{ - metrics, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, - Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, +use datafusion_common::{ + internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; -use datafusion_common::DataFusionError; -use datafusion_common::JoinType; -use datafusion_common::Result; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use futures::{Stream, StreamExt}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. #[derive(Debug)] pub struct SortMergeJoinExec { /// Left sorted joining execution plan - pub(crate) left: Arc, + pub left: Arc, /// Right sorting joining execution plan - pub(crate) right: Arc, + pub right: Arc, /// Set of common columns used to join on - pub(crate) on: JoinOn, + pub on: JoinOn, /// How the join is performed - pub(crate) join_type: JoinType, + pub join_type: JoinType, /// The schema once the join is applied schema: SchemaRef, /// Execution metrics @@ -80,9 +79,9 @@ pub struct SortMergeJoinExec { /// The output ordering output_ordering: Option>, /// Sort options of join columns used in sorting left and right execution plans - pub(crate) sort_options: Vec, + pub sort_options: Vec, /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, + pub null_equals_null: bool, } impl SortMergeJoinExec { @@ -102,18 +101,18 @@ impl SortMergeJoinExec { let right_schema = right.schema(); if join_type == JoinType::RightSemi { - return Err(DataFusionError::NotImplemented( - "SortMergeJoinExec does not support JoinType::RightSemi".to_string(), - )); + return not_impl_err!( + "SortMergeJoinExec does not support JoinType::RightSemi" + ); } check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Expected number of sort options: {}, actual: {}", on.len(), sort_options.len() - ))); + ); } let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on @@ -132,49 +131,15 @@ impl SortMergeJoinExec { }) .unzip(); - let output_ordering = match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti => { - left.output_ordering().map(|sort_exprs| sort_exprs.to_vec()) - } - JoinType::RightSemi | JoinType::RightAnti => right - .output_ordering() - .map(|sort_exprs| sort_exprs.to_vec()), - JoinType::Right => { - let left_columns_len = left.schema().fields.len(); - right - .output_ordering() - .map(|sort_exprs| { - let new_sort_exprs: Result> = sort_exprs - .iter() - .map(|e| { - let new_expr = - e.expr.clone().transform_down(&|e| match e - .as_any() - .downcast_ref::( - ) { - Some(col) => { - Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - left_columns_len + col.index(), - )))) - } - None => Ok(Transformed::No(e)), - }); - Ok(PhysicalSortExpr { - expr: new_expr?, - options: e.options, - }) - }) - .collect(); - new_sort_exprs - }) - .map_or(Ok(None), |v| v.map(Some))? - } - JoinType::Full => None, - }; + let output_ordering = calculate_join_output_ordering( + left.output_ordering().unwrap_or(&[]), + right.output_ordering().unwrap_or(&[]), + join_type, + &on, + left_schema.fields.len(), + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + ); let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -194,10 +159,71 @@ impl SortMergeJoinExec { }) } + /// Get probe side (e.g streaming side) information for this sort merge join. + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + // When output schema contains only the right side, probe side is right. + // Otherwise probe side is the left side. + match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinSide::Right + } + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::LeftAnti + | JoinType::LeftSemi => JoinSide::Left, + } + } + + /// Calculate order preservation flags for this sort merge join. + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + JoinType::Inner => vec![true, false], + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + vec![false, true] + } + _ => vec![false, false], + } + } + /// Set of common columns used to join on pub fn on(&self) -> &[(Column, Column)] { &self.on } + + pub fn right(&self) -> &dyn ExecutionPlan { + self.right.as_ref() + } + + pub fn join_type(&self) -> JoinType { + self.join_type + } + + pub fn left(&self) -> &dyn ExecutionPlan { + self.left.as_ref() + } +} + +impl DisplayAs for SortMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .collect::>() + .join(", "); + write!( + f, + "SortMergeJoin: join_type={:?}, on=[{}]", + self.join_type, on + ) + } + } + } } impl ExecutionPlan for SortMergeJoinExec { @@ -252,25 +278,18 @@ impl ExecutionPlan for SortMergeJoinExec { } fn maintains_input_order(&self) -> Vec { - match self.join_type { - JoinType::Inner => vec![true, true], - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - vec![false, true] - } - _ => vec![false, false], - } + Self::maintains_input_order(self.join_type) } fn equivalence_properties(&self) -> EquivalenceProperties { - let left_columns_len = self.left.schema().fields.len(); - combine_join_equivalence_properties( - self.join_type, + join_equivalence_properties( self.left.equivalence_properties(), self.right.equivalence_properties(), - left_columns_len, - self.on(), + &self.join_type, self.schema(), + &self.maintains_input_order(), + Some(Self::probe_side(&self.join_type)), + self.on(), ) } @@ -291,9 +310,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.sort_options.clone(), self.null_equals_null, )?)), - _ => Err(DataFusionError::Internal( - "SortMergeJoin wrong number of children".to_string(), - )), + _ => internal_err!("SortMergeJoin wrong number of children"), } } @@ -305,30 +322,18 @@ impl ExecutionPlan for SortMergeJoinExec { let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); if left_partitions != right_partitions { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ - consider using RepartitionExec", - ))); + consider using RepartitionExec" + ); } - - let (streamed, buffered, on_streamed, on_buffered) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::LeftAnti - | JoinType::LeftSemi => ( - self.left.clone(), - self.right.clone(), - self.on.iter().map(|on| on.0.clone()).collect(), - self.on.iter().map(|on| on.1.clone()).collect(), - ), - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => ( - self.right.clone(), - self.left.clone(), - self.on.iter().map(|on| on.1.clone()).collect(), - self.on.iter().map(|on| on.0.clone()).collect(), - ), - }; + let (on_left, on_right) = self.on.iter().cloned().unzip(); + let (streamed, buffered, on_streamed, on_buffered) = + if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { + (self.left.clone(), self.right.clone(), on_left, on_right) + } else { + (self.right.clone(), self.left.clone(), on_right, on_left) + }; // execute children plans let streamed = streamed.execute(partition, context.clone())?; @@ -361,19 +366,7 @@ impl ExecutionPlan for SortMergeJoinExec { Some(self.metrics.clone_inner()) } - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "SortMergeJoin: join_type={:?}, on={:?}", - self.join_type, self.on - ) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` @@ -382,6 +375,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.right.clone(), self.on.clone(), &self.join_type, + &self.schema, ) } } @@ -1307,9 +1301,9 @@ fn compare_join_arrays( DataType::Date32 => compare_value!(Date32Array), DataType::Date64 => compare_value!(Date64Array), _ => { - return Err(DataFusionError::NotImplemented( - "Unsupported data type in sort merge join comparator".to_owned(), - )); + return not_impl_err!( + "Unsupported data type in sort merge join comparator" + ); } } if !res.is_eq() { @@ -1373,9 +1367,9 @@ fn is_join_arrays_equal( DataType::Date32 => compare_value!(Date32Array), DataType::Date64 => compare_value!(Date64Array), _ => { - return Err(DataFusionError::NotImplemented( - "Unsupported data type in sort merge join comparator".to_owned(), - )); + return not_impl_err!( + "Unsupported data type in sort merge join comparator" + ); } } if !is_equal { @@ -1389,23 +1383,23 @@ fn is_join_arrays_equal( mod tests { use std::sync::Arc; + use crate::expressions::Column; + use crate::joins::utils::JoinOn; + use crate::joins::SortMergeJoinExec; + use crate::memory::MemoryExec; + use crate::test::build_table_i32; + use crate::{common, ExecutionPlan}; + use arrow::array::{Date32Array, Date64Array, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - - use crate::common::assert_contains; - use crate::physical_plan::expressions::Column; - use crate::physical_plan::joins::utils::JoinOn; - use crate::physical_plan::joins::SortMergeJoinExec; - use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::{common, ExecutionPlan}; - use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::{build_table_i32, columns}; - use crate::{assert_batches_eq, assert_batches_sorted_eq}; - use datafusion_common::JoinType; - use datafusion_common::Result; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, + }; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::TaskContext; fn build_table( a: (&str, &Vec), @@ -1541,8 +1535,7 @@ mod tests { sort_options: Vec, null_equals_null: bool, ) -> Result<(Vec, Vec)> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let join = join_with_options( left, right, @@ -1564,9 +1557,9 @@ mod tests { on: JoinOn, join_type: JoinType, ) -> Result<(Vec, Vec)> { - let session_ctx = - SessionContext::with_config(SessionConfig::new().with_batch_size(2)); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(2)); + let task_ctx = Arc::new(task_ctx); let join = join(left, right, on, join_type)?; let columns = columns(&join.schema()); @@ -1595,7 +1588,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1633,7 +1626,7 @@ mod tests { ]; let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1671,7 +1664,7 @@ mod tests { ]; let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1710,7 +1703,7 @@ mod tests { ]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1761,7 +1754,7 @@ mod tests { true, ) .await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1801,7 +1794,7 @@ mod tests { let (_, batches) = join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1836,7 +1829,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1868,7 +1861,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+----+----+----+----+----+----+", @@ -1900,7 +1893,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -1932,7 +1925,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; - let expected = vec![ + let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", "+----+----+----+", @@ -1963,7 +1956,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; - let expected = vec![ + let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", "+----+----+----+", @@ -1996,7 +1989,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ + let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", "+---+---+---+----+---+----+", @@ -2029,15 +2022,13 @@ mod tests { let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ - "+------------+------------+------------+------------+------------+------------+", + let expected = ["+------------+------------+------------+------------+------------+------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+------------+------------+------------+------------+------------+------------+", "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "+------------+------------+------------+------------+------------+------------+", - ]; + "+------------+------------+------------+------------+------------+------------+"]; // The output order is important as SMJ preserves sortedness assert_batches_eq!(expected, &batches); Ok(()) @@ -2063,15 +2054,13 @@ mod tests { let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; - let expected = vec![ - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - ]; + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+"]; // The output order is important as SMJ preserves sortedness assert_batches_eq!(expected, &batches); Ok(()) @@ -2095,7 +2084,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -2131,7 +2120,7 @@ mod tests { )]; let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; - let expected = vec![ + let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", "+----+----+----+----+----+----+", @@ -2325,8 +2314,12 @@ mod tests { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); let session_config = SessionConfig::default().with_batch_size(50); - let session_ctx = SessionContext::with_config_rt(session_config, runtime); - let task_ctx = session_ctx.task_ctx(); + + let task_ctx = TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( left.clone(), right.clone(), @@ -2401,8 +2394,10 @@ mod tests { let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); let session_config = SessionConfig::default().with_batch_size(50); - let session_ctx = SessionContext::with_config_rt(session_config, runtime); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); let join = join_with_options( left.clone(), right.clone(), @@ -2424,4 +2419,8 @@ mod tests { Ok(()) } + /// Returns the column names on the schema + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs new file mode 100644 index 0000000000000..5083f96b01fb1 --- /dev/null +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -0,0 +1,1414 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file contains common subroutines for symmetric hash join +//! related functionality, used both in join calculations and optimization rules. + +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::usize; + +use crate::handle_async_state; +use crate::joins::utils::{JoinFilter, JoinHashMapType}; + +use arrow::compute::concat_batches; +use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; +use arrow_schema::{Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + +use futures::{ready, FutureExt, StreamExt}; +use hashbrown::raw::RawTable; +use hashbrown::HashSet; + +/// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. +impl JoinHashMapType for PruningJoinHashMap { + type NextType = VecDeque; + + // Extend with zero + fn extend_zero(&mut self, len: usize) { + self.next.resize(self.next.len() + len, 0) + } + + /// Get mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { + (&mut self.map, &mut self.next) + } + + /// Get a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)> { + &self.map + } + + /// Get a reference to the next. + fn get_list(&self) -> &Self::NextType { + &self.next + } +} + +/// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with +/// the capability of pruning elements in an efficient manner. This structure +/// is particularly useful for cases where it's necessary to remove elements +/// from the map based on their buffer order. +/// +/// # Example +/// +/// ``` text +/// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would +/// handle the pruning scenario. +/// +/// Insert the pair (10,4) into the `PruningJoinHashMap`: +/// map: +/// ---------- +/// | 10 | 5 | +/// | 20 | 3 | +/// ---------- +/// list: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// +/// Now, let's prune 3 rows from `PruningJoinHashMap`: +/// map: +/// --------- +/// | 1 | 5 | +/// --------- +/// list: +/// --------- +/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) +/// --------- +/// +/// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since +/// there are no values left for this key. +/// ``` +pub struct PruningJoinHashMap { + /// Stores hash value to last row index + pub map: RawTable<(u64, u64)>, + /// Stores indices in chained list data structure + pub next: VecDeque, +} + +impl PruningJoinHashMap { + /// Constructs a new `PruningJoinHashMap` with the given capacity. + /// Both the map and the list are pre-allocated with the provided capacity. + /// + /// # Arguments + /// * `capacity`: The initial capacity of the hash map. + /// + /// # Returns + /// A new instance of `PruningJoinHashMap`. + pub(crate) fn with_capacity(capacity: usize) -> Self { + PruningJoinHashMap { + map: RawTable::with_capacity(capacity), + next: VecDeque::with_capacity(capacity), + } + } + + /// Shrinks the capacity of the hash map, if necessary, based on the + /// provided scale factor. + /// + /// # Arguments + /// * `scale_factor`: The scale factor that determines how conservative the + /// shrinking strategy is. The capacity will be reduced by 1/`scale_factor` + /// when necessary. + /// + /// # Note + /// Increasing the scale factor results in less aggressive capacity shrinking, + /// leading to potentially higher memory usage but fewer resizes. Conversely, + /// decreasing the scale factor results in more aggressive capacity shrinking, + /// potentially leading to lower memory usage but more frequent resizing. + pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) { + let capacity = self.map.capacity(); + + if capacity > scale_factor * self.map.len() { + let new_capacity = (capacity * (scale_factor - 1)) / scale_factor; + // Resize the map with the new capacity. + self.map.shrink_to(new_capacity, |(hash, _)| *hash) + } + } + + /// Calculates the size of the `PruningJoinHashMap` in bytes. + /// + /// # Returns + /// The size of the hash map in bytes. + pub(crate) fn size(&self) -> usize { + self.map.allocation_info().1.size() + + self.next.capacity() * std::mem::size_of::() + } + + /// Removes hash values from the map and the list based on the given pruning + /// length and deleting offset. + /// + /// # Arguments + /// * `prune_length`: The number of elements to remove from the list. + /// * `deleting_offset`: The offset used to determine which hash values to remove from the map. + /// + /// # Returns + /// A `Result` indicating whether the operation was successful. + pub(crate) fn prune_hash_values( + &mut self, + prune_length: usize, + deleting_offset: u64, + shrink_factor: usize, + ) -> Result<()> { + // Remove elements from the list based on the pruning length. + self.next.drain(0..prune_length); + + // Calculate the keys that should be removed from the map. + let removable_keys = unsafe { + self.map + .iter() + .map(|bucket| bucket.as_ref()) + .filter_map(|(hash, tail_index)| { + (*tail_index < prune_length as u64 + deleting_offset).then_some(*hash) + }) + .collect::>() + }; + + // Remove the keys from the map. + removable_keys.into_iter().for_each(|hash_value| { + self.map + .remove_entry(hash_value, |(hash, _)| hash_value == *hash); + }); + + // Shrink the map if necessary. + self.shrink_if_necessary(shrink_factor); + Ok(()) + } +} + +pub fn check_filter_expr_contains_sort_information( + expr: &Arc, + reference: &Arc, +) -> bool { + expr.eq(reference) + || expr + .children() + .iter() + .any(|e| check_filter_expr_contains_sort_information(e, reference)) +} + +/// Create a one to one mapping from main columns to filter columns using +/// filter column indices. A column index looks like: +/// ```text +/// ColumnIndex { +/// index: 0, // field index in main schema +/// side: JoinSide::Left, // child side +/// } +/// ``` +pub fn map_origin_col_to_filter_col( + filter: &JoinFilter, + schema: &SchemaRef, + side: &JoinSide, +) -> Result> { + let filter_schema = filter.schema(); + let mut col_to_col_map: HashMap = HashMap::new(); + for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { + if index.side.eq(side) { + // Get the main field from column index: + let main_field = schema.field(index.index); + // Create a column expression: + let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?; + // Since the order of by filter.column_indices() is the same with + // that of intermediate schema fields, we can get the column directly. + let filter_field = filter_schema.field(filter_schema_index); + let filter_col = Column::new(filter_field.name(), filter_schema_index); + // Insert mapping: + col_to_col_map.insert(main_col, filter_col); + } + } + Ok(col_to_col_map) +} + +/// This function analyzes [`PhysicalSortExpr`] graphs with respect to monotonicity +/// (sorting) properties. This is necessary since monotonically increasing and/or +/// decreasing expressions are required when using join filter expressions for +/// data pruning purposes. +/// +/// The method works as follows: +/// 1. Maps the original columns to the filter columns using the [`map_origin_col_to_filter_col`] function. +/// 2. Collects all columns in the sort expression using the [`collect_columns`] function. +/// 3. Checks if all columns are included in the map we obtain in the first step. +/// 4. If all columns are included, the sort expression is converted into a filter expression using +/// the [`convert_filter_columns`] function. +/// 5. Searches for the converted filter expression in the filter expression using the +/// [`check_filter_expr_contains_sort_information`] function. +/// 6. If an exact match is found, returns the converted filter expression as [`Some(Arc)`]. +/// 7. If all columns are not included or an exact match is not found, returns [`None`]. +/// +/// Examples: +/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". +/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, +/// there is no exact match, so this expression does not indicate pruning. +pub fn convert_sort_expr_with_filter_schema( + side: &JoinSide, + filter: &JoinFilter, + schema: &SchemaRef, + sort_expr: &PhysicalSortExpr, +) -> Result>> { + let column_map = map_origin_col_to_filter_col(filter, schema, side)?; + let expr = sort_expr.expr.clone(); + // Get main schema columns: + let expr_columns = collect_columns(&expr); + // Calculation is possible with `column_map` since sort exprs belong to a child. + let all_columns_are_included = + expr_columns.iter().all(|col| column_map.contains_key(col)); + if all_columns_are_included { + // Since we are sure that one to one column mapping includes all columns, we convert + // the sort expression into a filter expression. + let converted_filter_expr = expr.transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::Yes(transformed), + None => Transformed::No(p), + } + }) + })?; + // Search the converted `PhysicalExpr` in filter expression; if an exact + // match is found, use this sorted expression in graph traversals. + if check_filter_expr_contains_sort_information( + filter.expression(), + &converted_filter_expr, + ) { + return Ok(Some(converted_filter_expr)); + } + } + Ok(None) +} + +/// This function is used to build the filter expression based on the sort order of input columns. +/// +/// It first calls the [`convert_sort_expr_with_filter_schema`] method to determine if the sort +/// order of columns can be used in the filter expression. If it returns a [`Some`] value, the +/// method wraps the result in a [`SortedFilterExpr`] instance with the original sort expression and +/// the converted filter expression. Otherwise, this function returns an error. +/// +/// The `SortedFilterExpr` instance contains information about the sort order of columns that can +/// be used in the filter expression, which can be used to optimize the query execution process. +pub fn build_filter_input_order( + side: JoinSide, + filter: &JoinFilter, + schema: &SchemaRef, + order: &PhysicalSortExpr, +) -> Result> { + let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?; + opt_expr + .map(|filter_expr| { + SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema()) + }) + .transpose() +} + +/// Convert a physical expression into a filter expression using the given +/// column mapping information. +fn convert_filter_columns( + input: &dyn PhysicalExpr, + column_map: &HashMap, +) -> Result>> { + // Attempt to downcast the input expression to a Column type. + Ok(if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } else { + // If the downcast fails, return the input expression as is. + None + }) +} + +/// The [SortedFilterExpr] object represents a sorted filter expression. It +/// contains the following information: The origin expression, the filter +/// expression, an interval encapsulating expression bounds, and a stable +/// index identifying the expression in the expression DAG. +/// +/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides +/// and uses new column names. In this process, a column exchange is done so +/// we can utilize sorting information while traversing the filter expression +/// DAG for interval calculations. When evaluating the inner buffer, we use +/// `origin_sorted_expr`. +#[derive(Debug, Clone)] +pub struct SortedFilterExpr { + /// Sorted expression from a join side (i.e. a child of the join) + origin_sorted_expr: PhysicalSortExpr, + /// Expression adjusted for filter schema. + filter_expr: Arc, + /// Interval containing expression bounds + interval: Interval, + /// Node index in the expression DAG + node_index: usize, +} + +impl SortedFilterExpr { + /// Constructor + pub fn try_new( + origin_sorted_expr: PhysicalSortExpr, + filter_expr: Arc, + filter_schema: &Schema, + ) -> Result { + let dt = &filter_expr.data_type(filter_schema)?; + Ok(Self { + origin_sorted_expr, + filter_expr, + interval: Interval::make_unbounded(dt)?, + node_index: 0, + }) + } + /// Get origin expr information + pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { + &self.origin_sorted_expr + } + /// Get filter expr information + pub fn filter_expr(&self) -> &Arc { + &self.filter_expr + } + /// Get interval information + pub fn interval(&self) -> &Interval { + &self.interval + } + /// Sets interval + pub fn set_interval(&mut self, interval: Interval) { + self.interval = interval; + } + /// Node index in ExprIntervalGraph + pub fn node_index(&self) -> usize { + self.node_index + } + /// Node index setter in ExprIntervalGraph + pub fn set_node_index(&mut self, node_index: usize) { + self.node_index = node_index; + } +} + +/// Calculate the filter expression intervals. +/// +/// This function updates the `interval` field of each `SortedFilterExpr` based +/// on the first or the last value of the expression in `build_input_buffer` +/// and `probe_batch`. +/// +/// # Arguments +/// +/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. +/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. +/// * `probe_batch` - The `RecordBatch` on the probe side of the join. +/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. +/// +/// ### Note +/// ```text +/// +/// Interval arithmetic is used to calculate viable join ranges for build-side +/// pruning. This is done by first creating an interval for join filter values in +/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the +/// ordering (descending/ascending) of the filter expression. Here, FV denotes the +/// first value on the build side. This range is then compared with the probe side +/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering +/// (ascending/descending) of the probe side. Here, LV denotes the last value on +/// the probe side. +/// +/// As a concrete example, consider the following query: +/// +/// SELECT * FROM left_table, right_table +/// WHERE +/// left_key = right_key AND +/// a > b - 3 AND +/// a < b + 10 +/// +/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// respectively. When a new `RecordBatch` arrives at the right side, the +/// condition a > b - 3 will possibly indicate a prunable range for the left +/// side. Conversely, when a new `RecordBatch` arrives at the left side, the +/// condition a < b + 10 will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// side (i.e. when the left side is the build side): +/// +/// Build Probe +/// +-------+ +-------+ +/// | a | z | | b | y | +/// |+--|--+| |+--|--+| +/// | 1 | 2 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 3 | 1 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 5 | 7 | | 6 | 1 | +/// |+--|--+| |+--|--+| +/// | 7 | 1 | | 6 | 3 | +/// +-------+ +-------+ +/// +/// In this case, the interval representing viable (i.e. joinable) values for +/// column "a" is [1, ∞], and the interval representing possible future values +/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// intervals for the whole filter expression and propagate join constraint by +/// traversing the expression graph. +/// ``` +pub fn calculate_filter_expr_intervals( + build_input_buffer: &RecordBatch, + build_sorted_filter_expr: &mut SortedFilterExpr, + probe_batch: &RecordBatch, + probe_sorted_filter_expr: &mut SortedFilterExpr, +) -> Result<()> { + // If either build or probe side has no data, return early: + if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(()); + } + // Calculate the interval for the build side filter expression (if present): + update_filter_expr_interval( + &build_input_buffer.slice(0, 1), + build_sorted_filter_expr, + )?; + // Calculate the interval for the probe side filter expression (if present): + update_filter_expr_interval( + &probe_batch.slice(probe_batch.num_rows() - 1, 1), + probe_sorted_filter_expr, + ) +} + +/// This is a subroutine of the function [`calculate_filter_expr_intervals`]. +/// It constructs the current interval using the given `batch` and updates +/// the filter expression (i.e. `sorted_expr`) with this interval. +pub fn update_filter_expr_interval( + batch: &RecordBatch, + sorted_expr: &mut SortedFilterExpr, +) -> Result<()> { + // Evaluate the filter expression and convert the result to an array: + let array = sorted_expr + .origin_sorted_expr() + .expr + .evaluate(batch)? + .into_array(1)?; + // Convert the array to a ScalarValue: + let value = ScalarValue::try_from_array(&array, 0)?; + // Create a ScalarValue representing positive or negative infinity for the same data type: + let inf = ScalarValue::try_from(value.data_type())?; + // Update the interval with lower and upper bounds based on the sort option: + let interval = if sorted_expr.origin_sorted_expr().options.descending { + Interval::try_new(inf, value)? + } else { + Interval::try_new(value, inf)? + }; + // Set the calculated interval for the sorted filter expression: + sorted_expr.set_interval(interval); + Ok(()) +} + +/// Get the anti join indices from the visited hash set. +/// +/// This method returns the indices from the original input that were not present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A `PrimitiveArray` of the anti join indices. +pub fn get_pruning_anti_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + for v in 0..prune_length { + let row = v + deleted_offset; + bitmap.set_bit(v, visited_rows.contains(&row)); + } + // get the anti index + (0..prune_length) + .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect() +} + +/// This method creates a boolean buffer from the visited rows hash set +/// and the indices of the pruned record batch slice. +/// +/// It gets the indices from the original input that were present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A [PrimitiveArray] of the specified type T, containing the semi indices. +pub fn get_pruning_semi_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + (0..prune_length).for_each(|v| { + let row = &(v + deleted_offset); + bitmap.set_bit(v, visited_rows.contains(row)); + }); + // get the semi index + (0..prune_length) + .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect::>() +} + +pub fn combine_two_batches( + output_schema: &SchemaRef, + left_batch: Option, + right_batch: Option, +) -> Result> { + match (left_batch, right_batch) { + (Some(batch), None) | (None, Some(batch)) => { + // If only one of the batches are present, return it: + Ok(Some(batch)) + } + (Some(left_batch), Some(right_batch)) => { + // If both batches are present, concatenate them: + concat_batches(output_schema, &[left_batch, right_batch]) + .map_err(DataFusionError::ArrowError) + .map(Some) + } + (None, None) => { + // If neither is present, return an empty batch: + Ok(None) + } + } +} + +/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. +/// This function will insert the indices (offset by `offset`) into the `visited` hash set. +/// +/// # Arguments +/// +/// * `visited` - A hash set to store the visited indices. +/// * `offset` - An offset to the indices in the `PrimitiveArray`. +/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. +/// +pub fn record_visited_indices( + visited: &mut HashSet, + offset: usize, + indices: &PrimitiveArray, +) { + for i in indices.values() { + visited.insert(i.as_usize() + offset); + } +} + +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, typically encountered in implementations of `EagerJoinStream`. It +/// operates on a `StreamJoinStateResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. +/// +/// # Cases +/// +/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StreamJoinStateResult::Continue) => continue, + Ok(StreamJoinStateResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; +} + +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. +/// +/// # Arguments +/// +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of a stateful operation on `EagerJoinStream`. +/// +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StreamJoinStateResult { + Ready(T), + Continue, +} + +/// Represents the various states of an eager join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum EagerJoinStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + +/// `EagerJoinStream` is an asynchronous trait designed for managing incremental +/// join operations between two streams, such as those used in `SymmetricHashJoinExec` +/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan +/// one side of the join fully before proceeding, `EagerJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// Implementors of this trait can perform eager joins of data from two different +/// asynchronous streams, typically referred to as left and right streams. The +/// trait provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This trait's default implementations use a state machine approach to navigate +/// different stages of the join operation, handling data from both streams and +/// determining when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +#[async_trait] +pub trait EagerJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> + where + Self: Send, + { + loop { + return match self.state() { + EagerJoinStreamState::PullRight => { + handle_async_state!(self.fetch_next_from_right_stream(), cx) + } + EagerJoinStreamState::PullLeft => { + handle_async_state!(self.fetch_next_from_left_stream(), cx) + } + EagerJoinStreamState::RightExhausted => { + handle_async_state!(self.handle_right_stream_end(), cx) + } + EagerJoinStreamState::LeftExhausted => { + handle_async_state!(self.handle_left_stream_end(), cx) + } + EagerJoinStreamState::BothExhausted { + final_result: false, + } => { + handle_state!(self.prepare_for_final_results_after_exhaustion()) + } + EagerJoinStreamState::BothExhausted { final_result: true } => { + Poll::Ready(None) + } + }; + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_right_stream( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => { + self.set_state(EagerJoinStreamState::PullLeft); + self.process_batch_from_right(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::RightExhausted); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + async fn fetch_next_from_left_stream( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => { + self.set_state(EagerJoinStreamState::PullRight); + self.process_batch_from_left(batch) + } + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::LeftExhausted); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_right_stream_end( + &mut self, + ) -> Result>> { + match self.left_stream().next().await { + Some(Ok(batch)) => self.process_batch_after_right_end(batch), + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + async fn handle_left_stream_end( + &mut self, + ) -> Result>> { + match self.right_stream().next().await { + Some(Ok(batch)) => self.process_batch_after_left_end(batch), + Some(Err(e)) => Err(e), + None => { + self.set_state(EagerJoinStreamState::BothExhausted { + final_result: false, + }); + Ok(StreamJoinStateResult::Continue) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + + /// Handles a pulled batch from the right stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles a pulled batch from the left stream. + /// + /// # Arguments + /// + /// * `batch` - The pulled `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after processing the batch. + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the left stream is exhausted. + /// + /// # Arguments + /// + /// * `right_batch` - The `RecordBatch` from the right stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the left stream is exhausted. + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>>; + + /// Handles the situation when only the right stream is exhausted. + /// + /// # Arguments + /// + /// * `left_batch` - The `RecordBatch` from the left stream. + /// + /// # Returns + /// + /// * `Result>>` - The state result after the right stream is exhausted. + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>>; + + /// Handles the final state after both streams are exhausted. + /// + /// # Returns + /// + /// * `Result>>` - The final state result after processing. + fn process_batches_before_finalization( + &mut self, + ) -> Result>>; + + /// Provides mutable access to the right stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. + fn right_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Provides mutable access to the left stream. + /// + /// # Returns + /// + /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. + fn left_stream(&mut self) -> &mut SendableRecordBatchStream; + + /// Sets the current state of the join stream. + /// + /// # Arguments + /// + /// * `state` - The new state to be set. + fn set_state(&mut self, state: EagerJoinStreamState); + + /// Fetches the current state of the join stream. + /// + /// # Returns + /// + /// * `EagerJoinStreamState` - The current state of the join stream. + fn state(&mut self) -> EagerJoinStreamState; +} + +#[cfg(test)] +pub mod tests { + use std::sync::Arc; + + use super::*; + use crate::joins::stream_join_utils::{ + build_filter_input_order, check_filter_expr_contains_sort_information, + convert_sort_expr_with_filter_schema, PruningJoinHashMap, + }; + use crate::{ + expressions::{Column, PhysicalSortExpr}, + joins::utils::{ColumnIndex, JoinFilter}, + }; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{JoinSide, ScalarValue}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{binary, cast, col, lit}; + + /// Filter expr for a + b > c + 10 AND a + b < c + 100 + pub(crate) fn complicated_filter( + filter_schema: &Schema, + ) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) + } + + #[test] + fn test_column_exchange() -> Result<()> { + let left_child_schema = + Schema::new(vec![Field::new("left_1", DataType::Int32, true)]); + // Sorting information for the left side: + let left_child_sort_expr = PhysicalSortExpr { + expr: col("left_1", &left_child_schema)?, + options: SortOptions::default(), + }; + + let right_child_schema = Schema::new(vec![ + Field::new("right_1", DataType::Int32, true), + Field::new("right_2", DataType::Int32, true), + ]); + // Sorting information for the right side: + let right_child_sort_expr = PhysicalSortExpr { + expr: binary( + col("right_1", &right_child_schema)?, + Operator::Plus, + col("right_2", &right_child_schema)?, + &right_child_schema, + )?, + options: SortOptions::default(), + }; + + let intermediate_schema = Schema::new(vec![ + Field::new("filter_1", DataType::Int32, true), + Field::new("filter_2", DataType::Int32, true), + Field::new("filter_3", DataType::Int32, true), + ]); + // Our filter expression is: left_1 > right_1 + right_2. + let filter_left = col("filter_1", &intermediate_schema)?; + let filter_right = binary( + col("filter_2", &intermediate_schema)?, + Operator::Plus, + col("filter_3", &intermediate_schema)?, + &intermediate_schema, + )?; + let filter_expr = binary( + filter_left.clone(), + Operator::Gt, + filter_right.clone(), + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_sort_filter_expr = build_filter_input_order( + JoinSide::Left, + &filter, + &Arc::new(left_child_schema), + &left_child_sort_expr, + )? + .unwrap(); + assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); + + let right_sort_filter_expr = build_filter_input_order( + JoinSide::Right, + &filter, + &Arc::new(right_child_schema), + &right_child_sort_expr, + )? + .unwrap(); + assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); + + // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`: + assert!(filter_left.eq(left_sort_filter_expr.filter_expr())); + // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`: + assert!(filter_right.eq(right_sort_filter_expr.filter_expr())); + Ok(()) + } + + #[test] + fn test_column_collector() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&schema)?; + let columns = collect_columns(&filter_expr); + assert_eq!(columns.len(), 3); + Ok(()) + } + + #[test] + fn find_expr_inside_expr() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&schema)?; + + let expr_1 = Arc::new(Column::new("gnz", 0)) as _; + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_1 + )); + + let expr_2 = col("1", &schema)? as _; + + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_2 + )); + + let expr_3 = cast( + binary( + col("0", &schema)?, + Operator::Plus, + col("1", &schema)?, + &schema, + )?, + &schema, + DataType::Int64, + )?; + + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_3 + )); + + let expr_4 = Arc::new(Column::new("1", 42)) as _; + + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_4, + )); + Ok(()) + } + + #[test] + fn build_sorted_expr() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("la1", DataType::Int32, false), + Field::new("lb1", DataType::Int32, false), + Field::new("lc1", DataType::Int32, false), + Field::new("lt1", DataType::Int32, false), + Field::new("la2", DataType::Int32, false), + Field::new("la1_des", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![ + Field::new("ra1", DataType::Int32, false), + Field::new("rb1", DataType::Int32, false), + Field::new("rc1", DataType::Int32, false), + Field::new("rt1", DataType::Int32, false), + Field::new("ra2", DataType::Int32, false), + Field::new("ra1_des", DataType::Int32, false), + ]); + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_schema = Arc::new(left_schema); + let right_schema = Arc::new(right_schema); + + assert!(build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("la1", left_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_some()); + assert!(build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("lt1", left_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_none()); + assert!(build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("ra1", right_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_some()); + assert!(build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("rb1", right_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_none()); + + Ok(()) + } + + // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b". + #[test] + fn sorted_filter_expr_build() -> Result<()> { + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + ]); + let filter_expr = binary( + col("0", &intermediate_schema)?, + Operator::Minus, + col("1", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let sorted = PhysicalSortExpr { + expr: binary( + col("a", &schema)?, + Operator::Plus, + col("b", &schema)?, + &schema, + )?, + options: SortOptions::default(), + }; + + let res = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + &filter, + &Arc::new(schema), + &sorted, + )?; + assert!(res.is_none()); + Ok(()) + } + + #[test] + fn test_shrink_if_necessary() { + let scale_factor = 4; + let mut join_hash_map = PruningJoinHashMap::with_capacity(100); + let data_size = 2000; + let deleted_part = 3 * data_size / 4; + // Add elements to the JoinHashMap + for hash_value in 0..data_size { + join_hash_map.map.insert( + hash_value, + (hash_value, hash_value), + |(hash, _)| *hash, + ); + } + + assert_eq!(join_hash_map.map.len(), data_size as usize); + assert!(join_hash_map.map.capacity() >= data_size as usize); + + // Remove some elements from the JoinHashMap + for hash_value in 0..deleted_part { + join_hash_map + .map + .remove_entry(hash_value, |(hash, _)| hash_value == *hash); + } + + assert_eq!(join_hash_map.map.len(), (data_size - deleted_part) as usize); + + // Old capacity + let old_capacity = join_hash_map.map.capacity(); + + // Test shrink_if_necessary + join_hash_map.shrink_if_necessary(scale_factor); + + // The capacity should be reduced by the scale factor + let new_expected_capacity = + join_hash_map.map.capacity() * (scale_factor - 1) / scale_factor; + assert!(join_hash_map.map.capacity() >= new_expected_capacity); + assert!(join_hash_map.map.capacity() <= old_capacity); + } +} diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs new file mode 100644 index 0000000000000..95f15877b9607 --- /dev/null +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -0,0 +1,2167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the symmetric hash join algorithm with range-based +//! data pruning to join two (potentially infinite) streams. +//! +//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate +//! output ordering) and produces the join output according to the given join +//! type and other options. +//! +//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations +//! for both its children. + +use std::any::Any; +use std::fmt::{self, Debug}; +use std::sync::Arc; +use std::task::Poll; +use std::{usize, vec}; + +use crate::common::SharedMemoryReservation; +use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash}; +use crate::joins::stream_join_utils::{ + calculate_filter_expr_intervals, combine_two_batches, + convert_sort_expr_with_filter_schema, get_pruning_anti_indices, + get_pruning_semi_indices, record_visited_indices, EagerJoinStream, + EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinStateResult, +}; +use crate::joins::utils::{ + build_batch_from_indices, build_join_schema, check_join_is_valid, + partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, + JoinOn, +}; +use crate::{ + expressions::{Column, PhysicalSortExpr}, + joins::StreamJoinPartitionMode, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, + Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + +use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder}; +use arrow::compute::concat_batches; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::utils::bisect; +use datafusion_common::{ + internal_err, plan_err, DataFusionError, JoinSide, JoinType, Result, +}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; + +use ahash::RandomState; +use futures::Stream; +use hashbrown::HashSet; +use parking_lot::Mutex; + +const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; + +/// A symmetric hash join with range conditions is when both streams are hashed on the +/// join key and the resulting hash tables are used to join the streams. +/// The join is considered symmetric because the hash table is built on the join keys from both +/// streams, and the matching of rows is based on the values of the join keys in both streams. +/// This type of join is efficient in streaming context as it allows for fast lookups in the hash +/// table, rather than having to scan through one or both of the streams to find matching rows, also it +/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions), +/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming +/// data without any memory issues. +/// +/// For each input stream, create a hash table. +/// - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets. +/// - Test if input is equal to a predefined set of other inputs. +/// - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch]. +/// - Try to prune other side (probe) with new [RecordBatch]. +/// - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.), +/// output the [RecordBatch] when a pruning happens or at the end of the data. +/// +/// +/// ``` text +/// +-------------------------+ +/// | | +/// left stream ---------| Left OneSideHashJoiner |---+ +/// | | | +/// +-------------------------+ | +/// | +/// |--------- Joined output +/// | +/// +-------------------------+ | +/// | | | +/// right stream ---------| Right OneSideHashJoiner |---+ +/// | | +/// +-------------------------+ +/// +/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic +/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range. +/// +/// +/// PROBE SIDE BUILD SIDE +/// BUFFER BUFFER +/// +-------------+ +------------+ +/// | | | | Unjoinable +/// | | | | Range +/// | | | | +/// | | |--------------------------------- +/// | | | | | +/// | | | | | +/// | | / | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | Joinable +/// | |/ | | Range +/// | || | | +/// |+-----------+|| | | +/// || Record || | | +/// || Batch || | | +/// |+-----------+|| | | +/// +-------------+\ +------------+ +/// | +/// \ +/// |--------------------------------- +/// +/// This happens when range conditions are provided on sorted columns. E.g. +/// +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR +/// +/// or +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10 +/// +/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to +/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the +/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios) +/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning +/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" , +/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) +/// than that can be dropped from the inner buffer. +/// ``` +#[derive(Debug)] +pub struct SymmetricHashJoinExec { + /// Left side stream + pub(crate) left: Arc, + /// Right side stream + pub(crate) right: Arc, + /// Set of common columns used to join on + pub(crate) on: Vec<(Column, Column)>, + /// Filters applied when finding matching rows + pub(crate) filter: Option, + /// How the join is performed + pub(crate) join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Shares the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// If null_equals_null is true, null == null else null != null + pub(crate) null_equals_null: bool, + /// Partition Mode + mode: StreamJoinPartitionMode, +} + +#[derive(Debug)] +pub struct StreamJoinSideMetrics { + /// Number of batches consumed by this operator + pub(crate) input_batches: metrics::Count, + /// Number of rows consumed by this operator + pub(crate) input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +pub struct StreamJoinMetrics { + /// Number of left batches/rows consumed by this operator + pub(crate) left: StreamJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + pub(crate) right: StreamJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, +} + +impl StreamJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = StreamJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + +impl SymmetricHashJoinExec { + /// Tries to create a new [SymmetricHashJoinExec]. + /// # Error + /// This function errors when: + /// - It is not possible to join the left and right sides on keys `on`, or + /// - It fails to construct `SortedFilterExpr`s, or + /// - It fails to create the [ExprIntervalGraph]. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + null_equals_null: bool, + mode: StreamJoinPartitionMode, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Error out if no "on" contraints are given: + if on.is_empty() { + return plan_err!( + "On constraints in SymmetricHashJoinExec should be non-empty" + ); + } + + // Check if the join is valid with the given on constraints: + check_join_is_valid(&left_schema, &right_schema, &on)?; + + // Build the join schema from the left and right schemas: + let (schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + // Initialize the random state for the join operation: + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + Ok(SymmetricHashJoinExec { + left, + right, + on, + filter, + join_type: *join_type, + schema: Arc::new(schema), + random_state, + metrics: ExecutionPlanMetricsSet::new(), + column_indices, + null_equals_null, + mode, + }) + } + + /// left stream + pub fn left(&self) -> &Arc { + &self.left + } + + /// right stream + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> Option<&JoinFilter> { + self.filter.as_ref() + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// Get null_equals_null + pub fn null_equals_null(&self) -> bool { + self.null_equals_null + } + + /// Get partition mode + pub fn partition_mode(&self) -> StreamJoinPartitionMode { + self.mode + } + + /// Check if order information covers every column in the filter expression. + pub fn check_if_order_information_available(&self) -> Result { + if let Some(filter) = self.filter() { + let left = self.left(); + if let Some(left_ordering) = left.output_ordering() { + let right = self.right(); + if let Some(right_ordering) = right.output_ordering() { + let left_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + filter, + &left.schema(), + &left_ordering[0], + )? + .is_some(); + let right_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Right, + filter, + &right.schema(), + &right_ordering[0], + )? + .is_some(); + return Ok(left_convertible && right_convertible); + } + } + } + Ok(false) + } +} + +impl DisplayAs for SymmetricHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .collect::>() + .join(", "); + write!( + f, + "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}", + self.mode, self.join_type, on, display_filter + ) + } + } + } +} + +impl ExecutionPlan for SymmetricHashJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children.iter().any(|u| *u)) + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false, false] + } + + fn required_input_distribution(&self) -> Vec { + match self.mode { + StreamJoinPartitionMode::Partitioned => { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + StreamJoinPartitionMode::SinglePartition => { + vec![Distribution::SinglePartition, Distribution::SinglePartition] + } + } + } + + fn output_partitioning(&self) -> Partitioning { + let left_columns_len = self.left.schema().fields.len(); + partitioned_join_output_partitioning( + self.join_type, + self.left.output_partitioning(), + self.right.output_partitioning(), + left_columns_len, + ) + } + + // TODO: Output ordering might be kept for some cases. + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + join_equivalence_properties( + self.left.equivalence_properties(), + self.right.equivalence_properties(), + &self.join_type, + self.schema(), + &self.maintains_input_order(), + // Has alternating probe side + None, + self.on(), + ) + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(SymmetricHashJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + self.filter.clone(), + &self.join_type, + self.null_equals_null, + self.mode, + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + // TODO stats: it is not possible in general to know the output size of joins + Ok(Statistics::new_unknown(&self.schema())) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + if left_partitions != right_partitions { + return internal_err!( + "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ + consider using RepartitionExec" + ); + } + // If `filter_state` and `filter` are both present, then calculate sorted filter expressions + // for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left.output_ordering(), + self.right.output_ordering(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None for all three values: + _ => (None, None, None), + }; + + let (on_left, on_right) = self.on.iter().cloned().unzip(); + + let left_side_joiner = + OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema()); + let right_side_joiner = + OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); + + let left_stream = self.left.execute(partition, context.clone())?; + + let right_stream = self.right.execute(partition, context.clone())?; + + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) + .register(context.memory_pool()), + )); + if let Some(g) = graph.as_ref() { + reservation.lock().try_grow(g.size())?; + } + + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: EagerJoinStreamState::PullRight, + reservation, + })) + } +} + +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +struct SymmetricHashJoinStream { + /// Input streams + left_stream: SendableRecordBatchStream, + right_stream: SendableRecordBatchStream, + /// Input schema + schema: Arc, + /// join filter + filter: Option, + /// type of the join + join_type: JoinType, + // left hash joiner + left: OneSideHashJoiner, + /// right hash joiner + right: OneSideHashJoiner, + /// Information of index and left / right placement of columns + column_indices: Vec, + // Expression graph for range pruning. + graph: Option, + // Left globally sorted filter expr + left_sorted_filter_expr: Option, + // Right globally sorted filter expr + right_sorted_filter_expr: Option, + /// Random state used for hashing initialization + random_state: RandomState, + /// If null_equals_null is true, null == null else null != null + null_equals_null: bool, + /// Metrics + metrics: StreamJoinMetrics, + /// Memory reservation + reservation: SharedMemoryReservation, + /// State machine for input execution + state: EagerJoinStreamState, +} + +impl RecordBatchStream for SymmetricHashJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SymmetricHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +/// Determine the pruning length for `buffer`. +/// +/// This function evaluates the build side filter expression, converts the +/// result into an array and determines the pruning length by performing a +/// binary search on the array. +/// +/// # Arguments +/// +/// * `buffer`: The record batch to be pruned. +/// * `build_side_filter_expr`: The filter expression on the build side used +/// to determine the pruning length. +/// +/// # Returns +/// +/// A [Result] object that contains the pruning length. The function will return +/// an error if +/// - there is an issue evaluating the build side filter expression; +/// - there is an issue converting the build side filter expression into an array +fn determine_prune_length( + buffer: &RecordBatch, + build_side_filter_expr: &SortedFilterExpr, +) -> Result { + let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr(); + let interval = build_side_filter_expr.interval(); + // Evaluate the build side filter expression and convert it into an array + let batch_arr = origin_sorted_expr + .expr + .evaluate(buffer)? + .into_array(buffer.num_rows())?; + + // Get the lower or upper interval based on the sort direction + let target = if origin_sorted_expr.options.descending { + interval.upper().clone() + } else { + interval.lower().clone() + }; + + // Perform binary search on the array to determine the length of the record batch to be pruned + bisect::(&[batch_arr], &[target], &[origin_sorted_expr.options]) +} + +/// This method determines if the result of the join should be produced in the final step or not. +/// +/// # Arguments +/// +/// * `build_side` - Enum indicating the side of the join used as the build side. +/// * `join_type` - Enum indicating the type of join to be performed. +/// +/// # Returns +/// +/// A boolean indicating whether the result of the join should be produced in the final step or not. +/// The result will be true if the build side is JoinSide::Left and the join type is one of +/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi. +/// If the build side is JoinSide::Right, the result will be true if the join type +/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi. +fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { + if build_side == JoinSide::Left { + matches!( + join_type, + JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + ) + } else { + matches!( + join_type, + JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + ) + } +} + +/// Calculate indices by join type. +/// +/// This method returns a tuple of two arrays: build and probe indices. +/// The length of both arrays will be the same. +/// +/// # Arguments +/// +/// * `build_side`: Join side which defines the build side. +/// * `prune_length`: Length of the prune data. +/// * `visited_rows`: Hash set of visited rows of the build side. +/// * `deleted_offset`: Deleted offset of the build side. +/// * `join_type`: The type of join to be performed. +/// +/// # Returns +/// +/// A tuple of two arrays of primitive types representing the build and probe indices. +/// +fn calculate_indices_by_join_type( + build_side: JoinSide, + prune_length: usize, + visited_rows: &HashSet, + deleted_offset: usize, + join_type: JoinType, +) -> Result<(PrimitiveArray, PrimitiveArray)> +where + NativeAdapter: From<::Native>, +{ + // Store the result in a tuple + let result = match (build_side, join_type) { + // In the case of `Left` or `Right` join, or `Full` join, get the anti indices + (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) + | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) + | (_, JoinType::Full) => { + let build_unmatched_indices = + get_pruning_anti_indices(prune_length, deleted_offset, visited_rows); + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + // In the case of `LeftSemi` or `RightSemi` join, get the semi indices + (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { + let build_unmatched_indices = + get_pruning_semi_indices(prune_length, deleted_offset, visited_rows); + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + // The case of other join types is not considered + _ => unreachable!(), + }; + Ok(result) +} + +/// This function produces unmatched record results based on the build side, +/// join type and other parameters. +/// +/// The method uses first `prune_length` rows from the build side input buffer +/// to produce results. +/// +/// # Arguments +/// +/// * `output_schema` - The schema of the final output record batch. +/// * `prune_length` - The length of the determined prune length. +/// * `probe_schema` - The schema of the probe [RecordBatch]. +/// * `join_type` - The type of join to be performed. +/// * `column_indices` - Indices of columns that are being joined. +/// +/// # Returns +/// +/// * `Option` - The final output record batch if required, otherwise [None]. +pub(crate) fn build_side_determined_results( + build_hash_joiner: &OneSideHashJoiner, + output_schema: &SchemaRef, + prune_length: usize, + probe_schema: SchemaRef, + join_type: JoinType, + column_indices: &[ColumnIndex], +) -> Result> { + // Check if we need to produce a result in the final output: + if prune_length > 0 + && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) + { + // Calculate the indices for build and probe sides based on join type and build side: + let (build_indices, probe_indices) = calculate_indices_by_join_type( + build_hash_joiner.build_side, + prune_length, + &build_hash_joiner.visited_rows, + build_hash_joiner.deleted_offset, + join_type, + )?; + + // Create an empty probe record batch: + let empty_probe_batch = RecordBatch::new_empty(probe_schema); + // Build the final result from the indices of build and probe sides: + build_batch_from_indices( + output_schema.as_ref(), + &build_hash_joiner.input_buffer, + &empty_probe_batch, + &build_indices, + &probe_indices, + column_indices, + build_hash_joiner.build_side, + ) + .map(|batch| (batch.num_rows() > 0).then_some(batch)) + } else { + // If we don't need to produce a result, return None + Ok(None) + } +} + +/// This method performs a join between the build side input buffer and the probe side batch. +/// +/// # Arguments +/// +/// * `build_hash_joiner` - Build side hash joiner +/// * `probe_hash_joiner` - Probe side hash joiner +/// * `schema` - A reference to the schema of the output record batch. +/// * `join_type` - The type of join to be performed. +/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. +/// * `filter` - An optional filter on the join condition. +/// * `probe_batch` - The second record batch to be joined. +/// * `column_indices` - An array of columns to be selected for the result of the join. +/// * `random_state` - The random state for the join. +/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// +/// # Returns +/// +/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. +/// If the join type is one of the above four, the function will return [None]. +#[allow(clippy::too_many_arguments)] +pub(crate) fn join_with_probe_batch( + build_hash_joiner: &mut OneSideHashJoiner, + probe_hash_joiner: &mut OneSideHashJoiner, + schema: &SchemaRef, + join_type: JoinType, + filter: Option<&JoinFilter>, + probe_batch: &RecordBatch, + column_indices: &[ColumnIndex], + random_state: &RandomState, + null_equals_null: bool, +) -> Result> { + if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(None); + } + let (build_indices, probe_indices) = build_equal_condition_join_indices( + &build_hash_joiner.hashmap, + &build_hash_joiner.input_buffer, + probe_batch, + &build_hash_joiner.on, + &probe_hash_joiner.on, + random_state, + null_equals_null, + &mut build_hash_joiner.hashes_buffer, + filter, + build_hash_joiner.build_side, + Some(build_hash_joiner.deleted_offset), + )?; + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + record_visited_indices( + &mut build_hash_joiner.visited_rows, + build_hash_joiner.deleted_offset, + &build_indices, + ); + } + if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) { + record_visited_indices( + &mut probe_hash_joiner.visited_rows, + probe_hash_joiner.offset, + &probe_indices, + ); + } + if matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + ) { + Ok(None) + } else { + build_batch_from_indices( + schema, + &build_hash_joiner.input_buffer, + probe_batch, + &build_indices, + &probe_indices, + column_indices, + build_hash_joiner.build_side, + ) + .map(|batch| (batch.num_rows() > 0).then_some(batch)) + } +} + +pub struct OneSideHashJoiner { + /// Build side + build_side: JoinSide, + /// Input record batch buffer + pub input_buffer: RecordBatch, + /// Columns from the side + pub(crate) on: Vec, + /// Hashmap + pub(crate) hashmap: PruningJoinHashMap, + /// Reuse the hashes buffer + pub(crate) hashes_buffer: Vec, + /// Matched rows + pub(crate) visited_rows: HashSet, + /// Offset + pub(crate) offset: usize, + /// Deleted offset + pub(crate) deleted_offset: usize, +} + +impl OneSideHashJoiner { + pub fn size(&self) -> usize { + let mut size = 0; + size += std::mem::size_of_val(self); + size += std::mem::size_of_val(&self.build_side); + size += self.input_buffer.get_array_memory_size(); + size += std::mem::size_of_val(&self.on); + size += self.hashmap.size(); + size += self.hashes_buffer.capacity() * std::mem::size_of::(); + size += self.visited_rows.capacity() * std::mem::size_of::(); + size += std::mem::size_of_val(&self.offset); + size += std::mem::size_of_val(&self.deleted_offset); + size + } + pub fn new(build_side: JoinSide, on: Vec, schema: SchemaRef) -> Self { + Self { + build_side, + input_buffer: RecordBatch::new_empty(schema), + on, + hashmap: PruningJoinHashMap::with_capacity(0), + hashes_buffer: vec![], + visited_rows: HashSet::new(), + offset: 0, + deleted_offset: 0, + } + } + + /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch. + /// + /// # Arguments + /// + /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer + /// * `random_state` - The random state used to hash values + /// + /// # Returns + /// + /// Returns a [Result] encapsulating any intermediate errors. + pub(crate) fn update_internal_state( + &mut self, + batch: &RecordBatch, + random_state: &RandomState, + ) -> Result<()> { + // Merge the incoming batch with the existing input buffer: + self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?; + // Resize the hashes buffer to the number of rows in the incoming batch: + self.hashes_buffer.resize(batch.num_rows(), 0); + // Get allocation_info before adding the item + // Update the hashmap with the join key values and hashes of the incoming batch: + update_hash( + &self.on, + batch, + &mut self.hashmap, + self.offset, + random_state, + &mut self.hashes_buffer, + self.deleted_offset, + )?; + Ok(()) + } + + /// Calculate prune length. + /// + /// # Arguments + /// + /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression.. + /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. + /// * `graph` - A mutable reference to the physical expression graph. + /// + /// # Returns + /// + /// A Result object that contains the pruning length. + pub(crate) fn calculate_prune_length_with_probe_batch( + &mut self, + build_side_sorted_filter_expr: &mut SortedFilterExpr, + probe_side_sorted_filter_expr: &mut SortedFilterExpr, + graph: &mut ExprIntervalGraph, + ) -> Result { + // Return early if the input buffer is empty: + if self.input_buffer.num_rows() == 0 { + return Ok(0); + } + // Process the build and probe side sorted filter expressions if both are present: + // Collect the sorted filter expressions into a vector of (node_index, interval) tuples: + let mut filter_intervals = vec![]; + for expr in [ + &build_side_sorted_filter_expr, + &probe_side_sorted_filter_expr, + ] { + filter_intervals.push((expr.node_index(), expr.interval().clone())) + } + // Update the physical expression graph using the join filter intervals: + graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)?; + // Extract the new join filter interval for the build side: + let calculated_build_side_interval = filter_intervals.remove(0).1; + // If the intervals have not changed, return early without pruning: + if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) { + return Ok(0); + } + // Update the build side interval and determine the pruning length: + build_side_sorted_filter_expr.set_interval(calculated_build_side_interval); + + determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr) + } + + pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> { + // Prune the hash values: + self.hashmap.prune_hash_values( + prune_length, + self.deleted_offset as u64, + HASHMAP_SHRINK_SCALE_FACTOR, + )?; + // Remove pruned rows from the visited rows set: + for row in self.deleted_offset..(self.deleted_offset + prune_length) { + self.visited_rows.remove(&row); + } + // Update the input buffer after pruning: + self.input_buffer = self + .input_buffer + .slice(prune_length, self.input_buffer.num_rows() - prune_length); + // Increment the deleted offset: + self.deleted_offset += prune_length; + Ok(()) + } +} + +impl EagerJoinStream for SymmetricHashJoinStream { + fn process_batch_from_right( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Right) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StreamJoinStateResult::Ready(maybe_batch) + } else { + StreamJoinStateResult::Continue + } + }) + } + + fn process_batch_from_left( + &mut self, + batch: RecordBatch, + ) -> Result>> { + self.perform_join_for_given_side(batch, JoinSide::Left) + .map(|maybe_batch| { + if maybe_batch.is_some() { + StreamJoinStateResult::Ready(maybe_batch) + } else { + StreamJoinStateResult::Continue + } + }) + } + + fn process_batch_after_left_end( + &mut self, + right_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_right(right_batch) + } + + fn process_batch_after_right_end( + &mut self, + left_batch: RecordBatch, + ) -> Result>> { + self.process_batch_from_left(left_batch) + } + + fn process_batches_before_finalization( + &mut self, + ) -> Result>> { + // Get the left side results: + let left_result = build_side_determined_results( + &self.left, + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get the right side results: + let right_result = build_side_determined_results( + &self.right, + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + + // Combine the left and right results: + let result = combine_two_batches(&self.schema, left_result, right_result)?; + + // Update the metrics and return the result: + if let Some(batch) = &result { + // Update the metrics: + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Ok(StreamJoinStateResult::Ready(result)); + } + Ok(StreamJoinStateResult::Continue) + } + + fn right_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.right_stream + } + + fn left_stream(&mut self) -> &mut SendableRecordBatchStream { + &mut self.left_stream + } + + fn set_state(&mut self, state: EagerJoinStreamState) { + self.state = state; + } + + fn state(&mut self) -> EagerJoinStreamState { + self.state.clone() + } +} + +impl SymmetricHashJoinStream { + fn size(&self) -> usize { + let mut size = 0; + size += std::mem::size_of_val(&self.schema); + size += std::mem::size_of_val(&self.filter); + size += std::mem::size_of_val(&self.join_type); + size += self.left.size(); + size += self.right.size(); + size += std::mem::size_of_val(&self.column_indices); + size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); + size += std::mem::size_of_val(&self.left_sorted_filter_expr); + size += std::mem::size_of_val(&self.right_sorted_filter_expr); + size += std::mem::size_of_val(&self.random_state); + size += std::mem::size_of_val(&self.null_equals_null); + size += std::mem::size_of_val(&self.metrics); + size + } + + /// Performs a join operation for the specified `probe_side` (either left or right). + /// This function: + /// 1. Determines which side is the probe and which is the build side. + /// 2. Updates metrics based on the batch that was polled. + /// 3. Executes the join with the given `probe_batch`. + /// 4. Optionally computes anti-join results if all conditions are met. + /// 5. Combines the results and returns a combined batch or `None` if no batch was produced. + fn perform_join_for_given_side( + &mut self, + probe_batch: RecordBatch, + probe_side: JoinSide, + ) -> Result> { + let ( + probe_hash_joiner, + build_hash_joiner, + probe_side_sorted_filter_expr, + build_side_sorted_filter_expr, + probe_side_metrics, + ) = if probe_side.eq(&JoinSide::Left) { + ( + &mut self.left, + &mut self.right, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right, + &mut self.left, + &mut self.right_sorted_filter_expr, + &mut self.left_sorted_filter_expr, + &mut self.metrics.right, + ) + }; + // Update the metrics for the stream that was polled: + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for the build side: + probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)?; + // Join the two sides: + let equal_result = join_with_probe_batch( + build_hash_joiner, + probe_hash_joiner, + &self.schema, + self.join_type, + self.filter.as_ref(), + &probe_batch, + &self.column_indices, + &self.random_state, + self.null_equals_null, + )?; + // Increment the offset for the probe hash joiner: + probe_hash_joiner.offset += probe_batch.num_rows(); + + let anti_result = if let ( + Some(build_side_sorted_filter_expr), + Some(probe_side_sorted_filter_expr), + Some(graph), + ) = ( + build_side_sorted_filter_expr.as_mut(), + probe_side_sorted_filter_expr.as_mut(), + self.graph.as_mut(), + ) { + // Calculate filter intervals: + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + build_side_sorted_filter_expr, + &probe_batch, + probe_side_sorted_filter_expr, + )?; + let prune_length = build_hash_joiner + .calculate_prune_length_with_probe_batch( + build_side_sorted_filter_expr, + probe_side_sorted_filter_expr, + graph, + )?; + let result = build_side_determined_results( + build_hash_joiner, + &self.schema, + prune_length, + probe_batch.schema(), + self.join_type, + &self.column_indices, + )?; + build_hash_joiner.prune_internal_state(prune_length)?; + result + } else { + None + }; + + // Combine results: + let result = combine_two_batches(&self.schema, equal_result, anti_result)?; + let capacity = self.size(); + self.metrics.stream_memory_usage.set(capacity); + self.reservation.lock().try_resize(capacity)?; + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + } + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Mutex; + + use super::*; + use crate::joins::test_utils::{ + build_sides_record_batches, compare_batches, complicated_filter, + create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, + join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, + partitioned_sym_join_with_filter, split_record_batches, + }; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{binary, col, Column}; + + use once_cell::sync::Lazy; + use rstest::*; + + const TABLE_SIZE: i32 = 30; + + type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size) + type TableValue = (Vec, Vec); // (left, right) + + // Cache for storing tables + static TABLE_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + + fn get_or_create_table( + cardinality: (i32, i32), + batch_size: usize, + ) -> Result { + { + let cache = TABLE_CACHE.lock().unwrap(); + if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) { + return Ok(table.clone()); + } + } + + // If not, create the table + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + + let (left_partition, right_partition) = ( + split_record_batches(&left_batch, batch_size)?, + split_record_batches(&right_batch, batch_size)?, + ); + + // Lock the cache again and store the table + let mut cache = TABLE_CACHE.lock().unwrap(); + + // Store the table in the cache + cache.insert( + (cardinality.0, cardinality.1, batch_size), + (left_partition.clone(), right_partition.clone()), + ); + + Ok((left_partition, right_partition)) + } + + pub async fn experiment( + left: Arc, + right: Arc, + filter: Option, + join_type: JoinType, + on: JoinOn, + task_ctx: Arc, + ) -> Result<()> { + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left, right, on, filter, &join_type, false, task_ctx, + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (12, 17), + )] + cardinality: (i32, i32), + ) -> Result<()> { + // a + b > c + 10 AND a + b < c + 100 + let task_ctx = Arc::new(TaskContext::default()); + + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + + let left_sorted = vec![PhysicalSortExpr { + expr: binary( + col("la1", left_schema)?, + Operator::Plus, + col("la2", left_schema)?, + left_schema, + )?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values(0, 1, 2, 3, 4, 5)] case_expr: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_without_sort_information( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values(0, 1, 2, 3, 4, 5)] case_expr: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let (left, right) = + create_memory_table(left_partition, right_partition, vec![], vec![])?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_without_filter( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let (left, right) = + create_memory_table(left_partition, right_partition, vec![], vec![])?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + experiment(left, right, None, join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_descending_numeric_particular( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values(0, 1, 2, 3, 4, 5)] case_expr: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1_des", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1_des", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first() -> Result<()> { + let join_type = JoinType::Full; + let case_expr = 1; + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_asc_null_first", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_asc_null_first", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 6, + side: JoinSide::Left, + }, + ColumnIndex { + index: 6, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_last() -> Result<()> { + let join_type = JoinType::Full; + let case_expr = 1; + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_asc_null_last", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_asc_null_last", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 7, + side: JoinSide::Left, + }, + ColumnIndex { + index: 7, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first_descending() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_desc_null_first", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_desc_null_first", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture_i32( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 8, + side: JoinSide::Left, + }, + ColumnIndex { + index: 8, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_equivalence() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + // let session_ctx = SessionContext::with_config(config); + // let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![ + vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }], + vec![PhysicalSortExpr { + expr: col("la2", left_schema)?, + options: SortOptions::default(), + }], + ]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + + let (left, right) = create_memory_table( + left_partition, + right_partition, + left_sorted, + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn testing_with_temporal_columns( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (12, 17), + )] + cardinality: (i32, i32), + #[values(0, 1, 2)] case_expr: usize, + ) -> Result<()> { + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + let left_sorted = vec![PhysicalSortExpr { + expr: col("lt1", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("rt1", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + let intermediate_schema = Schema::new(vec![ + Field::new( + "left", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "right", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ]); + let filter_expr = join_expr_tests_fixture_temporal( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 3, + side: JoinSide::Left, + }, + ColumnIndex { + index: 3, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn test_with_interval_columns( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (12, 17), + )] + cardinality: (i32, i32), + ) -> Result<()> { + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + let left_sorted = vec![PhysicalSortExpr { + expr: col("li1", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ri1", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("right", DataType::Interval(IntervalUnit::DayTime), false), + ]); + let filter_expr = join_expr_tests_fixture_temporal( + 0, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 9, + side: JoinSide::Left, + }, + ColumnIndex { + index: 9, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn testing_ascending_float_pruning( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (12, 17), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4, 5)] case_expr: usize, + ) -> Result<()> { + let session_config = SessionConfig::new().with_repartition_joins(false); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_float", left_schema)?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_float", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Float64, true), + Field::new("right", DataType::Float64, true), + ]); + let filter_expr = join_expr_tests_fixture_f64( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 10, // l_float + side: JoinSide::Left, + }, + ColumnIndex { + index: 10, // r_float + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs new file mode 100644 index 0000000000000..fbd52ddf0c704 --- /dev/null +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -0,0 +1,567 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file has test utils for hash joins + +use std::sync::Arc; +use std::usize; + +use crate::joins::utils::{JoinFilter, JoinOn}; +use crate::joins::{ + HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; +use crate::memory::MemoryExec; +use crate::repartition::RepartitionExec; +use crate::{common, ExecutionPlan, Partitioning}; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::{ + ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, + TimestampMillisecondArray, +}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{binary, cast, col, lit}; +use datafusion_physical_expr::intervals::test_utils::{ + gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, +}; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; + +pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + // compare + let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); + let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); + + let mut first_formatted_sorted: Vec<&str> = first_formatted.trim().lines().collect(); + first_formatted_sorted.sort_unstable(); + + let mut second_formatted_sorted: Vec<&str> = + second_formatted.trim().lines().collect(); + second_formatted_sorted.sort_unstable(); + + for (i, (first_line, second_line)) in first_formatted_sorted + .iter() + .zip(&second_formatted_sorted) + .enumerate() + { + assert_eq!((i, first_line), (i, second_line)); + } +} + +pub async fn partitioned_sym_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, +) -> Result> { + let partition_count = 4; + + let left_expr = on + .iter() + .map(|(l, _)| Arc::new(l.clone()) as _) + .collect::>(); + + let right_expr = on + .iter() + .map(|(_, r)| Arc::new(r.clone()) as _) + .collect::>(); + + let join = SymmetricHashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + null_equals_null, + StreamJoinPartitionMode::Partitioned, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) +} + +pub async fn partitioned_hash_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, +) -> Result> { + let partition_count = 4; + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + + let join = Arc::new(HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + PartitionMode::Partitioned, + null_equals_null, + )?); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) +} + +pub fn split_record_batches( + batch: &RecordBatch, + batch_size: usize, +) -> Result> { + let row_num = batch.num_rows(); + let number_of_batch = row_num / batch_size; + let mut sizes = vec![batch_size; number_of_batch]; + sizes.push(row_num - (batch_size * number_of_batch)); + let mut result = vec![]; + for (i, size) in sizes.iter().enumerate() { + result.push(batch.slice(i * batch_size, *size)); + } + Ok(result) +} + +struct AscendingRandomFloatIterator { + prev: f64, + max: f64, + rng: StdRng, +} + +impl AscendingRandomFloatIterator { + fn new(min: f64, max: f64) -> Self { + let mut rng = StdRng::seed_from_u64(42); + let initial = rng.gen_range(min..max); + AscendingRandomFloatIterator { + prev: initial, + max, + rng, + } + } +} + +impl Iterator for AscendingRandomFloatIterator { + type Item = f64; + + fn next(&mut self) -> Option { + let value = self.rng.gen_range(self.prev..self.max); + self.prev = value; + Some(value) + } +} + +pub fn join_expr_tests_fixture_temporal( + expr_id: usize, + left_col: Arc, + right_col: Arc, + schema: &Schema, +) -> Result> { + match expr_id { + // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms')) + 0 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::new_interval_dt(0, 100), // 100 ms + ScalarValue::new_interval_dt(0, 200), // 200 ms + ScalarValue::new_interval_dt(0, 450), // 450 ms + ScalarValue::new_interval_dt(0, 300), // 300 ms + schema, + ), + // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02')) + 1 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03 + ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01 + ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00 + ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 + schema, + ), + // constructs ((left_col - DURATION '3 secs') > (right_col - DURATION '2 secs')) AND ((left_col - DURATION '5 secs') < (right_col - DURATION '4 secs')) + 2 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::DurationMillisecond(Some(3000)), // 3 secs + ScalarValue::DurationMillisecond(Some(2000)), // 2 secs + ScalarValue::DurationMillisecond(Some(5000)), // 5 secs + ScalarValue::DurationMillisecond(Some(4000)), // 4 secs + schema, + ), + _ => unreachable!(), + } +} + +// It creates join filters for different type of fields for testing. +macro_rules! join_expr_tests { + ($func_name:ident, $type:ty, $SCALAR:ident) => { + pub fn $func_name( + expr_id: usize, + left_col: Arc, + right_col: Arc, + ) -> Arc { + match expr_id { + // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 0 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 1 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 + 2 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 + 3 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 + 4 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(30 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + 5 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(2 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(7 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::GtEq, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + 6 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::Gt, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + 7 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::GtEq, Operator::Lt), + ), + _ => panic!("No case"), + } + } + }; +} + +join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); +join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); + +pub fn build_sides_record_batches( + table_size: i32, + key_cardinality: (i32, i32), +) -> Result<(RecordBatch, RecordBatch)> { + let null_ratio: f64 = 0.4; + let initial_range = 0..table_size; + let index = (table_size as f64 * null_ratio).round() as i32; + let rest_of = index..table_size; + let ordered: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )); + let ordered_des = Arc::new(Int32Array::from_iter( + initial_range.clone().rev().collect::>(), + )); + let cardinality = Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 4).collect::>(), + )); + let cardinality_key_left = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.0) + .collect::>(), + )); + let cardinality_key_right = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.1) + .collect::>(), + )); + let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })); + let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })); + + let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.rev().map(Some)) + .collect::>>() + })); + + let time = Arc::new(TimestampMillisecondArray::from( + initial_range + .clone() + .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00 + .collect::>(), + )); + let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( + initial_range + .map(|x| x as i64 * 100) // x * 100ms + .collect::>(), + )); + + let float_asc = Arc::new(Float64Array::from_iter_values( + AscendingRandomFloatIterator::new(0., table_size as f64) + .take(table_size as usize), + )); + + let left = RecordBatch::try_from_iter(vec![ + ("la1", ordered.clone()), + ("lb1", cardinality.clone()), + ("lc1", cardinality_key_left), + ("lt1", time.clone()), + ("la2", ordered.clone()), + ("la1_des", ordered_des.clone()), + ("l_asc_null_first", ordered_asc_null_first.clone()), + ("l_asc_null_last", ordered_asc_null_last.clone()), + ("l_desc_null_first", ordered_desc_null_first.clone()), + ("li1", interval_time.clone()), + ("l_float", float_asc.clone()), + ])?; + let right = RecordBatch::try_from_iter(vec![ + ("ra1", ordered.clone()), + ("rb1", cardinality), + ("rc1", cardinality_key_right), + ("rt1", time), + ("ra2", ordered), + ("ra1_des", ordered_des), + ("r_asc_null_first", ordered_asc_null_first), + ("r_asc_null_last", ordered_asc_null_last), + ("r_desc_null_first", ordered_desc_null_first), + ("ri1", interval_time), + ("r_float", float_asc), + ])?; + Ok((left, right)) +} + +pub fn create_memory_table( + left_partition: Vec, + right_partition: Vec, + left_sorted: Vec, + right_sorted: Vec, +) -> Result<(Arc, Arc)> { + let left_schema = left_partition[0].schema(); + let left = MemoryExec::try_new(&[left_partition], left_schema, None)? + .with_sort_information(left_sorted); + let right_schema = right_partition[0].schema(); + let right = MemoryExec::try_new(&[right_partition], right_schema, None)? + .with_sort_information(right_sorted); + Ok((Arc::new(left), Arc::new(right))) +} + +/// Filter expr for a + b > c + 10 AND a + b < c + 100 +pub(crate) fn complicated_filter( + filter_schema: &Schema, +) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) +} diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs similarity index 64% rename from datafusion/core/src/physical_plan/joins/utils.rs rename to datafusion/physical-plan/src/joins/utils.rs index f7e81b5add6ee..5e01ca227cf5a 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,6 +17,18 @@ //! Join related functionality used both on logical and physical plans +use std::collections::HashSet; +use std::fmt::{self, Debug}; +use std::future::Future; +use std::ops::IndexMut; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::usize; + +use crate::joins::stream_join_utils::{build_filter_input_order, SortedFilterExpr}; +use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; + use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, @@ -24,32 +36,151 @@ use arrow::array::{ use arrow::compute; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::stats::Precision; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + SharedResult, +}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; +use datafusion_physical_expr::utils::merge_vectors; +use datafusion_physical_expr::{ + LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr, +}; + use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; +use hashbrown::raw::RawTable; use parking_lot::Mutex; -use std::cmp::max; -use std::collections::HashSet; -use std::fmt::{Display, Formatter}; -use std::future::Future; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::usize; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::{ScalarValue, SharedResult}; +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// +/// Insert (10,1) <-- insert hash value 10 with row index 1 +/// map: +/// ---------- +/// | 10 | 2 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (20,2) +/// map: +/// ---------- +/// | 10 | 2 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (10,3) <-- collision! row index 3 has a hash value of 10 as well +/// map: +/// ---------- +/// | 10 | 4 | +/// | 20 | 3 | +/// ---------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash value of 10 +/// map: +/// --------- +/// | 10 | 5 | +/// | 20 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` +pub struct JoinHashMap { + // Stores hash value to last row index + map: RawTable<(u64, u64)>, + // Stores indices in chained list data structure + next: Vec, +} -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; +impl JoinHashMap { + #[cfg(test)] + pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec) -> Self { + Self { map, next } + } -use datafusion_common::JoinType; -use datafusion_common::{DataFusionError, Result}; + pub(crate) fn with_capacity(capacity: usize) -> Self { + JoinHashMap { + map: RawTable::with_capacity(capacity), + next: vec![0; capacity], + } + } +} -use crate::physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::physical_plan::SchemaRef; -use crate::physical_plan::{ - ColumnStatistics, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, -}; +// Trait defining methods that must be implemented by a hash map type to be used for joins. +pub trait JoinHashMapType { + /// The type of list used to store the next list + type NextType: IndexMut; + /// Extend with zero + fn extend_zero(&mut self, len: usize); + /// Returns mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType); + /// Returns a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)>; + /// Returns a reference to the next. + fn get_list(&self) -> &Self::NextType; +} + +/// Implementation of `JoinHashMapType` for `JoinHashMap`. +impl JoinHashMapType for JoinHashMap { + type NextType = Vec; + + // Void implementation + fn extend_zero(&mut self, _: usize) {} + + /// Get mutable references to the hash map and the next. + fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { + (&mut self.map, &mut self.next) + } + + /// Get a reference to the hash map. + fn get_map(&self) -> &RawTable<(u64, u64)> { + &self.map + } + + /// Get a reference to the next. + fn get_list(&self) -> &Self::NextType { + &self.next + } +} + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; @@ -89,9 +220,9 @@ fn check_join_set_is_valid( let right_missing = on_right.difference(right).collect::>(); if !left_missing.is_empty() | !right_missing.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}", - ))); + return plan_err!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}" + ); }; Ok(()) @@ -131,131 +262,94 @@ pub fn adjust_right_output_partitioning( Partitioning::Hash(exprs, size) => { let new_exprs = exprs .into_iter() - .map(|expr| { - expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( - col.name(), - left_columns_len + col.index(), - )))), - None => Ok(Transformed::No(e)), - }) - .unwrap() - }) - .collect::>(); + .map(|expr| add_offset_to_expr(expr, left_columns_len)) + .collect(); Partitioning::Hash(new_exprs, size) } } } -/// Combine the Equivalence Properties for Join Node -pub fn combine_join_equivalence_properties( - join_type: JoinType, - left_properties: EquivalenceProperties, - right_properties: EquivalenceProperties, +/// Replaces the right column (first index in the `on_column` tuple) with +/// the left column (zeroth index in the tuple) inside `right_ordering`. +fn replace_on_columns_of_right_ordering( + on_columns: &[(Column, Column)], + right_ordering: &mut [PhysicalSortExpr], left_columns_len: usize, - on: &[(Column, Column)], - schema: SchemaRef, -) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(schema); - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - new_properties.extend(left_properties.classes().to_vec()); - let new_right_properties = right_properties - .classes() - .iter() - .map(|prop| { - let new_head = Column::new( - prop.head().name(), - left_columns_len + prop.head().index(), - ); - let new_others = prop - .others() - .iter() - .map(|col| { - Column::new(col.name(), left_columns_len + col.index()) - }) - .collect::>(); - EquivalentClass::new(new_head, new_others) - }) - .collect::>(); - - new_properties.extend(new_right_properties); - } - JoinType::LeftSemi | JoinType::LeftAnti => { - new_properties.extend(left_properties.classes().to_vec()) - } - JoinType::RightSemi | JoinType::RightAnti => { - new_properties.extend(right_properties.classes().to_vec()) +) { + for (left_col, right_col) in on_columns { + let right_col = + Column::new(right_col.name(), right_col.index() + left_columns_len); + for item in right_ordering.iter_mut() { + if let Some(col) = item.expr.as_any().downcast_ref::() { + if right_col.eq(col) { + item.expr = Arc::new(left_col.clone()) as _; + } + } } } - - if join_type == JoinType::Inner { - on.iter().for_each(|(column1, column2)| { - let new_column2 = - Column::new(column2.name(), left_columns_len + column2.index()); - new_properties.add_equal_conditions((column1, &new_column2)) - }) - } - new_properties } -/// Calculate the Equivalence Properties for CrossJoin Node -pub fn cross_join_equivalence_properties( - left_properties: EquivalenceProperties, - right_properties: EquivalenceProperties, +/// Calculate the output ordering of a given join operation. +pub fn calculate_join_output_ordering( + left_ordering: LexOrderingRef, + right_ordering: LexOrderingRef, + join_type: JoinType, + on_columns: &[(Column, Column)], left_columns_len: usize, - schema: SchemaRef, -) -> EquivalenceProperties { - let mut new_properties = EquivalenceProperties::new(schema); - new_properties.extend(left_properties.classes().to_vec()); - let new_right_properties = right_properties - .classes() - .iter() - .map(|prop| { - let new_head = - Column::new(prop.head().name(), left_columns_len + prop.head().index()); - let new_others = prop - .others() + maintains_input_order: &[bool], + probe_side: Option, +) -> Option { + let mut right_ordering = match join_type { + // In the case below, right ordering should be offseted with the left + // side length, since we append the right table to the left table. + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + right_ordering .iter() - .map(|col| Column::new(col.name(), left_columns_len + col.index())) - .collect::>(); - EquivalentClass::new(new_head, new_others) - }) - .collect::>(); - new_properties.extend(new_right_properties); - new_properties -} - -impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - JoinSide::Left => write!(f, "left"), - JoinSide::Right => write!(f, "right"), + .map(|sort_expr| PhysicalSortExpr { + expr: add_offset_to_expr(sort_expr.expr.clone(), left_columns_len), + options: sort_expr.options, + }) + .collect() } - } -} - -/// Used in ColumnIndex to distinguish which side the index is for -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum JoinSide { - /// Left side of the join - Left, - /// Right side of the join - Right, -} - -impl JoinSide { - /// Inverse the join side - pub fn negate(&self) -> Self { - match self { - JoinSide::Left => JoinSide::Right, - JoinSide::Right => JoinSide::Left, + _ => right_ordering.to_vec(), + }; + let output_ordering = match maintains_input_order { + [true, false] => { + // Special case, we can prefix ordering of right side with the ordering of left side. + if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { + replace_on_columns_of_right_ordering( + on_columns, + &mut right_ordering, + left_columns_len, + ); + merge_vectors(left_ordering, &right_ordering) + } else { + left_ordering.to_vec() + } } - } + [false, true] => { + // Special case, we can prefix ordering of left side with the ordering of right side. + if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { + replace_on_columns_of_right_ordering( + on_columns, + &mut right_ordering, + left_columns_len, + ); + merge_vectors(&right_ordering, left_ordering) + } else { + right_ordering.to_vec() + } + } + // Doesn't maintain ordering, output ordering is None. + [false, false] => return None, + [true, true] => unreachable!("Cannot maintain ordering of both sides"), + _ => unreachable!("Join operators can not have more than two children"), + }; + (!output_ordering.is_empty()).then_some(output_ordering) } /// Information about the index and placement (left or right) of the columns -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ColumnIndex { /// Index of the column pub index: usize, @@ -493,21 +587,21 @@ pub(crate) fn estimate_join_statistics( right: Arc, on: JoinOn, join_type: &JoinType, -) -> Statistics { - let left_stats = left.statistics(); - let right_stats = right.statistics(); + schema: &Schema, +) -> Result { + let left_stats = left.statistics()?; + let right_stats = right.statistics()?; let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on); let (num_rows, column_statistics) = match join_stats { - Some(stats) => (Some(stats.num_rows), Some(stats.column_statistics)), - None => (None, None), + Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), + None => (Precision::Absent, Statistics::unknown_column(schema)), }; - Statistics { + Ok(Statistics { num_rows, - total_byte_size: None, + total_byte_size: Precision::Absent, column_statistics, - is_exact: false, - } + }) } // Estimate the cardinality for the given join with input statistics. @@ -519,29 +613,27 @@ fn estimate_join_cardinality( ) -> Option { match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let left_num_rows = left_stats.num_rows?; - let right_num_rows = right_stats.num_rows?; - - // Take the left_col_stats and right_col_stats using the index - // obtained from index() method of the each element of 'on'. - let all_left_col_stats = left_stats.column_statistics?; - let all_right_col_stats = right_stats.column_statistics?; let (left_col_stats, right_col_stats) = on .iter() .map(|(left, right)| { ( - all_left_col_stats[left.index()].clone(), - all_right_col_stats[right.index()].clone(), + left_stats.column_statistics[left.index()].clone(), + right_stats.column_statistics[right.index()].clone(), ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); let ij_cardinality = estimate_inner_join_cardinality( - left_num_rows, - right_num_rows, - left_col_stats, - right_col_stats, - left_stats.is_exact && right_stats.is_exact, + Statistics { + num_rows: left_stats.num_rows.clone(), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: right_stats.num_rows.clone(), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, )?; // The cardinality for inner join can also be used to estimate @@ -550,25 +642,25 @@ fn estimate_join_cardinality( // joins (so that we don't underestimate the cardinality). let cardinality = match join_type { JoinType::Inner => ij_cardinality, - JoinType::Left => max(ij_cardinality, left_num_rows), - JoinType::Right => max(ij_cardinality, right_num_rows), - JoinType::Full => { - max(ij_cardinality, left_num_rows) - + max(ij_cardinality, right_num_rows) - - ij_cardinality - } + JoinType::Left => ij_cardinality.max(&left_stats.num_rows), + JoinType::Right => ij_cardinality.max(&right_stats.num_rows), + JoinType::Full => ij_cardinality + .max(&left_stats.num_rows) + .add(&ij_cardinality.max(&right_stats.num_rows)) + .sub(&ij_cardinality), _ => unreachable!(), }; Some(PartialJoinStatistics { - num_rows: cardinality, + num_rows: *cardinality.get_value()?, // We don't do anything specific here, just combine the existing // statistics which might yield subpar results (although it is // true, esp regarding min/max). For a better estimation, we need // filter selectivity analysis first. - column_statistics: all_left_col_stats + column_statistics: left_stats + .column_statistics .into_iter() - .chain(all_right_col_stats.into_iter()) + .chain(right_stats.column_statistics) .collect(), }) } @@ -585,30 +677,47 @@ fn estimate_join_cardinality( /// a very conservative implementation that can quickly give up if there is not /// enough input statistics. fn estimate_inner_join_cardinality( - left_num_rows: usize, - right_num_rows: usize, - left_col_stats: Vec, - right_col_stats: Vec, - is_exact: bool, -) -> Option { + left_stats: Statistics, + right_stats: Statistics, +) -> Option> { // The algorithm here is partly based on the non-histogram selectivity estimation // from Spark's Catalyst optimizer. - - let mut join_selectivity = None; - for (left_stat, right_stat) in left_col_stats.iter().zip(right_col_stats.iter()) { - if (left_stat.min_value.clone()? > right_stat.max_value.clone()?) - || (left_stat.max_value.clone()? < right_stat.min_value.clone()?) - { - // If there is no overlap in any of the join columns, that means the join - // itself is disjoint and the cardinality is 0. Though we can only assume - // this when the statistics are exact (since it is a very strong assumption). - return if is_exact { Some(0) } else { None }; + let mut join_selectivity = Precision::Absent; + for (left_stat, right_stat) in left_stats + .column_statistics + .iter() + .zip(right_stats.column_statistics.iter()) + { + // If there is no overlap in any of the join columns, this means the join + // itself is disjoint and the cardinality is 0. Though we can only assume + // this when the statistics are exact (since it is a very strong assumption). + if left_stat.min_value.get_value()? > right_stat.max_value.get_value()? { + return Some( + if left_stat.min_value.is_exact().unwrap_or(false) + && right_stat.max_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + if left_stat.max_value.get_value()? < right_stat.min_value.get_value()? { + return Some( + if left_stat.max_value.is_exact().unwrap_or(false) + && right_stat.min_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); } - let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone()); - let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone()); - let max_distinct = max(left_max_distinct, right_max_distinct); - if max_distinct > join_selectivity { + let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); + let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat); + let max_distinct = left_max_distinct.max(&right_max_distinct); + if max_distinct.get_value().is_some() { // Seems like there are a few implementations of this algorithm that implement // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs // further exploration. @@ -619,9 +728,14 @@ fn estimate_inner_join_cardinality( // With the assumption that the smaller input's domain is generally represented in the bigger // input's domain, we can estimate the inner join's cardinality by taking the cartesian product // of the two inputs and normalizing it by the selectivity factor. + let left_num_rows = left_stats.num_rows.get_value()?; + let right_num_rows = right_stats.num_rows.get_value()?; match join_selectivity { - Some(selectivity) if selectivity > 0 => { - Some((left_num_rows * right_num_rows) / selectivity) + Precision::Exact(value) if value > 0 => { + Some(Precision::Exact((left_num_rows * right_num_rows) / value)) + } + Precision::Inexact(value) if value > 0 => { + Some(Precision::Inexact((left_num_rows * right_num_rows) / value)) } // Since we don't have any information about the selectivity (which is derived // from the number of distinct rows information) we can give up here for now. @@ -632,47 +746,61 @@ fn estimate_inner_join_cardinality( } /// Estimate the number of maximum distinct values that can be present in the -/// given column from its statistics. -/// -/// If distinct_count is available, uses it directly. If the column numeric, and -/// has min/max values, then they might be used as a fallback option. Otherwise, -/// returns None. -fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option { - match (stats.distinct_count, stats.max_value, stats.min_value) { - (Some(_), _, _) => stats.distinct_count, - (_, Some(max), Some(min)) => { - // Note that float support is intentionally omitted here, since the computation - // of a range between two float values is not trivial and the result would be - // highly inaccurate. - let numeric_range = get_int_range(min, max)?; - - // The number can never be greater than the number of rows we have (minus - // the nulls, since they don't count as distinct values). - let ceiling = num_rows - stats.null_count.unwrap_or(0); - Some(numeric_range.min(ceiling)) - } - _ => None, - } -} +/// given column from its statistics. If distinct_count is available, uses it +/// directly. Otherwise, if the column is numeric and has min/max values, it +/// estimates the maximum distinct count from those. +fn max_distinct_count( + num_rows: &Precision, + stats: &ColumnStatistics, +) -> Precision { + match &stats.distinct_count { + dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc.clone(), + _ => { + // The number can never be greater than the number of rows we have + // minus the nulls (since they don't count as distinct values). + let result = match num_rows { + Precision::Absent => Precision::Absent, + Precision::Inexact(count) => { + Precision::Inexact(count - stats.null_count.get_value().unwrap_or(&0)) + } + Precision::Exact(count) => { + let count = count - stats.null_count.get_value().unwrap_or(&0); + if stats.null_count.is_exact().unwrap_or(false) { + Precision::Exact(count) + } else { + Precision::Inexact(count) + } + } + }; + // Cap the estimate using the number of possible values: + if let (Some(min), Some(max)) = + (stats.min_value.get_value(), stats.max_value.get_value()) + { + if let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) + .ok() + .and_then(|e| e.cardinality()) + { + let range_dc = range_dc as usize; + // Note that the `unwrap` calls in the below statement are safe. + return if matches!(result, Precision::Absent) + || &range_dc < result.get_value().unwrap() + { + if stats.min_value.is_exact().unwrap() + && stats.max_value.is_exact().unwrap() + { + Precision::Exact(range_dc) + } else { + Precision::Inexact(range_dc) + } + } else { + result + }; + } + } -/// Return the numeric range between the given min and max values. -fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option { - let delta = &max.sub(&min).ok()?; - match delta { - ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize), - ScalarValue::UInt8(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt16(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt32(Some(delta)) => Some(*delta as usize), - ScalarValue::UInt64(Some(delta)) => Some(*delta as usize), - _ => None, + result + } } - // The delta (directly) is not the real range, since it does not include the - // first term. - // E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3). - .map(|open_ended_range| open_ended_range + 1) } enum OnceFutState { @@ -784,15 +912,15 @@ pub(crate) fn apply_join_filter_to_indices( filter.schema(), build_input_buffer, probe_batch, - build_indices.clone(), - probe_indices.clone(), + &build_indices, + &probe_indices, filter.column_indices(), build_side, )?; let filter_result = filter .expression() .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows()); + .into_array(intermediate_batch.num_rows())?; let mask = as_boolean_array(&filter_result)?; let left_filtered = compute::filter(&build_indices, mask)?; @@ -809,8 +937,8 @@ pub(crate) fn build_batch_from_indices( schema: &Schema, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, - build_indices: UInt64Array, - probe_indices: UInt32Array, + build_indices: &UInt64Array, + probe_indices: &UInt32Array, column_indices: &[ColumnIndex], build_side: JoinSide, ) -> Result { @@ -841,7 +969,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(build_indices.null_count(), build_indices.len()); new_null_array(array.data_type(), build_indices.len()) } else { - compute::take(array.as_ref(), &build_indices, None)? + compute::take(array.as_ref(), build_indices, None)? } } else { let array = probe_batch.column(column_index.index); @@ -849,7 +977,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(probe_indices.null_count(), probe_indices.len()); new_null_array(array.data_type(), probe_indices.len()) } else { - compute::take(array.as_ref(), &probe_indices, None)? + compute::take(array.as_ref(), probe_indices, None)? } }; columns.push(array); @@ -1064,14 +1192,102 @@ impl BuildProbeJoinMetrics { } } +/// Updates sorted filter expressions with corresponding node indices from the +/// expression interval graph. +/// +/// This function iterates through the provided sorted filter expressions, +/// gathers the corresponding node indices from the expression interval graph, +/// and then updates the sorted expressions with these indices. It ensures +/// that these sorted expressions are aligned with the structure of the graph. +fn update_sorted_exprs_with_node_indices( + graph: &mut ExprIntervalGraph, + sorted_exprs: &mut [SortedFilterExpr], +) { + // Extract filter expressions from the sorted expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|expr| expr.filter_expr().clone()) + .collect::>(); + + // Gather corresponding node indices for the extracted filter expressions from the graph: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Iterate through the sorted expressions and the gathered node indices: + for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) { + // Update each sorted expression with the corresponding node index: + sorted_expr.set_node_index(index); + } +} + +/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// +/// # Arguments +/// +/// * `filter` - The join filter to base the sorting on. +/// * `left` - The left execution plan. +/// * `right` - The right execution plan. +/// * `left_sort_exprs` - The expressions to sort on the left side. +/// * `right_sort_exprs` - The expressions to sort on the right side. +/// +/// # Returns +/// +/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. +pub fn prepare_sorted_exprs( + filter: &JoinFilter, + left: &Arc, + right: &Arc, + left_sort_exprs: &[PhysicalSortExpr], + right_sort_exprs: &[PhysicalSortExpr], +) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { + // Build the filter order for the left side + let err = || plan_datafusion_err!("Filter does not include the child order"); + + let left_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Left, + filter, + &left.schema(), + &left_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Build the filter order for the right side + let right_temp_sorted_filter_expr = build_filter_input_order( + JoinSide::Right, + filter, + &right.schema(), + &right_sort_exprs[0], + )? + .ok_or_else(err)?; + + // Collect the sorted expressions + let mut sorted_exprs = + vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; + + // Build the expression interval graph + let mut graph = + ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + + // Update sorted expressions with node indices + update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); + + // Swap and remove to get the final sorted filter expressions + let right_sorted_filter_expr = sorted_exprs.swap_remove(1); + let left_sorted_filter_expr = sorted_exprs.swap_remove(0); + + Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) +} + #[cfg(test)] mod tests { + use std::pin::Pin; + use super::*; - use arrow::datatypes::Fields; - use arrow::error::Result as ArrowResult; - use arrow::{datatypes::DataType, error::ArrowError}; + + use arrow::datatypes::{DataType, Fields}; + use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_schema::SortOptions; + use datafusion_common::ScalarValue; - use std::pin::Pin; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -1224,14 +1440,18 @@ mod tests { fn create_stats( num_rows: Option, - column_stats: Option>, + column_stats: Vec, is_exact: bool, ) -> Statistics { Statistics { - num_rows, + num_rows: if is_exact { + num_rows.map(Precision::Exact) + } else { + num_rows.map(Precision::Inexact) + } + .unwrap_or(Precision::Absent), column_statistics: column_stats, - is_exact, - ..Default::default() + total_byte_size: Precision::Absent, } } @@ -1241,9 +1461,15 @@ mod tests { distinct_count: Option, ) -> ColumnStatistics { ColumnStatistics { - distinct_count, - min_value: min.map(|size| ScalarValue::Int64(Some(size))), - max_value: max.map(|size| ScalarValue::Int64(Some(size))), + distinct_count: distinct_count + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + min_value: min + .map(|size| Precision::Inexact(ScalarValue::from(size))) + .unwrap_or(Precision::Absent), + max_value: max + .map(|size| Precision::Inexact(ScalarValue::from(size))) + .unwrap_or(Precision::Absent), ..Default::default() } } @@ -1255,7 +1481,7 @@ mod tests { // over the expected output (since it depends on join type to join type). #[test] fn test_inner_join_cardinality_single_column() -> Result<()> { - let cases: Vec<(PartialStats, PartialStats, Option)> = vec![ + let cases: Vec<(PartialStats, PartialStats, Option>)> = vec![ // ----------------------------------------------------------------------------- // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected | // ----------------------------------------------------------------------------- @@ -1267,70 +1493,70 @@ mod tests { ( (10, Some(1), Some(10), None), (10, Some(1), Some(10), None), - Some(10), + Some(Precision::Inexact(10)), ), // range(left) > range(right) ( (10, Some(6), Some(10), None), (10, Some(8), Some(10), None), - Some(20), + Some(Precision::Inexact(20)), ), // range(right) > range(left) ( (10, Some(8), Some(10), None), (10, Some(6), Some(10), None), - Some(20), + Some(Precision::Inexact(20)), ), // range(left) > len(left), range(right) > len(right) ( (10, Some(1), Some(15), None), (20, Some(1), Some(40), None), - Some(10), + Some(Precision::Inexact(10)), ), // When we have distinct count. ( (10, Some(1), Some(10), Some(10)), (10, Some(1), Some(10), Some(10)), - Some(10), + Some(Precision::Inexact(10)), ), // distinct(left) > distinct(right) ( (10, Some(1), Some(10), Some(5)), (10, Some(1), Some(10), Some(2)), - Some(20), + Some(Precision::Inexact(20)), ), // distinct(right) > distinct(left) ( (10, Some(1), Some(10), Some(2)), (10, Some(1), Some(10), Some(5)), - Some(20), + Some(Precision::Inexact(20)), ), // min(left) < 0 (range(left) > range(right)) ( (10, Some(-5), Some(5), None), (10, Some(1), Some(5), None), - Some(10), + Some(Precision::Inexact(10)), ), // min(right) < 0, max(right) < 0 (range(right) > range(left)) ( (10, Some(-25), Some(-20), None), (10, Some(-25), Some(-15), None), - Some(10), + Some(Precision::Inexact(10)), ), // range(left) < 0, range(right) >= 0 // (there isn't a case where both left and right ranges are negative // so one of them is always going to work, this just proves negative // ranges with bigger absolute values are not are not accidentally used). ( - (10, Some(10), Some(0), None), + (10, Some(-10), Some(0), None), (10, Some(0), Some(10), Some(5)), - Some(20), // It would have been ten if we have used abs(range(left)) + Some(Precision::Inexact(10)), ), // range(left) = 1, range(right) = 1 ( (10, Some(1), Some(1), None), (10, Some(1), Some(1), None), - Some(100), + Some(Precision::Inexact(100)), ), // // Edge cases @@ -1355,22 +1581,12 @@ mod tests { ( (10, Some(0), Some(10), None), (10, Some(11), Some(20), None), - None, + Some(Precision::Inexact(0)), ), ( (10, Some(11), Some(20), None), (10, Some(0), Some(10), None), - None, - ), - ( - (10, Some(5), Some(10), Some(10)), - (10, Some(11), Some(3), Some(10)), - None, - ), - ( - (10, Some(10), Some(5), Some(10)), - (10, Some(3), Some(7), Some(10)), - None, + Some(Precision::Inexact(0)), ), // distinct(left) = 0, distinct(right) = 0 ( @@ -1394,13 +1610,18 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( - left_num_rows, - right_num_rows, - left_col_stats.clone(), - right_col_stats.clone(), - false, + Statistics { + num_rows: Precision::Inexact(left_num_rows), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats.clone(), + }, + Statistics { + num_rows: Precision::Inexact(right_num_rows), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats.clone(), + }, ), - expected_cardinality + expected_cardinality.clone() ); // We should also be able to use join_cardinality to get the same results @@ -1408,18 +1629,22 @@ mod tests { let join_on = vec![(Column::new("a", 0), Column::new("b", 0))]; let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(left_num_rows), Some(left_col_stats.clone()), false), - create_stats(Some(right_num_rows), Some(right_col_stats.clone()), false), + create_stats(Some(left_num_rows), left_col_stats.clone(), false), + create_stats(Some(right_num_rows), right_col_stats.clone(), false), &join_on, ); assert_eq!( - partial_join_stats.clone().map(|s| s.num_rows), - expected_cardinality + partial_join_stats + .clone() + .map(|s| Precision::Inexact(s.num_rows)), + expected_cardinality.clone() ); assert_eq!( partial_join_stats.map(|s| s.column_statistics), - expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat()) + expected_cardinality + .clone() + .map(|_| [left_col_stats, right_col_stats].concat()) ); } Ok(()) @@ -1441,13 +1666,18 @@ mod tests { // count is 200, so we are going to pick it. assert_eq!( estimate_inner_join_cardinality( - 400, - 400, - left_col_stats, - right_col_stats, - false + Statistics { + num_rows: Precision::Inexact(400), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: Precision::Inexact(400), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, ), - Some((400 * 400) / 200) + Some(Precision::Inexact((400 * 400) / 200)) ); Ok(()) } @@ -1455,28 +1685,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: None, - min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: None, - min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Precision::Absent, + min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( - 100, - 100, - left_col_stats, - right_col_stats, - false + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: left_col_stats, + }, + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: right_col_stats, + }, ), - None + Some(Precision::Inexact(100)) ); Ok(()) } @@ -1521,8 +1756,8 @@ mod tests { let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(1000), Some(left_col_stats.clone()), false), - create_stats(Some(2000), Some(right_col_stats.clone()), false), + create_stats(Some(1000), left_col_stats.clone(), false), + create_stats(Some(2000), right_col_stats.clone(), false), &join_on, ) .unwrap(); @@ -1586,8 +1821,8 @@ mod tests { for (join_type, expected_num_rows) in cases { let partial_join_stats = estimate_join_cardinality( &join_type, - create_stats(Some(1000), Some(left_col_stats.clone()), true), - create_stats(Some(2000), Some(right_col_stats.clone()), true), + create_stats(Some(1000), left_col_stats.clone(), true), + create_stats(Some(2000), right_col_stats.clone(), true), &join_on, ) .unwrap(); @@ -1600,4 +1835,104 @@ mod tests { Ok(()) } + + #[test] + fn test_calculate_join_output_ordering() -> Result<()> { + let options = SortOptions::default(); + let left_ordering = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options, + }, + ]; + let right_ordering = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("z", 2)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("y", 1)), + options, + }, + ]; + let join_type = JoinType::Inner; + let on_columns = [(Column::new("b", 1), Column::new("x", 0))]; + let left_columns_len = 5; + let maintains_input_orders = [[true, false], [false, true]]; + let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; + + let expected = [ + Some(vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("z", 7)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("y", 6)), + options, + }, + ]), + Some(vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("z", 7)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("y", 6)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("d", 3)), + options, + }, + ]), + ]; + + for (i, (maintains_input_order, probe_side)) in + maintains_input_orders.iter().zip(probe_sides).enumerate() + { + assert_eq!( + calculate_join_output_ordering( + &left_ordering, + &right_ordering, + join_type, + &on_columns, + left_columns_len, + maintains_input_order, + probe_side + ), + expected[i] + ); + } + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs new file mode 100644 index 0000000000000..6c9e97e03cb7f --- /dev/null +++ b/datafusion/physical-plan/src/lib.rs @@ -0,0 +1,584 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Traits for physical query plan, supporting parallel execution for partitioned relations. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::display::DisplayableExecutionPlan; +use crate::metrics::MetricsSet; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::Transformed; +use datafusion_common::utils::DataPtr; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalSortExpr, PhysicalSortRequirement, +}; + +use futures::stream::TryStreamExt; +use tokio::task::JoinSet; + +mod topk; +mod visitor; + +pub mod aggregates; +pub mod analyze; +pub mod coalesce_batches; +pub mod coalesce_partitions; +pub mod common; +pub mod display; +pub mod empty; +pub mod explain; +pub mod filter; +pub mod insert; +pub mod joins; +pub mod limit; +pub mod memory; +pub mod metrics; +mod ordering; +pub mod placeholder_row; +pub mod projection; +pub mod repartition; +pub mod sorts; +pub mod stream; +pub mod streaming; +pub mod tree_node; +pub mod udaf; +pub mod union; +pub mod unnest; +pub mod values; +pub mod windows; + +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; +pub use crate::topk::TopK; +pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; + +use datafusion_common::config::ConfigOptions; +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +pub use datafusion_physical_expr::{ + expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, +}; + +// Backwards compatibility +pub use crate::stream::EmptyRecordBatchStream; +pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; + +/// Represent nodes in the DataFusion Physical Plan. +/// +/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of +/// [`RecordBatch`] that incrementally computes a partition of the +/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more +/// details on partitioning. +/// +/// Methods such as [`schema`] and [`output_partitioning`] communicate +/// properties of this output to the DataFusion optimizer, and methods such as +/// [`required_input_distribution`] and [`required_input_ordering`] express +/// requirements of the `ExecutionPlan` from its input. +/// +/// [`ExecutionPlan`] can be displayed in a simplified form using the +/// return value from [`displayable`] in addition to the (normally +/// quite verbose) `Debug` output. +/// +/// [`execute`]: ExecutionPlan::execute +/// [`schema`]: ExecutionPlan::schema +/// [`output_partitioning`]: ExecutionPlan::output_partitioning +/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution +/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering +pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { + /// Returns the execution plan as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef; + + /// Specifies how the output of this `ExecutionPlan` is split into + /// partitions. + fn output_partitioning(&self) -> Partitioning; + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, _children: &[bool]) -> Result { + if _children.iter().any(|&x| x) { + plan_err!("Plan does not support infinite stream from its children") + } else { + Ok(false) + } + } + + /// If the output of this `ExecutionPlan` within each partition is sorted, + /// returns `Some(keys)` with the description of how it was sorted. + /// + /// For example, Sort, (obviously) produces sorted output as does + /// SortPreservingMergeStream. Less obviously `Projection` + /// produces sorted output if its input was sorted as it does not + /// reorder the input rows, + /// + /// It is safe to return `None` here if your `ExecutionPlan` does not + /// have any particular output order here + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; + + /// Specifies the data distribution requirements for all the + /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution; self.children().len()] + } + + /// Specifies the ordering required for all of the children of this + /// `ExecutionPlan`. + /// + /// For each child, it's the local ordering requirement within + /// each partition rather than the global ordering + /// + /// NOTE that checking `!is_empty()` does **not** check for a + /// required input ordering. Instead, the correct check is that at + /// least one entry must be `Some` + fn required_input_ordering(&self) -> Vec>> { + vec![None; self.children().len()] + } + + /// Returns `false` if this `ExecutionPlan`'s implementation may reorder + /// rows within or between partitions. + /// + /// For example, Projection, Filter, and Limit maintain the order + /// of inputs -- they may transform values (Projection) or not + /// produce the same number of rows that went in (Filter and + /// Limit), but the rows that are produced go in the same way. + /// + /// DataFusion uses this metadata to apply certain optimizations + /// such as automatically repartitioning correctly. + /// + /// The default implementation returns `false` + /// + /// WARNING: if you override this default, you *MUST* ensure that + /// the `ExecutionPlan`'s maintains the ordering invariant or else + /// DataFusion may produce incorrect results. + fn maintains_input_order(&self) -> Vec { + vec![false; self.children().len()] + } + + /// Specifies whether the `ExecutionPlan` benefits from increased + /// parallelization at its input for each child. + /// + /// If returns `true`, the `ExecutionPlan` would benefit from partitioning + /// its corresponding child (and thus from more parallelism). For + /// `ExecutionPlan` that do very little work the overhead of extra + /// parallelism may outweigh any benefits + /// + /// The default implementation returns `true` unless this `ExecutionPlan` + /// has signalled it requires a single child input partition. + fn benefits_from_input_partitioning(&self) -> Vec { + // By default try to maximize parallelism with more CPUs if + // possible + self.required_input_distribution() + .into_iter() + .map(|dist| !matches!(dist, Distribution::SinglePartition)) + .collect() + } + + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`Self::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new(self.schema()) + } + + /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. + /// The returned list will be empty for leaf nodes such as scans, will contain + /// a single value for unary nodes, or two values for binary nodes (such as + /// joins). + fn children(&self) -> Vec>; + + /// Returns a new `ExecutionPlan` where all existing children were replaced + /// by the `children`, oi order + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result>; + + /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to + /// produce `target_partitions` partitions. + /// + /// If the `ExecutionPlan` does not support changing its partitioning, + /// returns `Ok(None)` (the default). + /// + /// It is the `ExecutionPlan` can increase its partitioning, but not to the + /// `target_partitions`, it may return an ExecutionPlan with fewer + /// partitions. This might happen, for example, if each new partition would + /// be too small to be efficiently processed individually. + /// + /// The DataFusion optimizer attempts to use as many threads as possible by + /// repartitioning its inputs to match the target number of threads + /// available (`target_partitions`). Some data sources, such as the built in + /// CSV and Parquet readers, implement this method as they are able to read + /// from their input files in parallel, regardless of how the source data is + /// split amongst files. + fn repartitioned( + &self, + _target_partitions: usize, + _config: &ConfigOptions, + ) -> Result>> { + Ok(None) + } + + /// Begin execution of `partition`, returning a [`Stream`] of + /// [`RecordBatch`]es. + /// + /// # Notes + /// + /// The `execute` method itself is not `async` but it returns an `async` + /// [`futures::stream::Stream`]. This `Stream` should incrementally compute + /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). + /// Most `ExecutionPlan`s should not do any work before the first + /// `RecordBatch` is requested from the stream. + /// + /// [`RecordBatchStreamAdapter`] can be used to convert an `async` + /// [`Stream`] into a [`SendableRecordBatchStream`]. + /// + /// Using `async` `Streams` allows for network I/O during execution and + /// takes advantage of Rust's built in support for `async` continuations and + /// crate ecosystem. + /// + /// [`Stream`]: futures::stream::Stream + /// [`StreamExt`]: futures::stream::StreamExt + /// [`TryStreamExt`]: futures::stream::TryStreamExt + /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter + /// + /// # Implementation Examples + /// + /// While `async` `Stream`s have a non trivial learning curve, the + /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] + /// which help simplify many common operations. + /// + /// Here are some common patterns: + /// + /// ## Return Precomputed `RecordBatch` + /// + /// We can return a precomputed `RecordBatch` as a `Stream`: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// batch: RecordBatch, + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // use functions from futures crate convert the batch into a stream + /// let fut = futures::future::ready(Ok(self.batch.clone())); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) Compute `RecordBatch` + /// + /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// Returns a single batch when the returned stream is polled + /// async fn get_batch() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// let fut = get_batch(); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) create a Stream + /// + /// If you need to to create the return `Stream` using an `async` function, + /// you can do so by flattening the result: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use futures::TryStreamExt; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// async function that returns a stream + /// async fn get_batch_stream() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // A future that yields a stream + /// let fut = get_batch_stream(); + /// // Use TryStreamExt::try_flatten to flatten the stream of streams + /// let stream = futures::stream::once(fut).try_flatten(); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result; + + /// Return a snapshot of the set of [`Metric`]s for this + /// [`ExecutionPlan`]. If no `Metric`s are available, return None. + /// + /// While the values of the metrics in the returned + /// [`MetricsSet`]s may change as execution progresses, the + /// specific metrics will not. + /// + /// Once `self.execute()` has returned (technically the future is + /// resolved) for all available partitions, the set of metrics + /// should be complete. If this function is called prior to + /// `execute()` new metrics may appear in subsequent calls. + fn metrics(&self) -> Option { + None + } + + /// Returns statistics for this `ExecutionPlan` node. If statistics are not + /// available, should return [`Statistics::new_unknown`] (the default), not + /// an error. + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful +/// especially for the distributed engine to judge whether need to deal with shuffling. +/// Currently there are 3 kinds of execution plan which needs data exchange +/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s +/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee +/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee +pub fn need_data_exchange(plan: Arc) -> bool { + if let Some(repart) = plan.as_any().downcast_ref::() { + !matches!( + repart.output_partitioning(), + Partitioning::RoundRobinBatch(_) + ) + } else if let Some(coalesce) = plan.as_any().downcast_ref::() + { + coalesce.input().output_partitioning().partition_count() > 1 + } else if let Some(sort_preserving_merge) = + plan.as_any().downcast_ref::() + { + sort_preserving_merge + .input() + .output_partitioning() + .partition_count() + > 1 + } else { + false + } +} + +/// Returns a copy of this plan if we change any child according to the pointer comparison. +/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. +pub fn with_new_children_if_necessary( + plan: Arc, + children: Vec>, +) -> Result>> { + let old_children = plan.children(); + if children.len() != old_children.len() { + internal_err!("Wrong number of children") + } else if children.is_empty() + || children + .iter() + .zip(old_children.iter()) + .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) + { + Ok(Transformed::Yes(plan.with_new_children(children)?)) + } else { + Ok(Transformed::No(plan)) + } +} + +/// Return a [wrapper](DisplayableExecutionPlan) around an +/// [`ExecutionPlan`] which can be displayed in various easier to +/// understand ways. +pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { + DisplayableExecutionPlan::new(plan) +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect( + plan: Arc, + context: Arc, +) -> Result> { + let stream = execute_stream(plan, context)?; + common::collect(stream).await +} + +/// Execute the [ExecutionPlan] and return a single stream of results +pub fn execute_stream( + plan: Arc, + context: Arc, +) -> Result { + match plan.output_partitioning().partition_count() { + 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), + 1 => plan.execute(0, context), + _ => { + // merge into a single partition + let plan = CoalescePartitionsExec::new(plan.clone()); + // CoalescePartitionsExec must produce a single partition + assert_eq!(1, plan.output_partitioning().partition_count()); + plan.execute(0, context) + } + } +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect_partitioned( + plan: Arc, + context: Arc, +) -> Result>> { + let streams = execute_stream_partitioned(plan, context)?; + + let mut join_set = JoinSet::new(); + // Execute the plan and collect the results into batches. + streams.into_iter().enumerate().for_each(|(idx, stream)| { + join_set.spawn(async move { + let result: Result> = stream.try_collect().await; + (idx, result) + }); + }); + + let mut batches = vec![]; + // Note that currently this doesn't identify the thread that panicked + // + // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id + // once it is stable + while let Some(result) = join_set.join_next().await { + match result { + Ok((idx, res)) => batches.push((idx, res?)), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + batches.sort_by_key(|(idx, _)| *idx); + let batches = batches.into_iter().map(|(_, batch)| batch).collect(); + + Ok(batches) +} + +/// Execute the [ExecutionPlan] and return a vec with one stream per output partition +pub fn execute_stream_partitioned( + plan: Arc, + context: Arc, +) -> Result> { + let num_partitions = plan.output_partitioning().partition_count(); + let mut streams = Vec::with_capacity(num_partitions); + for i in 0..num_partitions { + streams.push(plan.execute(i, context.clone())?); + } + Ok(streams) +} + +// Get output (un)boundedness information for the given `plan`. +pub fn unbounded_output(plan: &Arc) -> bool { + let children_unbounded_output = plan + .children() + .iter() + .map(unbounded_output) + .collect::>(); + plan.unbounded_output(&children_unbounded_output) + .unwrap_or(true) +} + +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + +#[cfg(test)] +pub mod test; diff --git a/datafusion/core/src/physical_plan/limit.rs b/datafusion/physical-plan/src/limit.rs similarity index 62% rename from datafusion/core/src/physical_plan/limit.rs rename to datafusion/physical-plan/src/limit.rs index 132bae6141466..355561c36f35f 100644 --- a/datafusion/core/src/physical_plan/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -17,30 +17,28 @@ //! Defines the LIMIT plan -use futures::stream::Stream; -use futures::stream::StreamExt; -use log::trace; use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::physical_plan::{ +use super::expressions::PhysicalSortExpr; +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::{ DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, }; + use arrow::array::ArrayRef; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion_common::{DataFusionError, Result}; - -use super::expressions::PhysicalSortExpr; -use super::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - +use datafusion_common::stats::Precision; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use futures::stream::{Stream, StreamExt}; +use log::trace; + /// Limit execution plan #[derive(Debug)] pub struct GlobalLimitExec { @@ -82,6 +80,25 @@ impl GlobalLimitExec { } } +impl DisplayAs for GlobalLimitExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "GlobalLimitExec: skip={}, fetch={}", + self.skip, + self.fetch.map_or("None".to_string(), |x| x.to_string()) + ) + } + } + } +} + impl ExecutionPlan for GlobalLimitExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -108,8 +125,8 @@ impl ExecutionPlan for GlobalLimitExec { vec![true] } - fn benefits_from_input_partitioning(&self) -> bool { - false + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -131,6 +148,10 @@ impl ExecutionPlan for GlobalLimitExec { ))) } + fn unbounded_output(&self, _children: &[bool]) -> Result { + Ok(false) + } + fn execute( &self, partition: usize, @@ -142,16 +163,12 @@ impl ExecutionPlan for GlobalLimitExec { ); // GlobalLimitExec has a single output partition if 0 != partition { - return Err(DataFusionError::Internal(format!( - "GlobalLimitExec invalid partition {partition}" - ))); + return internal_err!("GlobalLimitExec invalid partition {partition}"); } // GlobalLimitExec requires a single input partition if 1 != self.input.output_partitioning().partition_count() { - return Err(DataFusionError::Internal( - "GlobalLimitExec requires a single input partition".to_owned(), - )); + return internal_err!("GlobalLimitExec requires a single input partition"); } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -164,72 +181,85 @@ impl ExecutionPlan for GlobalLimitExec { ))) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "GlobalLimitExec: skip={}, fetch={}", - self.skip, - self.fetch.map_or("None".to_string(), |x| x.to_string()) - ) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stats = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stats = self.input.statistics()?; let skip = self.skip; - // the maximum row number needs to be fetched - let max_row_num = self - .fetch - .map(|fetch| { - if fetch >= usize::MAX - skip { - usize::MAX - } else { - fetch + skip - } - }) - .unwrap_or(usize::MAX); - match input_stats { + let col_stats = Statistics::unknown_column(&self.schema()); + let fetch = self.fetch.unwrap_or(usize::MAX); + + let mut fetched_row_number_stats = Statistics { + num_rows: Precision::Exact(fetch), + column_statistics: col_stats.clone(), + total_byte_size: Precision::Absent, + }; + + let stats = match input_stats { Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. + } + | Statistics { + num_rows: Precision::Inexact(nr), + .. } => { if nr <= skip { // if all input data will be skipped, return 0 - Statistics { - num_rows: Some(0), - is_exact: input_stats.is_exact, - ..Default::default() + let mut skip_all_rows_stats = Statistics { + num_rows: Precision::Exact(0), + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }; + if !input_stats.num_rows.is_exact().unwrap_or(false) { + // The input stats are inexact, so the output stats must be too. + skip_all_rows_stats = skip_all_rows_stats.into_inexact(); } - } else if nr <= max_row_num { - // if the input does not reach the "fetch" globally, return input stats + skip_all_rows_stats + } else if nr <= fetch && self.skip == 0 { + // if the input does not reach the "fetch" globally, and "skip" is zero + // (meaning the input and output are identical), return input stats. + // Can input_stats still be used, but adjusted, in the "skip != 0" case? input_stats + } else if nr - skip <= fetch { + // after "skip" input rows are skipped, the remaining rows are less than or equal to the + // "fetch" values, so `num_rows` must equal the remaining rows + let remaining_rows: usize = nr - skip; + let mut skip_some_rows_stats = Statistics { + num_rows: Precision::Exact(remaining_rows), + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }; + if !input_stats.num_rows.is_exact().unwrap_or(false) { + // The input stats are inexact, so the output stats must be too. + skip_some_rows_stats = skip_some_rows_stats.into_inexact(); + } + skip_some_rows_stats } else { - // if the input is greater than the "fetch", the num_row will be the "fetch", + // if the input is greater than "fetch+skip", the num_rows will be the "fetch", // but we won't be able to predict the other statistics - Statistics { - num_rows: Some(max_row_num), - is_exact: input_stats.is_exact, - ..Default::default() + if !input_stats.num_rows.is_exact().unwrap_or(false) + || self.fetch.is_none() + { + // If the input stats are inexact, the output stats must be too. + // If the fetch value is `usize::MAX` because no LIMIT was specified, + // we also can't represent it as an exact value. + fetched_row_number_stats = + fetched_row_number_stats.into_inexact(); } + fetched_row_number_stats } } - _ => Statistics { - // the result output row number will always be no greater than the limit number - num_rows: Some(max_row_num), - is_exact: false, - ..Default::default() - }, - } + _ => { + // The result output `num_rows` will always be no greater than the limit number. + // Should `num_rows` be marked as `Absent` here when the `fetch` value is large, + // as the actual `num_rows` may be far away from the `fetch` value? + fetched_row_number_stats.into_inexact() + } + }; + Ok(stats) } } @@ -265,6 +295,20 @@ impl LocalLimitExec { } } +impl DisplayAs for LocalLimitExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "LocalLimitExec: fetch={}", self.fetch) + } + } + } +} + impl ExecutionPlan for LocalLimitExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -283,8 +327,8 @@ impl ExecutionPlan for LocalLimitExec { self.input.output_partitioning() } - fn benefits_from_input_partitioning(&self) -> bool { - false + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] } // Local limit will not change the input plan's ordering @@ -300,6 +344,10 @@ impl ExecutionPlan for LocalLimitExec { self.input.equivalence_properties() } + fn unbounded_output(&self, _children: &[bool]) -> Result { + Ok(false) + } + fn with_new_children( self: Arc, children: Vec>, @@ -309,9 +357,7 @@ impl ExecutionPlan for LocalLimitExec { children[0].clone(), self.fetch, ))), - _ => Err(DataFusionError::Internal( - "LocalLimitExec wrong number of children".to_string(), - )), + _ => internal_err!("LocalLimitExec wrong number of children"), } } @@ -331,48 +377,57 @@ impl ExecutionPlan for LocalLimitExec { ))) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "LocalLimitExec: fetch={}", self.fetch) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stats = self.input.statistics(); - match input_stats { + fn statistics(&self) -> Result { + let input_stats = self.input.statistics()?; + let col_stats = Statistics::unknown_column(&self.schema()); + let stats = match input_stats { // if the input does not reach the limit globally, return input stats Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. + } + | Statistics { + num_rows: Precision::Inexact(nr), + .. } if nr <= self.fetch => input_stats, // if the input is greater than the limit, the num_row will be greater // than the limit because the partitions will be limited separatly // the statistic Statistics { - num_rows: Some(nr), .. + num_rows: Precision::Exact(nr), + .. + } if nr > self.fetch => Statistics { + num_rows: Precision::Exact(self.fetch), + // this is not actually exact, but will be when GlobalLimit is applied + // TODO stats: find a more explicit way to vehiculate this information + column_statistics: col_stats, + total_byte_size: Precision::Absent, + }, + Statistics { + num_rows: Precision::Inexact(nr), + .. } if nr > self.fetch => Statistics { - num_rows: Some(self.fetch), + num_rows: Precision::Inexact(self.fetch), // this is not actually exact, but will be when GlobalLimit is applied // TODO stats: find a more explicit way to vehiculate this information - is_exact: input_stats.is_exact, - ..Default::default() + column_statistics: col_stats, + total_byte_size: Precision::Absent, }, _ => Statistics { // the result output row number will always be no greater than the limit number - num_rows: Some(self.fetch * self.output_partitioning().partition_count()), - is_exact: false, - ..Default::default() + num_rows: Precision::Inexact( + self.fetch * self.output_partitioning().partition_count(), + ), + + column_statistics: col_stats, + total_byte_size: Precision::Absent, }, - } + }; + Ok(stats) } } @@ -428,7 +483,7 @@ impl LimitStream { match &poll { Poll::Ready(Some(Ok(batch))) => { - if batch.num_rows() > 0 && self.skip == 0 { + if batch.num_rows() > 0 { break poll; } else { // continue to poll input stream @@ -514,22 +569,22 @@ impl RecordBatchStream for LimitStream { #[cfg(test)] mod tests { - - use common::collect; - use super::*; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::common; - use crate::prelude::SessionContext; - use crate::test; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::common::collect; + use crate::{common, test}; + + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use arrow_schema::Schema; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalExpr; #[tokio::test] async fn limit() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let num_partitions = 4; - let csv = test::scan_partitioned_csv(num_partitions)?; + let csv = test::scan_partitioned(num_partitions); // input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); @@ -613,9 +668,9 @@ mod tests { #[tokio::test] async fn limit_no_column() -> Result<()> { let batches = vec![ - test::make_batch_no_column(6), - test::make_batch_no_column(6), - test::make_batch_no_column(6), + make_batch_no_column(6), + make_batch_no_column(6), + make_batch_no_column(6), ]; let input = test::exec::TestStream::new(batches); @@ -642,11 +697,11 @@ mod tests { // test cases for "skip" async fn skip_and_fetch(skip: usize, fetch: Option) -> Result { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); + // 4 partitions @ 100 rows apiece let num_partitions = 4; - let csv = test::scan_partitioned_csv(num_partitions)?; + let csv = test::scan_partitioned(num_partitions); assert_eq!(csv.output_partitioning().partition_count(), num_partitions); @@ -662,7 +717,7 @@ mod tests { #[tokio::test] async fn skip_none_fetch_none() -> Result<()> { let row_count = skip_and_fetch(0, None).await?; - assert_eq!(row_count, 100); + assert_eq!(row_count, 400); Ok(()) } @@ -675,14 +730,14 @@ mod tests { #[tokio::test] async fn skip_3_fetch_none() -> Result<()> { - // there are total of 100 rows, we skipped 3 rows (offset = 3) + // there are total of 400 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, None).await?; - assert_eq!(row_count, 97); + assert_eq!(row_count, 397); Ok(()) } #[tokio::test] - async fn skip_3_fetch_10() -> Result<()> { + async fn skip_3_fetch_10_stats() -> Result<()> { // there are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); @@ -690,23 +745,24 @@ mod tests { } #[tokio::test] - async fn skip_100_fetch_none() -> Result<()> { - let row_count = skip_and_fetch(100, None).await?; + async fn skip_400_fetch_none() -> Result<()> { + let row_count = skip_and_fetch(400, None).await?; assert_eq!(row_count, 0); Ok(()) } #[tokio::test] - async fn skip_100_fetch_1() -> Result<()> { - let row_count = skip_and_fetch(100, Some(1)).await?; + async fn skip_400_fetch_1() -> Result<()> { + // there are a total of 400 rows + let row_count = skip_and_fetch(400, Some(1)).await?; assert_eq!(row_count, 0); Ok(()) } #[tokio::test] - async fn skip_101_fetch_none() -> Result<()> { - // there are total of 100 rows, we skipped 101 rows (offset = 3) - let row_count = skip_and_fetch(101, None).await?; + async fn skip_401_fetch_none() -> Result<()> { + // there are total of 400 rows, we skipped 401 rows (offset = 3) + let row_count = skip_and_fetch(401, None).await?; assert_eq!(row_count, 0); Ok(()) } @@ -714,10 +770,61 @@ mod tests { #[tokio::test] async fn test_row_number_statistics_for_global_limit() -> Result<()> { let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?; - assert_eq!(row_count, Some(10)); + assert_eq!(row_count, Precision::Exact(10)); let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?; - assert_eq!(row_count, Some(15)); + assert_eq!(row_count, Precision::Exact(10)); + + let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?; + assert_eq!(row_count, Precision::Exact(0)); + + let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?; + assert_eq!(row_count, Precision::Exact(1)); + + let row_count = row_number_statistics_for_global_limit(398, None).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = + row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Exact(400)); + + let row_count = + row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Exact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(0, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(10)); + + let row_count = + row_number_inexact_statistics_for_global_limit(5, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(10)); + + let row_count = + row_number_inexact_statistics_for_global_limit(400, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(0)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(10)).await?; + assert_eq!(row_count, Precision::Inexact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(1)).await?; + assert_eq!(row_count, Precision::Inexact(1)); + + let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?; + assert_eq!(row_count, Precision::Inexact(2)); + + let row_count = + row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Inexact(400)); + + let row_count = + row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?; + assert_eq!(row_count, Precision::Inexact(2)); Ok(()) } @@ -725,7 +832,7 @@ mod tests { #[tokio::test] async fn test_row_number_statistics_for_local_limit() -> Result<()> { let row_count = row_number_statistics_for_local_limit(4, 10).await?; - assert_eq!(row_count, Some(40)); + assert_eq!(row_count, Precision::Exact(10)); Ok(()) } @@ -733,28 +840,77 @@ mod tests { async fn row_number_statistics_for_global_limit( skip: usize, fetch: Option, - ) -> Result> { + ) -> Result> { let num_partitions = 4; - let csv = test::scan_partitioned_csv(num_partitions)?; + let csv = test::scan_partitioned(num_partitions); assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - Ok(offset.statistics().num_rows) + Ok(offset.statistics()?.num_rows) + } + + pub fn build_group_by( + input_schema: &SchemaRef, + columns: Vec, + ) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) + } + + async fn row_number_inexact_statistics_for_global_limit( + skip: usize, + fetch: Option, + ) -> Result> { + let num_partitions = 4; + let csv = test::scan_partitioned(num_partitions); + + assert_eq!(csv.output_partitioning().partition_count(), num_partitions); + + // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. + let agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&csv.schema().clone(), vec!["i".to_string()]), + vec![], + vec![None], + vec![None], + csv.clone(), + csv.schema().clone(), + )?; + let agg_exec: Arc = Arc::new(agg); + + let offset = GlobalLimitExec::new( + Arc::new(CoalescePartitionsExec::new(agg_exec)), + skip, + fetch, + ); + + Ok(offset.statistics()?.num_rows) } async fn row_number_statistics_for_local_limit( num_partitions: usize, fetch: usize, - ) -> Result> { - let csv = test::scan_partitioned_csv(num_partitions)?; + ) -> Result> { + let csv = test::scan_partitioned(num_partitions); assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let offset = LocalLimitExec::new(csv, fetch); - Ok(offset.statistics().num_rows) + Ok(offset.statistics()?.num_rows) + } + + /// Return a RecordBatch with a single array with row_count sz + fn make_batch_no_column(sz: usize) -> RecordBatch { + let schema = Arc::new(Schema::empty()); + + let options = RecordBatchOptions::new().with_row_count(Option::from(sz)); + RecordBatch::try_new_with_options(schema, vec![], &options).unwrap() } } diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/physical-plan/src/memory.rs similarity index 55% rename from datafusion/core/src/physical_plan/memory.rs rename to datafusion/physical-plan/src/memory.rs index 38fa5d549cba4..7de474fda11c3 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -17,21 +17,23 @@ //! Execution plan for reading in-memory batches of data +use std::any::Any; +use std::fmt; +use std::sync::Arc; +use std::task::{Context, Poll}; + use super::expressions::PhysicalSortExpr; use super::{ - common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; + use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use core::fmt; -use datafusion_common::Result; -use std::any::Any; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use datafusion_common::DataFusionError; +use datafusion_common::{internal_err, project_schema, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; + use futures::Stream; /// Execution plan for reading in-memory batches of data @@ -44,15 +46,51 @@ pub struct MemoryExec { projected_schema: SchemaRef, /// Optional projection projection: Option>, - // Optional sort information - sort_information: Option>, + // Sort information: one or more equivalent orderings + sort_information: Vec, } impl fmt::Debug for MemoryExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; - write!(f, "projection: {:?}", self.projection) + write!(f, "projection: {:?}", self.projection)?; + if let Some(sort_info) = &self.sort_information.first() { + write!(f, ", output_ordering: {:?}", sort_info)?; + } + Ok(()) + } +} + +impl DisplayAs for MemoryExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let partition_sizes: Vec<_> = + self.partitions.iter().map(|b| b.len()).collect(); + + let output_ordering = self + .sort_information + .first() + .map(|output_ordering| { + format!( + ", output_ordering={}", + PhysicalSortExpr::format_list(output_ordering) + ) + }) + .unwrap_or_default(); + + write!( + f, + "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}", + partition_sizes.len(), + ) + } + } } } @@ -78,16 +116,25 @@ impl ExecutionPlan for MemoryExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.sort_information.as_deref() + self.sort_information + .first() + .map(|ordering| ordering.as_slice()) + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings(self.schema(), &self.sort_information) } fn with_new_children( self: Arc, - _: Vec>, + children: Vec>, ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {self:?}" - ))) + // MemoryExec has no children + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } } fn execute( @@ -102,32 +149,13 @@ impl ExecutionPlan for MemoryExec { )?)) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let partitions: Vec<_> = - self.partitions.iter().map(|b| b.len()).collect(); - write!( - f, - "MemoryExec: partitions={}, partition_sizes={:?}", - partitions.len(), - partitions - ) - } - } - } - /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so - fn statistics(&self) -> Statistics { - common::compute_record_batch_statistics( + fn statistics(&self) -> Result { + Ok(common::compute_record_batch_statistics( &self.partitions, &self.schema, self.projection.clone(), - ) + )) } } @@ -145,35 +173,42 @@ impl MemoryExec { schema, projected_schema, projection, - sort_information: None, + sort_information: vec![], }) } - /// Create a new execution plan for reading in-memory record batches - /// The provided `schema` should not have the projection applied. - pub fn try_new_owned_data( - partitions: Vec>, - schema: SchemaRef, - projection: Option>, - ) -> Result { - let projected_schema = project_schema(&schema, projection.as_ref())?; - Ok(Self { - partitions, - schema, - projected_schema, - projection, - sort_information: None, - }) + pub fn partitions(&self) -> &[Vec] { + &self.partitions } - /// Set sort information - pub fn with_sort_information( - mut self, - sort_information: Vec, - ) -> Self { - self.sort_information = Some(sort_information); + pub fn projection(&self) -> &Option> { + &self.projection + } + + /// A memory table can be ordered by multiple expressions simultaneously. + /// [`EquivalenceProperties`] keeps track of expressions that describe the + /// global ordering of the schema. These columns are not necessarily same; e.g. + /// ```text + /// ┌-------┐ + /// | a | b | + /// |---|---| + /// | 1 | 9 | + /// | 2 | 8 | + /// | 3 | 7 | + /// | 5 | 5 | + /// └---┴---┘ + /// ``` + /// where both `a ASC` and `b DESC` can describe the table ordering. With + /// [`EquivalenceProperties`], we can keep track of these equivalences + /// and treat `a ASC` and `b DESC` as the same ordering requirement. + pub fn with_sort_information(mut self, sort_information: Vec) -> Self { + self.sort_information = sort_information; self } + + pub fn original_schema(&self) -> SchemaRef { + self.schema.clone() + } } /// Iterator over batches @@ -238,3 +273,47 @@ impl RecordBatchStream for MemoryStream { self.schema.clone() } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::memory::MemoryExec; + use crate::ExecutionPlan; + + use arrow_schema::{DataType, Field, Schema, SortOptions}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalSortExpr; + + #[test] + fn test_memory_order_eq() -> datafusion_common::Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + ])); + let expected_output_order = vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions::default(), + }, + ]; + let expected_order_eq = vec![PhysicalSortExpr { + expr: col("c", &schema)?, + options: SortOptions::default(), + }]; + let sort_information = + vec![expected_output_order.clone(), expected_order_eq.clone()]; + let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)? + .with_sort_information(sort_information); + + assert_eq!(mem_exec.output_ordering().unwrap(), expected_output_order); + let eq_properties = mem_exec.equivalence_properties(); + assert!(eq_properties.oeq_class().contains(&expected_order_eq)); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs similarity index 98% rename from datafusion/core/src/physical_plan/metrics/baseline.rs rename to datafusion/physical-plan/src/metrics/baseline.rs index 7d72a6a9fae17..dc345cd8cdcd6 100644 --- a/datafusion/core/src/physical_plan/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -29,7 +29,7 @@ use datafusion_common::Result; /// /// Example: /// ``` -/// use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; +/// use datafusion_physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; /// let metrics = ExecutionPlanMetricsSet::new(); /// /// let partition = 2; diff --git a/datafusion/core/src/physical_plan/metrics/builder.rs b/datafusion/physical-plan/src/metrics/builder.rs similarity index 99% rename from datafusion/core/src/physical_plan/metrics/builder.rs rename to datafusion/physical-plan/src/metrics/builder.rs index 30e9764c64460..beecc13e0029b 100644 --- a/datafusion/core/src/physical_plan/metrics/builder.rs +++ b/datafusion/physical-plan/src/metrics/builder.rs @@ -29,7 +29,7 @@ use super::{ /// case of constant strings /// /// ```rust -/// use datafusion::physical_plan::metrics::*; +/// use datafusion_physical_plan::metrics::*; /// /// let metrics = ExecutionPlanMetricsSet::new(); /// let partition = 1; diff --git a/datafusion/core/src/physical_plan/metrics/mod.rs b/datafusion/physical-plan/src/metrics/mod.rs similarity index 99% rename from datafusion/core/src/physical_plan/metrics/mod.rs rename to datafusion/physical-plan/src/metrics/mod.rs index 652c0af5c2e44..b2e0086f69e9a 100644 --- a/datafusion/core/src/physical_plan/metrics/mod.rs +++ b/datafusion/physical-plan/src/metrics/mod.rs @@ -43,7 +43,7 @@ pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// [`ExecutionPlanMetricsSet`]. /// /// ``` -/// use datafusion::physical_plan::metrics::*; +/// use datafusion_physical_plan::metrics::*; /// /// let metrics = ExecutionPlanMetricsSet::new(); /// assert!(metrics.clone_inner().output_rows().is_none()); diff --git a/datafusion/core/src/physical_plan/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs similarity index 99% rename from datafusion/core/src/physical_plan/metrics/value.rs rename to datafusion/physical-plan/src/metrics/value.rs index 59b012f25a27d..899ceb60b49f7 100644 --- a/datafusion/core/src/physical_plan/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -430,11 +430,13 @@ impl MetricValue { Self::Time { time, .. } => time.value(), Self::StartTimestamp(timestamp) => timestamp .value() - .map(|ts| ts.timestamp_nanos() as usize) + .and_then(|ts| ts.timestamp_nanos_opt()) + .map(|nanos| nanos as usize) .unwrap_or(0), Self::EndTimestamp(timestamp) => timestamp .value() - .map(|ts| ts.timestamp_nanos() as usize) + .and_then(|ts| ts.timestamp_nanos_opt()) + .map(|nanos| nanos as usize) .unwrap_or(0), } } diff --git a/datafusion/physical-plan/src/ordering.rs b/datafusion/physical-plan/src/ordering.rs new file mode 100644 index 0000000000000..047f89eef1932 --- /dev/null +++ b/datafusion/physical-plan/src/ordering.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Specifies how the input to an aggregation or window operator is ordered +/// relative to their `GROUP BY` or `PARTITION BY` expressions. +/// +/// For example, if the existing ordering is `[a ASC, b ASC, c ASC]` +/// +/// ## Window Functions +/// - A `PARTITION BY b` clause can use `Linear` mode. +/// - A `PARTITION BY a, c` or a `PARTITION BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `PARTITION BY a, b` or a `PARTITION BY b, a` can use `Sorted` mode. +/// +/// ## Aggregations +/// - A `GROUP BY b` clause can use `Linear` mode. +/// - A `GROUP BY a, c` or a `GROUP BY BY c, a` can use +/// `PartiallySorted([0])` or `PartiallySorted([1])` modes, respectively. +/// (The vector stores the index of `a` in the respective PARTITION BY expression.) +/// - A `GROUP BY a, b` or a `GROUP BY b, a` can use `Sorted` mode. +/// +/// Note these are the same examples as above, but with `GROUP BY` instead of +/// `PARTITION BY` to make the examples easier to read. +#[derive(Debug, Clone, PartialEq)] +pub enum InputOrderMode { + /// There is no partial permutation of the expressions satisfying the + /// existing ordering. + Linear, + /// There is a partial permutation of the expressions satisfying the + /// existing ordering. Indices describing the longest partial permutation + /// are stored in the vector. + PartiallySorted(Vec), + /// There is a (full) permutation of the expressions satisfying the + /// existing ordering. + Sorted, +} diff --git a/datafusion/core/src/physical_plan/empty.rs b/datafusion/physical-plan/src/placeholder_row.rs similarity index 52% rename from datafusion/core/src/physical_plan/empty.rs rename to datafusion/physical-plan/src/placeholder_row.rs index 627444ffd94d9..94f32788530be 100644 --- a/datafusion/core/src/physical_plan/empty.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -15,59 +15,49 @@ // specific language governing permissions and limitations // under the License. -//! EmptyRelation execution plan +//! EmptyRelation produce_one_row=true execution plan use std::any::Any; use std::sync::Arc; -use crate::physical_plan::{ - memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning, -}; +use super::expressions::PhysicalSortExpr; +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning}; + use arrow::array::{ArrayRef, NullArray}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; -use log::trace; - -use super::expressions::PhysicalSortExpr; -use super::{common, SendableRecordBatchStream, Statistics}; - +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; -/// Execution plan for empty relation (produces no rows) +use log::trace; + +/// Execution plan for empty relation with produce_one_row=true #[derive(Debug)] -pub struct EmptyExec { - /// Specifies whether this exec produces a row or not - produce_one_row: bool, +pub struct PlaceholderRowExec { /// The schema for the produced row schema: SchemaRef, /// Number of partitions partitions: usize, } -impl EmptyExec { - /// Create a new EmptyExec - pub fn new(produce_one_row: bool, schema: SchemaRef) -> Self { - EmptyExec { - produce_one_row, +impl PlaceholderRowExec { + /// Create a new PlaceholderRowExec + pub fn new(schema: SchemaRef) -> Self { + PlaceholderRowExec { schema, partitions: 1, } } - /// Create a new EmptyExec with specified partition number + /// Create a new PlaceholderRowExecPlaceholderRowExec with specified partition number pub fn with_partitions(mut self, partitions: usize) -> Self { self.partitions = partitions; self } - /// Specifies whether this exec produces a row or not - pub fn produce_one_row(&self) -> bool { - self.produce_one_row - } - fn data(&self) -> Result> { - let batch = if self.produce_one_row { + Ok({ let n_field = self.schema.fields.len(); // hack for https://github.com/apache/arrow-datafusion/pull/3242 let n_field = if n_field == 0 { 1 } else { n_field }; @@ -86,15 +76,25 @@ impl EmptyExec { }) .collect(), )?] - } else { - vec![] - }; + }) + } +} - Ok(batch) +impl DisplayAs for PlaceholderRowExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PlaceholderRowExec") + } + } } } -impl ExecutionPlan for EmptyExec { +impl ExecutionPlan for PlaceholderRowExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -121,10 +121,7 @@ impl ExecutionPlan for EmptyExec { self: Arc, _: Vec>, ) -> Result> { - Ok(Arc::new(EmptyExec::new( - self.produce_one_row, - self.schema.clone(), - ))) + Ok(Arc::new(PlaceholderRowExec::new(self.schema.clone()))) } fn execute( @@ -132,13 +129,14 @@ impl ExecutionPlan for EmptyExec { partition: usize, context: Arc, ) -> Result { - trace!("Start EmptyExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); if partition >= self.partitions { - return Err(DataFusionError::Internal(format!( - "EmptyExec invalid partition {} (expected less than {})", - partition, self.partitions - ))); + return internal_err!( + "PlaceholderRowExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); } Ok(Box::pin(MemoryStream::try_new( @@ -148,66 +146,37 @@ impl ExecutionPlan for EmptyExec { )?)) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "EmptyExec: produce_one_row={}", self.produce_one_row) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = self .data() - .expect("Create empty RecordBatch should not fail"); - common::compute_record_batch_statistics(&[batch], &self.schema, None) + .expect("Create single row placeholder RecordBatch should not fail"); + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) } } #[cfg(test)] mod tests { use super::*; - use crate::physical_plan::with_new_children_if_necessary; - use crate::prelude::SessionContext; - use crate::{physical_plan::common, test_util}; - - #[tokio::test] - async fn empty() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - - let empty = EmptyExec::new(false, schema.clone()); - assert_eq!(empty.schema(), schema); - - // we should have no results - let iter = empty.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; - assert!(batches.is_empty()); - - Ok(()) - } + use crate::with_new_children_if_necessary; + use crate::{common, test}; #[test] fn with_new_children() -> Result<()> { - let schema = test_util::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(false, schema.clone())); - let empty_with_row = Arc::new(EmptyExec::new(true, schema)); + let schema = test::aggr_test_schema(); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); - assert_eq!(empty.schema(), empty2.schema()); + let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let empty_with_row_2 = - with_new_children_if_necessary(empty_with_row.clone(), vec![])?.into(); - assert_eq!(empty_with_row.schema(), empty_with_row_2.schema()); + let placeholder_2 = + with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + assert_eq!(placeholder.schema(), placeholder_2.schema()); - let too_many_kids = vec![empty2]; + let too_many_kids = vec![placeholder_2]; assert!( - with_new_children_if_necessary(empty, too_many_kids).is_err(), + with_new_children_if_necessary(placeholder, too_many_kids).is_err(), "expected error when providing list of kids" ); Ok(()) @@ -215,25 +184,23 @@ mod tests { #[tokio::test] async fn invalid_execute() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(false, schema); + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); // ask for the wrong partition - assert!(empty.execute(1, task_ctx.clone()).is_err()); - assert!(empty.execute(20, task_ctx).is_err()); + assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(20, task_ctx).is_err()); Ok(()) } #[tokio::test] async fn produce_one_row() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - let empty = EmptyExec::new(true, schema); + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); + let placeholder = PlaceholderRowExec::new(schema); - let iter = empty.execute(0, task_ctx)?; + let iter = placeholder.execute(0, task_ctx)?; let batches = common::collect(iter).await?; // should have one item @@ -244,14 +211,13 @@ mod tests { #[tokio::test] async fn produce_one_row_multiple_partition() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); + let task_ctx = Arc::new(TaskContext::default()); + let schema = test::aggr_test_schema(); let partitions = 3; - let empty = EmptyExec::new(true, schema).with_partitions(partitions); + let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); for n in 0..partitions { - let iter = empty.execute(n, task_ctx.clone())?; + let iter = placeholder.execute(n, task_ctx.clone())?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs new file mode 100644 index 0000000000000..cc2ab62049ed5 --- /dev/null +++ b/datafusion/physical-plan/src/projection.rs @@ -0,0 +1,487 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the projection execution plan. A projection determines which columns or expressions +//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example +//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the +//! projection expressions. `SELECT` without `FROM` will only evaluate expressions. + +use std::any::Any; +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::expressions::{Column, PhysicalSortExpr}; +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::{ + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, +}; + +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_common::stats::Precision; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::{Literal, UnKnownColumn}; +use datafusion_physical_expr::EquivalenceProperties; + +use futures::stream::{Stream, StreamExt}; +use log::trace; + +/// Execution plan for a projection +#[derive(Debug, Clone)] +pub struct ProjectionExec { + /// The projection expressions stored as tuples of (expression, output column name) + pub(crate) expr: Vec<(Arc, String)>, + /// The schema once the projection has been applied to the input + schema: SchemaRef, + /// The input plan + input: Arc, + /// The output ordering + output_ordering: Option>, + /// The mapping used to normalize expressions like Partitioning and + /// PhysicalSortExpr that maps input to output + projection_mapping: ProjectionMapping, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl ProjectionExec { + /// Create a projection on an input + pub fn try_new( + expr: Vec<(Arc, String)>, + input: Arc, + ) -> Result { + let input_schema = input.schema(); + + let fields: Result> = expr + .iter() + .map(|(e, name)| { + let mut field = Field::new( + name, + e.data_type(&input_schema)?, + e.nullable(&input_schema)?, + ); + field.set_metadata( + get_field_metadata(e, &input_schema).unwrap_or_default(), + ); + + Ok(field) + }) + .collect(); + + let schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + // construct a map from the input expressions to the output expression of the Projection + let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; + + let input_eqs = input.equivalence_properties(); + let project_eqs = input_eqs.project(&projection_mapping, schema.clone()); + let output_ordering = project_eqs.oeq_class().output_ordering(); + + Ok(Self { + expr, + schema, + input, + output_ordering, + projection_mapping, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// The projection expressions stored as tuples of (expression, output column name) + pub fn expr(&self) -> &[(Arc, String)] { + &self.expr + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for ProjectionExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let expr: Vec = self + .expr + .iter() + .map(|(e, alias)| { + let e = e.to_string(); + if &e != alias { + format!("{e} as {alias}") + } else { + e + } + }) + .collect(); + + write!(f, "ProjectionExec: expr=[{}]", expr.join(", ")) + } + } + } +} + +impl ExecutionPlan for ProjectionExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + // Output partition need to respect the alias + let input_partition = self.input.output_partitioning(); + let input_eq_properties = self.input.equivalence_properties(); + if let Partitioning::Hash(exprs, part) = input_partition { + let normalized_exprs = exprs + .into_iter() + .map(|expr| { + input_eq_properties + .project_expr(&expr, &self.projection_mapping) + .unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) + }) + .collect(); + Partitioning::Hash(normalized_exprs, part) + } else { + input_partition + } + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + fn maintains_input_order(&self) -> Vec { + // tell optimizer this operator doesn't reorder its input + vec![true] + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + self.input + .equivalence_properties() + .project(&self.projection_mapping, self.schema()) + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) + } + + fn benefits_from_input_partitioning(&self) -> Vec { + let all_simple_exprs = self + .expr + .iter() + .all(|(e, _)| e.as_any().is::() || e.as_any().is::()); + // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, + // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. + vec![!all_simple_exprs] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + Ok(Box::pin(ProjectionStream { + schema: self.schema.clone(), + expr: self.expr.iter().map(|x| x.0.clone()).collect(), + input: self.input.execute(partition, context)?, + baseline_metrics: BaselineMetrics::new(&self.metrics, partition), + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(stats_projection( + self.input.statistics()?, + self.expr.iter().map(|(e, _)| Arc::clone(e)), + self.schema.clone(), + )) + } +} + +/// If e is a direct column reference, returns the field level +/// metadata for that field, if any. Otherwise returns None +fn get_field_metadata( + e: &Arc, + input_schema: &Schema, +) -> Option> { + // Look up field by index in schema (not NAME as there can be more than one + // column with the same name) + e.as_any() + .downcast_ref::() + .map(|column| input_schema.field(column.index()).metadata()) + .cloned() +} + +fn stats_projection( + mut stats: Statistics, + exprs: impl Iterator>, + schema: SchemaRef, +) -> Statistics { + let mut primitive_row_size = 0; + let mut primitive_row_size_possible = true; + let mut column_statistics = vec![]; + for expr in exprs { + let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { + stats.column_statistics[col.index()].clone() + } else { + // TODO stats: estimate more statistics from expressions + // (expressions should compute their statistics themselves) + ColumnStatistics::new_unknown() + }; + column_statistics.push(col_stats); + if let Ok(data_type) = expr.data_type(&schema) { + if let Some(value) = data_type.primitive_width() { + primitive_row_size += value; + continue; + } + } + primitive_row_size_possible = false; + } + + if primitive_row_size_possible { + stats.total_byte_size = + Precision::Exact(primitive_row_size).multiply(&stats.num_rows); + } + stats.column_statistics = column_statistics; + stats +} + +impl ProjectionStream { + fn batch_project(&self, batch: &RecordBatch) -> Result { + // records time on drop + let _timer = self.baseline_metrics.elapsed_compute().timer(); + let arrays = self + .expr + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + + if arrays.is_empty() { + let options = + RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(self.schema.clone(), arrays, &options) + .map_err(Into::into) + } else { + RecordBatch::try_new(self.schema.clone(), arrays).map_err(Into::into) + } + } +} + +/// Projection iterator +struct ProjectionStream { + schema: SchemaRef, + expr: Vec>, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, +} + +impl Stream for ProjectionStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Some(self.batch_project(&batch)), + other => other, + }); + + self.baseline_metrics.record_poll(poll) + } + + fn size_hint(&self) -> (usize, Option) { + // same number of record batches + self.input.size_hint() + } +} + +impl RecordBatchStream for ProjectionStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::collect; + use crate::expressions; + use crate::test; + + use arrow_schema::DataType; + use datafusion_common::ScalarValue; + + #[tokio::test] + async fn project_no_column() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + + let exec = test::scan_partitioned(1); + let expected = collect(exec.execute(0, task_ctx.clone())?).await.unwrap(); + + let projection = ProjectionExec::try_new(vec![], exec)?; + let stream = projection.execute(0, task_ctx.clone())?; + let output = collect(stream).await.unwrap(); + assert_eq!(output.len(), expected.len()); + + Ok(()) + } + + fn get_stats() -> Statistics { + Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), + }, + ColumnStatistics { + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), + }, + ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, + }, + ], + } + } + + fn get_schema() -> Schema { + let field_0 = Field::new("col0", DataType::Int64, false); + let field_1 = Field::new("col1", DataType::Utf8, false); + let field_2 = Field::new("col2", DataType::Float32, false); + Schema::new(vec![field_0, field_1, field_2]) + } + #[tokio::test] + async fn test_stats_projection_columns_only() { + let source = get_stats(); + let schema = get_schema(); + + let exprs: Vec> = vec![ + Arc::new(expressions::Column::new("col1", 1)), + Arc::new(expressions::Column::new("col0", 0)), + ]; + + let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + + let expected = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), + }, + ], + }; + + assert_eq!(result, expected); + } + + #[tokio::test] + async fn test_stats_projection_column_with_primitive_width_only() { + let source = get_stats(); + let schema = get_schema(); + + let exprs: Vec> = vec![ + Arc::new(expressions::Column::new("col2", 2)), + Arc::new(expressions::Column::new("col0", 0)), + ]; + + let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + + let expected = Statistics { + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(60), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), + }, + ], + }; + + assert_eq!(result, expected); + } +} diff --git a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs similarity index 97% rename from datafusion/core/src/physical_plan/repartition/distributor_channels.rs rename to datafusion/physical-plan/src/repartition/distributor_channels.rs index d9466d647cf33..e71b88467bccd 100644 --- a/datafusion/core/src/physical_plan/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -83,6 +83,19 @@ pub fn channels( (senders, receivers) } +type PartitionAwareSenders = Vec>>; +type PartitionAwareReceivers = Vec>>; + +/// Create `n_out` empty channels for each of the `n_in` inputs. +/// This way, each distinct partition will communicate via a dedicated channel. +/// This SPSC structure enables us to track which partition input data comes from. +pub fn partition_aware_channels( + n_in: usize, + n_out: usize, +) -> (PartitionAwareSenders, PartitionAwareReceivers) { + (0..n_in).map(|_| channels(n_out)).unzip() +} + /// Erroring during [send](DistributionSender::send). /// /// This occurs when the [receiver](DistributionReceiver) is gone. diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs similarity index 64% rename from datafusion/core/src/physical_plan/repartition/mod.rs rename to datafusion/physical-plan/src/repartition/mod.rs index d7dc54afd6e14..24f227d8a5352 100644 --- a/datafusion/core/src/physical_plan/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -15,44 +15,49 @@ // specific language governing permissions and limitations // under the License. -//! The repartition operator maps N input partitions to M output partitions based on a -//! partitioning scheme. +//! This file implements the [`RepartitionExec`] operator, which maps N input +//! partitions to M output partitions based on a partitioning scheme, optionally +//! maintaining the order of the input rows in the output. use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{any::Any, vec}; -use crate::physical_plan::hash_utils::create_hashes; -use crate::physical_plan::repartition::distributor_channels::channels; -use crate::physical_plan::{ - DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, -}; use arrow::array::{ArrayRef, UInt64Builder}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; -use datafusion_execution::memory_pool::MemoryConsumer; +use futures::stream::Stream; +use futures::{FutureExt, StreamExt}; +use hashbrown::HashMap; use log::trace; +use parking_lot::Mutex; +use tokio::task::JoinHandle; -use self::distributor_channels::{DistributionReceiver, DistributionSender}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; + +use crate::common::transpose; +use crate::hash_utils::create_hashes; +use crate::metrics::BaselineMetrics; +use crate::repartition::distributor_channels::{channels, partition_aware_channels}; +use crate::sorts::streaming_merge; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; use super::common::{AbortOnDropMany, AbortOnDropSingle, SharedMemoryReservation}; use super::expressions::PhysicalSortExpr; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalExpr; -use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; -use hashbrown::HashMap; -use parking_lot::Mutex; -use tokio::task::JoinHandle; +use self::distributor_channels::{DistributionReceiver, DistributionSender}; mod distributor_channels; type MaybeBatch = Option>; +type InputPartitionsToCurrentPartitionSender = Vec>; +type InputPartitionsToCurrentPartitionReceiver = Vec>; /// Inner state of [`RepartitionExec`]. #[derive(Debug)] @@ -62,8 +67,8 @@ struct RepartitionExecState { channels: HashMap< usize, ( - DistributionSender, - DistributionReceiver, + InputPartitionsToCurrentPartitionSender, + InputPartitionsToCurrentPartitionReceiver, SharedMemoryReservation, ), >, @@ -110,11 +115,7 @@ impl BatchPartitioner { random_state: ahash::RandomState::with_seeds(0, 0, 0, 0), hash_buffer: vec![], }, - other => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported repartitioning scheme {other:?}" - ))) - } + other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"), }; Ok(Self { state, timer }) @@ -168,9 +169,7 @@ impl BatchPartitioner { let arrays = exprs .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) + .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; hash_buffer.clear(); @@ -230,8 +229,68 @@ impl BatchPartitioner { } } -/// The repartition operator maps N input partitions to M output partitions based on a -/// partitioning scheme. No guarantees are made about the order of the resulting partitions. +/// Maps `N` input partitions to `M` output partitions based on a +/// [`Partitioning`] scheme. +/// +/// # Background +/// +/// DataFusion, like most other commercial systems, with the +/// notable exception of DuckDB, uses the "Exchange Operator" based +/// approach to parallelism which works well in practice given +/// sufficient care in implementation. +/// +/// DataFusion's planner picks the target number of partitions and +/// then `RepartionExec` redistributes [`RecordBatch`]es to that number +/// of output partitions. +/// +/// For example, given `target_partitions=3` (trying to use 3 cores) +/// but scanning an input with 2 partitions, `RepartitionExec` can be +/// used to get 3 even streams of `RecordBatch`es +/// +/// +///```text +/// ▲ ▲ ▲ +/// │ │ │ +/// │ │ │ +/// │ │ │ +///┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +///│ GroupBy │ │ GroupBy │ │ GroupBy │ +///│ (Partial) │ │ (Partial) │ │ (Partial) │ +///└───────────────┘ └───────────────┘ └───────────────┘ +/// ▲ ▲ ▲ +/// └──────────────────┼──────────────────┘ +/// │ +/// ┌─────────────────────────┐ +/// │ RepartitionExec │ +/// │ (hash/round robin) │ +/// └─────────────────────────┘ +/// ▲ ▲ +/// ┌───────────┘ └───────────┐ +/// │ │ +/// │ │ +/// .─────────. .─────────. +/// ,─' '─. ,─' '─. +/// ; Input : ; Input : +/// : Partition 0 ; : Partition 1 ; +/// ╲ ╱ ╲ ╱ +/// '─. ,─' '─. ,─' +/// `───────' `───────' +///``` +/// +/// # Output Ordering +/// +/// If more than one stream is being repartitioned, the output will be some +/// arbitrary interleaving (and thus unordered) unless +/// [`Self::with_preserve_order`] specifies otherwise. +/// +/// # Footnote +/// +/// The "Exchange Operator" was first described in the 1989 paper +/// [Encapsulation of parallelism in the Volcano query processing +/// system +/// Paper](https://w6113.github.io/files/papers/volcanoparallelism-89.pdf) +/// which uses the term "Exchange" for the concept of repartitioning +/// data across threads. #[derive(Debug)] pub struct RepartitionExec { /// Input execution plan @@ -245,6 +304,10 @@ pub struct RepartitionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, + + /// Boolean flag to decide whether to preserve ordering. If true means + /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. + preserve_order: bool, } #[derive(Debug, Clone)] @@ -252,7 +315,7 @@ struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches fetch_time: metrics::Time, /// Time in nanos to perform repartitioning - repart_time: metrics::Time, + repartition_time: metrics::Time, /// Time in nanos for sending resulting batches to channels send_time: metrics::Time, } @@ -282,7 +345,7 @@ impl RepartitionMetrics { Self { fetch_time, - repart_time, + repartition_time: repart_time, send_time, } } @@ -298,6 +361,50 @@ impl RepartitionExec { pub fn partitioning(&self) -> &Partitioning { &self.partitioning } + + /// Get preserve_order flag of the RepartitionExecutor + /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec` + pub fn preserve_order(&self) -> bool { + self.preserve_order + } + + /// Get name used to display this Exec + pub fn name(&self) -> &str { + if self.preserve_order { + "SortPreservingRepartitionExec" + } else { + "RepartitionExec" + } + } +} + +impl DisplayAs for RepartitionExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "{}: partitioning={}, input_partitions={}", + self.name(), + self.partitioning, + self.input.output_partitioning().partition_count() + )?; + + if let Some(sort_exprs) = self.sort_exprs() { + write!( + f, + ", sort_exprs={}", + PhysicalSortExpr::format_list(sort_exprs) + )?; + } + Ok(()) + } + } + } } impl ExecutionPlan for RepartitionExec { @@ -317,12 +424,14 @@ impl ExecutionPlan for RepartitionExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - Ok(Arc::new(RepartitionExec::try_new( - children[0].clone(), - self.partitioning.clone(), - )?)) + let mut repartition = + RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + if self.preserve_order { + repartition = repartition.with_preserve_order(); + } + Ok(Arc::new(repartition)) } /// Specifies whether this plan generates an infinite stream of records. @@ -332,6 +441,10 @@ impl ExecutionPlan for RepartitionExec { Ok(children[0]) } + fn benefits_from_input_partitioning(&self) -> Vec { + vec![matches!(self.partitioning, Partitioning::Hash(_, _))] + } + fn output_partitioning(&self) -> Partitioning { self.partitioning.clone() } @@ -345,12 +458,21 @@ impl ExecutionPlan for RepartitionExec { } fn maintains_input_order(&self) -> Vec { - // We preserve ordering when input partitioning is 1 - vec![self.input().output_partitioning().partition_count() <= 1] + if self.preserve_order { + vec![true] + } else { + // We preserve ordering when input partitioning is 1 + vec![self.input().output_partitioning().partition_count() <= 1] + } } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + let mut result = self.input.equivalence_properties(); + // If the ordering is lost, reset the ordering equivalence class. + if !self.maintains_input_order()[0] { + result.clear_orderings(); + } + result } fn execute( @@ -359,7 +481,8 @@ impl ExecutionPlan for RepartitionExec { context: Arc, ) -> Result { trace!( - "Start RepartitionExec::execute for partition: {}", + "Start {}::execute for partition: {}", + self.name(), partition ); // lock mutexes @@ -370,13 +493,29 @@ impl ExecutionPlan for RepartitionExec { // if this is the first partition to be invoked then we need to set up initial state if state.channels.is_empty() { - // create one channel per *output* partition - // note we use a custom channel that ensures there is always data for each receiver - // but limits the amount of buffering if required. - let (txs, rxs) = channels(num_output_partitions); + let (txs, rxs) = if self.preserve_order { + let (txs, rxs) = + partition_aware_channels(num_input_partitions, num_output_partitions); + // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition + let txs = transpose(txs); + let rxs = transpose(rxs); + (txs, rxs) + } else { + // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + // Clone sender for each input partitions + let txs = txs + .into_iter() + .map(|item| vec![item; num_input_partitions]) + .collect::>(); + let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); + (txs, rxs) + }; for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("RepartitionExec[{partition}]")) + MemoryConsumer::new(format!("{}[{partition}]", self.name())) .register(context.memory_pool()), )); state.channels.insert(partition, (tx, rx, reservation)); @@ -389,7 +528,7 @@ impl ExecutionPlan for RepartitionExec { .channels .iter() .map(|(partition, (tx, _rx, reservation))| { - (*partition, (tx.clone(), Arc::clone(reservation))) + (*partition, (tx[i].clone(), Arc::clone(reservation))) }) .collect(); @@ -420,54 +559,76 @@ impl ExecutionPlan for RepartitionExec { } trace!( - "Before returning stream in RepartitionExec::execute for partition: {}", + "Before returning stream in {}::execute for partition: {}", + self.name(), partition ); // now return stream for the specified *output* partition which will // read from the channel - let (_tx, rx, reservation) = state + let (_tx, mut rx, reservation) = state .channels .remove(&partition) .expect("partition not used yet"); - Ok(Box::pin(RepartitionStream { - num_input_partitions, - num_input_partitions_processed: 0, - schema: self.input.schema(), - input: rx, - drop_helper: Arc::clone(&state.abort_helper), - reservation, - })) + + if self.preserve_order { + // Store streams from all the input partitions: + let input_streams = rx + .into_iter() + .map(|receiver| { + Box::pin(PerPartitionStream { + schema: self.schema(), + receiver, + drop_helper: Arc::clone(&state.abort_helper), + reservation: reservation.clone(), + }) as SendableRecordBatchStream + }) + .collect::>(); + // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. + + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]); + + // Merge streams (while preserving ordering) coming from + // input partitions to this partition: + let fetch = None; + let merge_reservation = + MemoryConsumer::new(format!("{}[Merge {partition}]", self.name())) + .register(context.memory_pool()); + streaming_merge( + input_streams, + self.schema(), + sort_exprs, + BaselineMetrics::new(&self.metrics, partition), + context.session_config().batch_size(), + fetch, + merge_reservation, + ) + } else { + Ok(Box::pin(RepartitionStream { + num_input_partitions, + num_input_partitions_processed: 0, + schema: self.input.schema(), + input: rx.swap_remove(0), + drop_helper: Arc::clone(&state.abort_helper), + reservation, + })) + } } fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "RepartitionExec: partitioning={:?}, input_partitions={}", - self.partitioning, - self.input.output_partitioning().partition_count() - ) - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } impl RepartitionExec { - /// Create a new RepartitionExec + /// Create a new RepartitionExec, that produces output `partitioning`, and + /// does not preserve the order of the input (see [`Self::with_preserve_order`] + /// for more details) pub fn try_new( input: Arc, partitioning: Partitioning, @@ -480,40 +641,64 @@ impl RepartitionExec { abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])), })), metrics: ExecutionPlanMetricsSet::new(), + preserve_order: false, }) } + /// Specify if this reparititoning operation should preserve the order of + /// rows from its input when producing output. Preserving order is more + /// expensive at runtime, so should only be set if the output of this + /// operator can take advantage of it. + /// + /// If the input is not ordered, or has only one partition, this is a no op, + /// and the node remains a `RepartitionExec`. + pub fn with_preserve_order(mut self) -> Self { + self.preserve_order = + // If the input isn't ordered, there is no ordering to preserve + self.input.output_ordering().is_some() && + // if there is only one input partition, merging is not required + // to maintain order + self.input.output_partitioning().partition_count() > 1; + self + } + + /// Return the sort expressions that are used to merge + fn sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + if self.preserve_order { + self.input.output_ordering() + } else { + None + } + } + /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// - /// i is the input partition index - /// /// txs hold the output sending channels for each output partition async fn pull_from_input( input: Arc, - i: usize, - mut txs: HashMap< + partition: usize, + mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, - r_metrics: RepartitionMetrics, + metrics: RepartitionMetrics, context: Arc, ) -> Result<()> { let mut partitioner = - BatchPartitioner::try_new(partitioning, r_metrics.repart_time.clone())?; + BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; // execute the child operator - let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i, context)?; + let timer = metrics.fetch_time.timer(); + let mut stream = input.execute(partition, context)?; timer.done(); - // While there are still outputs to send to, keep - // pulling inputs + // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); - while !txs.is_empty() { + while !output_channels.is_empty() { // fetch the next batch - let timer = r_metrics.fetch_time.timer(); + let timer = metrics.fetch_time.timer(); let result = stream.next().await; timer.done(); @@ -527,15 +712,15 @@ impl RepartitionExec { let (partition, batch) = res?; let size = batch.get_array_memory_size(); - let timer = r_metrics.send_time.timer(); + let timer = metrics.send_time.timer(); // if there is still a receiver, send to it - if let Some((tx, reservation)) = txs.get_mut(&partition) { + if let Some((tx, reservation)) = output_channels.get_mut(&partition) { reservation.lock().try_grow(size)?; if tx.send(Some(Ok(batch))).await.is_err() { // If the other end has hung up, it was an early shutdown (e.g. LIMIT) reservation.lock().shrink(size); - txs.remove(&partition); + output_channels.remove(&partition); } } timer.done(); @@ -575,7 +760,7 @@ impl RepartitionExec { /// channels. async fn wait_for_task( input_task: AbortOnDropSingle>, - txs: HashMap>>>, + txs: HashMap>, ) { // wait for completion, and propagate error // note we ignore errors on send (.ok) as that means the receiver has already shutdown. @@ -681,15 +866,72 @@ impl RecordBatchStream for RepartitionStream { } } +/// This struct converts a receiver to a stream. +/// Receiver receives data on an SPSC channel. +struct PerPartitionStream { + /// Schema wrapped by Arc + schema: SchemaRef, + + /// channel containing the repartitioned batches + receiver: DistributionReceiver, + + /// Handle to ensure background tasks are killed when no longer needed. + #[allow(dead_code)] + drop_helper: Arc>, + + /// Memory reservation. + reservation: SharedMemoryReservation, +} + +impl Stream for PerPartitionStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.receiver.recv().poll_unpin(cx) { + Poll::Ready(Some(Some(v))) => { + if let Ok(batch) = &v { + self.reservation + .lock() + .shrink(batch.get_array_memory_size()); + } + Poll::Ready(Some(v)) + } + Poll::Ready(Some(None)) => { + // Input partition has finished sending batches + Poll::Ready(None) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for PerPartitionStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { - use super::*; - use crate::execution::context::SessionConfig; - use crate::prelude::SessionContext; - use crate::test::create_vec_batches; + use std::collections::HashSet; + + use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::UInt32Array; + use futures::FutureExt; + use tokio::task::JoinHandle; + + use datafusion_common::cast::as_string_array; + use datafusion_common::{assert_batches_sorted_eq, exec_err}; + use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::{ - assert_batches_sorted_eq, - physical_plan::{collect, expressions::col, memory::MemoryExec}, test::{ assert_is_pending, exec::{ @@ -697,20 +939,16 @@ mod tests { ErrorExec, MockExec, }, }, + {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{ArrayRef, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_string_array; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use futures::FutureExt; - use std::collections::HashSet; + + use super::*; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let partitions = vec![partition]; // repartition from 1 input to 4 output @@ -730,7 +968,7 @@ mod tests { async fn many_to_one_round_robin() -> Result<()> { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let partitions = vec![partition.clone(), partition.clone(), partition.clone()]; // repartition from 3 input to 1 output @@ -747,7 +985,7 @@ mod tests { async fn many_to_many_round_robin() -> Result<()> { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let partitions = vec![partition.clone(), partition.clone(), partition.clone()]; // repartition from 3 input to 5 output @@ -768,7 +1006,7 @@ mod tests { async fn many_to_many_hash_partition() -> Result<()> { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let partitions = vec![partition.clone(), partition.clone(), partition.clone()]; let output_partitions = repartition( @@ -798,8 +1036,7 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; @@ -824,7 +1061,7 @@ mod tests { tokio::spawn(async move { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let partitions = vec![partition.clone(), partition.clone(), partition.clone()]; @@ -846,8 +1083,7 @@ mod tests { #[tokio::test] async fn unsupported_partitioning() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); // have to send at least one batch through to provoke error let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", @@ -865,7 +1101,7 @@ mod tests { let output_stream = exec.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::physical_plan::common::collect(output_stream) + let result_string = crate::common::collect(output_stream) .await .unwrap_err() .to_string(); @@ -881,8 +1117,7 @@ mod tests { // This generates an error on a call to execute. The error // should be returned and no results produced. - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let input = ErrorExec::new(); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); @@ -892,7 +1127,7 @@ mod tests { let output_stream = exec.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::physical_plan::common::collect(output_stream) + let result_string = crate::common::collect(output_stream) .await .unwrap_err() .to_string(); @@ -904,8 +1139,7 @@ mod tests { #[tokio::test] async fn repartition_with_error_in_stream() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, @@ -914,7 +1148,7 @@ mod tests { // input stream returns one good batch and then one error. The // error should be returned. - let err = Err(DataFusionError::Execution("bad data error".to_string())); + let err = exec_err!("bad data error"); let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch), err], schema); @@ -926,7 +1160,7 @@ mod tests { let output_stream = exec.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::physical_plan::common::collect(output_stream) + let result_string = crate::common::collect(output_stream) .await .unwrap_err() .to_string(); @@ -938,8 +1172,7 @@ mod tests { #[tokio::test] async fn repartition_with_delayed_stream() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, @@ -975,17 +1208,14 @@ mod tests { assert_batches_sorted_eq!(&expected, &expected_batches); let output_stream = exec.execute(0, task_ctx).unwrap(); - let batches = crate::physical_plan::common::collect(output_stream) - .await - .unwrap(); + let batches = crate::common::collect(output_stream).await.unwrap(); assert_batches_sorted_eq!(&expected, &batches); } #[tokio::test] async fn robin_repartition_with_dropping_output_stream() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let partitioning = Partitioning::RoundRobinBatch(2); // The barrier exec waits to be pinged // requires the input to wait at least once) @@ -1005,9 +1235,7 @@ mod tests { input.wait().await; // output stream 1 should *not* error and have one of the input batches - let batches = crate::physical_plan::common::collect(output_stream1) - .await - .unwrap(); + let batches = crate::common::collect(output_stream1).await.unwrap(); let expected = vec![ "+------------------+", @@ -1028,10 +1256,9 @@ mod tests { // wiht different compilers, we will compare the same execution with // and without droping the output stream. async fn hash_repartition_with_dropping_output_stream() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let partitioning = Partitioning::Hash( - vec![Arc::new(crate::physical_plan::expressions::Column::new( + vec![Arc::new(crate::expressions::Column::new( "my_awesome_field", 0, ))], @@ -1043,9 +1270,7 @@ mod tests { let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); input.wait().await; - let batches_without_drop = crate::physical_plan::common::collect(output_stream1) - .await - .unwrap(); + let batches_without_drop = crate::common::collect(output_stream1).await.unwrap(); // run some checks on the result let items_vec = str_batches_to_vec(&batches_without_drop); @@ -1067,9 +1292,7 @@ mod tests { // *before* any outputs are produced std::mem::drop(output_stream0); input.wait().await; - let batches_with_drop = crate::physical_plan::common::collect(output_stream1) - .await - .unwrap(); + let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); assert_eq!(batches_without_drop, batches_with_drop); } @@ -1124,8 +1347,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1148,30 +1370,23 @@ mod tests { #[tokio::test] async fn hash_repartition_avoid_empty_batch() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let batch = RecordBatch::try_from_iter(vec![( "a", Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, )]) .unwrap(); let partitioning = Partitioning::Hash( - vec![Arc::new(crate::physical_plan::expressions::Column::new( - "a", 0, - ))], + vec![Arc::new(crate::expressions::Column::new("a", 0))], 2, ); let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); - let batch0 = crate::physical_plan::common::collect(output_stream0) - .await - .unwrap(); + let batch0 = crate::common::collect(output_stream0).await.unwrap(); let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); - let batch1 = crate::physical_plan::common::collect(output_stream1) - .await - .unwrap(); + let batch1 = crate::common::collect(output_stream1).await.unwrap(); assert!(batch0.is_empty() || batch1.is_empty()); Ok(()) } @@ -1180,19 +1395,17 @@ mod tests { async fn oom() -> Result<()> { // define input partitions let schema = test_schema(); - let partition = create_vec_batches(&schema, 50); + let partition = create_vec_batches(50); let input_partitions = vec![partition]; let partitioning = Partitioning::RoundRobinBatch(4); // setup up context - let session_ctx = SessionContext::with_config_rt( - SessionConfig::default(), - Arc::new( - RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) - .unwrap(), - ), + let runtime = Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)).unwrap(), ); - let task_ctx = session_ctx.task_ctx(); + + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; @@ -1213,4 +1426,146 @@ mod tests { Ok(()) } + + /// Create vector batches + fn create_vec_batches(n: usize) -> Vec { + let batch = create_batch(); + (0..n).map(|_| batch.clone()).collect() + } + + /// Create batch + fn create_batch() -> RecordBatch { + let schema = test_schema(); + RecordBatch::try_new( + schema, + vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], + ) + .unwrap() + } +} + +#[cfg(test)] +mod test { + use arrow_schema::{DataType, Field, Schema, SortOptions}; + + use datafusion_physical_expr::expressions::col; + + use crate::memory::MemoryExec; + use crate::union::UnionExec; + + use super::*; + + /// Asserts that the plan is as expected + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$PLAN`: the plan to optimized + /// + macro_rules! assert_plan { + ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => { + let physical_plan = $PLAN; + let formatted = crate::displayable(&physical_plan).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + }; + } + + #[tokio::test] + async fn test_preserve_order() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source1 = sorted_memory_exec(&schema, sort_exprs.clone()); + let source2 = sorted_memory_exec(&schema, sort_exprs); + // output has multiple partitions, and is sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should preserve order + let expected_plan = [ + "SortPreservingRepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, sort_exprs=c0@0 ASC", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_one_partition() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source = sorted_memory_exec(&schema, sort_exprs); + // output is sorted, but has only a single partition, so no need to sort + let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + #[tokio::test] + async fn test_preserve_order_input_not_sorted() -> Result<()> { + let schema = test_schema(); + let source1 = memory_exec(&schema); + let source2 = memory_exec(&schema); + // output has multiple partitions, but is not sorted + let union = UnionExec::new(vec![source1, source2]); + let exec = + RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) + .unwrap() + .with_preserve_order(); + + // Repartition should not preserve order, as there is no order to preserve + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " MemoryExec: partitions=1, partition_sizes=[0]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_plan!(expected_plan, exec); + Ok(()) + } + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) + } + + fn sort_exprs(schema: &Schema) -> Vec { + let options = SortOptions::default(); + vec![PhysicalSortExpr { + expr: col("c0", schema).unwrap(), + options, + }] + } + + fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + } + + fn sorted_memory_exec( + schema: &SchemaRef, + sort_exprs: Vec, + ) -> Arc { + Arc::new( + MemoryExec::try_new(&[vec![]], schema.clone(), None) + .unwrap() + .with_sort_information(vec![sort_exprs]), + ) + } } diff --git a/datafusion/core/src/physical_plan/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs similarity index 88% rename from datafusion/core/src/physical_plan/sorts/builder.rs rename to datafusion/physical-plan/src/sorts/builder.rs index 1c5ec356eed9f..3527d57382230 100644 --- a/datafusion/core/src/physical_plan/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -19,6 +19,7 @@ use arrow::compute::interleave; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; #[derive(Debug, Copy, Clone, Default)] struct BatchCursor { @@ -37,6 +38,9 @@ pub struct BatchBuilder { /// Maintain a list of [`RecordBatch`] and their corresponding stream batches: Vec<(usize, RecordBatch)>, + /// Accounts for memory used by buffered batches + reservation: MemoryReservation, + /// The current [`BatchCursor`] for each stream cursors: Vec, @@ -47,23 +51,31 @@ pub struct BatchBuilder { impl BatchBuilder { /// Create a new [`BatchBuilder`] with the provided `stream_count` and `batch_size` - pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> Self { + pub fn new( + schema: SchemaRef, + stream_count: usize, + batch_size: usize, + reservation: MemoryReservation, + ) -> Self { Self { schema, batches: Vec::with_capacity(stream_count * 2), cursors: vec![BatchCursor::default(); stream_count], indices: Vec::with_capacity(batch_size), + reservation, } } /// Append a new batch in `stream_idx` - pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) { + pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> { + self.reservation.try_grow(batch.get_array_memory_size())?; let batch_idx = self.batches.len(); self.batches.push((stream_idx, batch)); self.cursors[stream_idx] = BatchCursor { batch_idx, row_idx: 0, - } + }; + Ok(()) } /// Append the next row from `stream_idx` @@ -119,7 +131,7 @@ impl BatchBuilder { // We can therefore drop all but the last batch for each stream let mut batch_idx = 0; let mut retained = 0; - self.batches.retain(|(stream_idx, _)| { + self.batches.retain(|(stream_idx, batch)| { let stream_cursor = &mut self.cursors[*stream_idx]; let retain = stream_cursor.batch_idx == batch_idx; batch_idx += 1; @@ -127,6 +139,8 @@ impl BatchBuilder { if retain { stream_cursor.batch_idx = retained; retained += 1; + } else { + self.reservation.shrink(batch.get_array_memory_size()); } retain }); diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs new file mode 100644 index 0000000000000..df90c97faf68e --- /dev/null +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -0,0 +1,462 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; + +use arrow::buffer::ScalarBuffer; +use arrow::compute::SortOptions; +use arrow::datatypes::ArrowNativeTypeOp; +use arrow::row::Rows; +use arrow_array::types::ByteArrayType; +use arrow_array::{ + Array, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow_buffer::{Buffer, OffsetBuffer}; +use datafusion_execution::memory_pool::MemoryReservation; + +/// A comparable collection of values for use with [`Cursor`] +/// +/// This is a trait as there are several specialized implementations, such as for +/// single columns or for normalized multi column keys ([`Rows`]) +pub trait CursorValues { + fn len(&self) -> usize; + + /// Returns true if `l[l_idx] == r[r_idx]` + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool; + + /// Returns comparison of `l[l_idx]` and `r[r_idx]` + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering; +} + +/// A comparable cursor, used by sort operations +/// +/// A `Cursor` is a pointer into a collection of rows, stored in +/// [`CursorValues`] +/// +/// ```text +/// +/// ┌───────────────────────┐ +/// │ │ ┌──────────────────────┐ +/// │ ┌─────────┐ ┌─────┐ │ ─ ─ ─ ─│ Cursor │ +/// │ │ 1 │ │ A │ │ │ └──────────────────────┘ +/// │ ├─────────┤ ├─────┤ │ +/// │ │ 2 │ │ A │◀─ ┼ ─ ┘ Cursor tracks an +/// │ └─────────┘ └─────┘ │ offset within a +/// │ ... ... │ CursorValues +/// │ │ +/// │ ┌─────────┐ ┌─────┐ │ +/// │ │ 3 │ │ E │ │ +/// │ └─────────┘ └─────┘ │ +/// │ │ +/// │ CursorValues │ +/// └───────────────────────┘ +/// +/// +/// Store logical rows using +/// one of several formats, +/// with specialized +/// implementations +/// depending on the column +/// types +#[derive(Debug)] +pub struct Cursor { + offset: usize, + values: T, +} + +impl Cursor { + /// Create a [`Cursor`] from the given [`CursorValues`] + pub fn new(values: T) -> Self { + Self { offset: 0, values } + } + + /// Returns true if there are no more rows in this cursor + pub fn is_finished(&self) -> bool { + self.offset == self.values.len() + } + + /// Advance the cursor, returning the previous row index + pub fn advance(&mut self) -> usize { + let t = self.offset; + self.offset += 1; + t + } +} + +impl PartialEq for Cursor { + fn eq(&self, other: &Self) -> bool { + T::eq(&self.values, self.offset, &other.values, other.offset) + } +} + +impl Eq for Cursor {} + +impl PartialOrd for Cursor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Cursor { + fn cmp(&self, other: &Self) -> Ordering { + T::compare(&self.values, self.offset, &other.values, other.offset) + } +} + +/// Implements [`CursorValues`] for [`Rows`] +/// +/// Used for sorting when there are multiple columns in the sort key +#[derive(Debug)] +pub struct RowValues { + rows: Rows, + + /// Tracks for the memory used by in the `Rows` of this + /// cursor. Freed on drop + #[allow(dead_code)] + reservation: MemoryReservation, +} + +impl RowValues { + /// Create a new [`RowValues`] from `rows` and a `reservation` + /// that tracks its memory. There must be at least one row + /// + /// Panics if the reservation is not for exactly `rows.size()` + /// bytes or if `rows` is empty. + pub fn new(rows: Rows, reservation: MemoryReservation) -> Self { + assert_eq!( + rows.size(), + reservation.size(), + "memory reservation mismatch" + ); + assert!(rows.num_rows() > 0); + Self { rows, reservation } + } +} + +impl CursorValues for RowValues { + fn len(&self) -> usize { + self.rows.num_rows() + } + + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.rows.row(l_idx) == r.rows.row(r_idx) + } + + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.rows.row(l_idx).cmp(&r.rows.row(r_idx)) + } +} + +/// An [`Array`] that can be converted into [`CursorValues`] +pub trait CursorArray: Array + 'static { + type Values: CursorValues; + + fn values(&self) -> Self::Values; +} + +impl CursorArray for PrimitiveArray { + type Values = PrimitiveValues; + + fn values(&self) -> Self::Values { + PrimitiveValues(self.values().clone()) + } +} + +#[derive(Debug)] +pub struct PrimitiveValues(ScalarBuffer); + +impl CursorValues for PrimitiveValues { + fn len(&self) -> usize { + self.0.len() + } + + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.0[l_idx].is_eq(r.0[r_idx]) + } + + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.0[l_idx].compare(r.0[r_idx]) + } +} + +pub struct ByteArrayValues { + offsets: OffsetBuffer, + values: Buffer, +} + +impl ByteArrayValues { + fn value(&self, idx: usize) -> &[u8] { + assert!(idx < self.len()); + // Safety: offsets are valid and checked bounds above + unsafe { + let start = self.offsets.get_unchecked(idx).as_usize(); + let end = self.offsets.get_unchecked(idx + 1).as_usize(); + self.values.get_unchecked(start..end) + } + } +} + +impl CursorValues for ByteArrayValues { + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + l.value(l_idx) == r.value(r_idx) + } + + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + l.value(l_idx).cmp(r.value(r_idx)) + } +} + +impl CursorArray for GenericByteArray { + type Values = ByteArrayValues; + + fn values(&self) -> Self::Values { + ByteArrayValues { + offsets: self.offsets().clone(), + values: self.values().clone(), + } + } +} + +/// A collection of sorted, nullable [`CursorValues`] +/// +/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering +#[derive(Debug)] +pub struct ArrayValues { + values: T, + // If nulls first, the first non-null index + // Otherwise, the first null index + null_threshold: usize, + options: SortOptions, +} + +impl ArrayValues { + /// Create a new [`ArrayValues`] from the provided `values` sorted according + /// to `options`. + /// + /// Panics if the array is empty + pub fn new>(options: SortOptions, array: &A) -> Self { + assert!(array.len() > 0, "Empty array passed to FieldCursor"); + let null_threshold = match options.nulls_first { + true => array.null_count(), + false => array.len() - array.null_count(), + }; + + Self { + values: array.values(), + null_threshold, + options, + } + } + + fn is_null(&self, idx: usize) -> bool { + (idx < self.null_threshold) == self.options.nulls_first + } +} + +impl CursorValues for ArrayValues { + fn len(&self) -> usize { + self.values.len() + } + + fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { + match (l.is_null(l_idx), r.is_null(r_idx)) { + (true, true) => true, + (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx), + _ => false, + } + } + + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { + match (l.is_null(l_idx), r.is_null(r_idx)) { + (true, true) => Ordering::Equal, + (true, false) => match l.options.nulls_first { + true => Ordering::Less, + false => Ordering::Greater, + }, + (false, true) => match l.options.nulls_first { + true => Ordering::Greater, + false => Ordering::Less, + }, + (false, false) => match l.options.descending { + true => T::compare(&r.values, r_idx, &l.values, l_idx), + false => T::compare(&l.values, l_idx, &r.values, r_idx), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn new_primitive( + options: SortOptions, + values: ScalarBuffer, + null_count: usize, + ) -> Cursor>> { + let null_threshold = match options.nulls_first { + true => null_count, + false => values.len() - null_count, + }; + + let values = ArrayValues { + values: PrimitiveValues(values), + null_threshold, + options, + }; + + Cursor::new(values) + } + + #[test] + fn test_primitive_nulls_first() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]); + let mut a = new_primitive(options, buffer, 1); + let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]); + let mut b = new_primitive(options, buffer, 2); + + // NULL == NULL + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL == NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL < -2 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 1 > -2 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 1 > -1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 1 == 1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // 9 > 1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 9 > 2 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]); + let mut a = new_primitive(options, buffer, 2); + let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]); + let mut b = new_primitive(options, buffer, 2); + + // 0 > -1 + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 0 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 1 < NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // NULL = NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + let options = SortOptions { + descending: true, + nulls_first: false, + }; + + let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]); + let mut a = new_primitive(options, buffer, 3); + let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]); + let mut b = new_primitive(options, buffer, 2); + + // 6 > 67 + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 6 < -3 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // NULL == NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + let options = SortOptions { + descending: true, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]); + let mut a = new_primitive(options, buffer, 2); + let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]); + let mut b = new_primitive(options, buffer, 1); + + // NULL == NULL + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL == NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL < 4546 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 > 4546 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 6 < -3 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + } +} diff --git a/datafusion/core/src/physical_plan/sorts/index.rs b/datafusion/physical-plan/src/sorts/index.rs similarity index 100% rename from datafusion/core/src/physical_plan/sorts/index.rs rename to datafusion/physical-plan/src/sorts/index.rs diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs similarity index 77% rename from datafusion/core/src/physical_plan/sorts/merge.rs rename to datafusion/physical-plan/src/sorts/merge.rs index d8a3cdef4d686..422ff3aebdb39 100644 --- a/datafusion/core/src/physical_plan/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -15,77 +15,27 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_plan::metrics::BaselineMetrics; -use crate::physical_plan::sorts::builder::BatchBuilder; -use crate::physical_plan::sorts::cursor::Cursor; -use crate::physical_plan::sorts::stream::{ - FieldCursorStream, PartitionedStream, RowCursorStream, -}; -use crate::physical_plan::{ - PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, -}; -use arrow::datatypes::{DataType, SchemaRef}; +//! Merge that deals with an arbitrary size of streaming inputs. +//! This is an order-preserving merge. + +use crate::metrics::BaselineMetrics; +use crate::sorts::builder::BatchBuilder; +use crate::sorts::cursor::{Cursor, CursorValues}; +use crate::sorts::stream::PartitionedStream; +use crate::RecordBatchStream; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow_array::*; use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; -macro_rules! primitive_merge_helper { - ($t:ty, $($v:ident),+) => { - merge_helper!(PrimitiveArray<$t>, $($v),+) - }; -} - -macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{ - let streams = FieldCursorStream::<$t>::new($sort, $streams); - return Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - $schema, - $tracking_metrics, - $batch_size, - ))); - }}; -} - -/// Perform a streaming merge of [`SendableRecordBatchStream`] -pub(crate) fn streaming_merge( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, -) -> Result { - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size) - _ => {} - } - } - - let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?; - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - ))) -} - /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, /// The sorted input streams to merge together @@ -138,21 +88,29 @@ struct SortPreservingMergeStream { /// target batch size batch_size: usize, - /// Vector that holds cursors for each non-exhausted input partition - cursors: Vec>, + /// Cursors for each input partition. `None` means the input is exhausted + cursors: Vec>>, + + /// Optional number of rows to fetch + fetch: Option, + + /// number of rows produced + produced: usize, } -impl SortPreservingMergeStream { - fn new( +impl SortPreservingMergeStream { + pub(crate) fn new( streams: CursorStream, schema: SchemaRef, metrics: BaselineMetrics, batch_size: usize, + fetch: Option, + reservation: MemoryReservation, ) -> Self { let stream_count = streams.partitions(); Self { - in_progress: BatchBuilder::new(schema, stream_count, batch_size), + in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation), streams, metrics, aborted: false, @@ -160,6 +118,8 @@ impl SortPreservingMergeStream { loser_tree: vec![], loser_tree_adjusted: false, batch_size, + fetch, + produced: 0, } } @@ -180,9 +140,8 @@ impl SortPreservingMergeStream { None => Poll::Ready(Ok(())), Some(Err(e)) => Poll::Ready(Err(e)), Some(Ok((cursor, batch))) => { - self.cursors[idx] = Some(cursor); - self.in_progress.push_batch(idx, batch); - Poll::Ready(Ok(())) + self.cursors[idx] = Some(Cursor::new(cursor)); + Poll::Ready(self.in_progress.push_batch(idx, batch)) } } } @@ -227,15 +186,27 @@ impl SortPreservingMergeStream { if self.advance(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); - if self.in_progress.len() < self.batch_size { + + // stop sorting if fetch has been reached + if self.fetch_reached() { + self.aborted = true; + } else if self.in_progress.len() < self.batch_size { continue; } } + self.produced += self.in_progress.len(); + return Poll::Ready(self.in_progress.build_record_batch().transpose()); } } + fn fetch_reached(&mut self) -> bool { + self.fetch + .map(|fetch| self.produced + self.in_progress.len() >= fetch) + .unwrap_or(false) + } + fn advance(&mut self, stream_idx: usize) -> bool { let slot = &mut self.cursors[stream_idx]; match slot.as_mut() { @@ -339,7 +310,7 @@ impl SortPreservingMergeStream { } } -impl Stream for SortPreservingMergeStream { +impl Stream for SortPreservingMergeStream { type Item = Result; fn poll_next( @@ -351,7 +322,7 @@ impl Stream for SortPreservingMergeStream { } } -impl RecordBatchStream for SortPreservingMergeStream { +impl RecordBatchStream for SortPreservingMergeStream { fn schema(&self) -> SchemaRef { self.in_progress.schema().clone() } diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs similarity index 92% rename from datafusion/core/src/physical_plan/sorts/mod.rs rename to datafusion/physical-plan/src/sorts/mod.rs index 567de96c1cfdf..8a1184d3c2b5d 100644 --- a/datafusion/core/src/physical_plan/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -24,6 +24,7 @@ mod merge; pub mod sort; pub mod sort_preserving_merge; mod stream; +pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use merge::streaming_merge; +pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs similarity index 56% rename from datafusion/core/src/physical_plan/sorts/sort.rs rename to datafusion/physical-plan/src/sorts/sort.rs index 58c257c97f996..2d8237011fff6 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -19,39 +19,42 @@ //! It will do in-memory sorting if it has enough memory budget //! but spills to disk if needed. -use crate::physical_plan::common::{batch_byte_size, spawn_buffered, IPCWriter}; -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::metrics::{ +use std::any::Any; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use crate::common::{spawn_buffered, IPCWriter}; +use crate::expressions::PhysicalSortExpr; +use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::physical_plan::sorts::merge::streaming_merge; -use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; -use crate::physical_plan::{ - DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, - SendableRecordBatchStream, Statistics, +use crate::sorts::streaming_merge::streaming_merge; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::topk::TopK; +use crate::{ + DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, + Partitioning, SendableRecordBatchStream, Statistics, }; -pub use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryReservation, }; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; + use futures::{StreamExt, TryStreamExt}; use log::{debug, error, trace}; -use std::any::Any; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::fs::File; -use std::io::BufReader; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tempfile::NamedTempFile; use tokio::sync::mpsc::Sender; use tokio::task; @@ -76,37 +79,171 @@ impl ExternalSorterMetrics { } } -/// Sort arbitrary size of data to get a total order (may spill several times during sorting based on free memory available). +/// Sorts an arbitrary sized, unsorted, stream of [`RecordBatch`]es to +/// a total order. Depending on the input size and memory manager +/// configuration, writes intermediate results to disk ("spills") +/// using Arrow IPC format. +/// +/// # Algorithm /// -/// The basic architecture of the algorithm: /// 1. get a non-empty new batch from input -/// 2. check with the memory manager if we could buffer the batch in memory -/// 2.1 if memory sufficient, then buffer batch in memory, go to 1. -/// 2.2 if the memory threshold is reached, sort all buffered batches and spill to file. -/// buffer the batch in memory, go to 1. -/// 3. when input is exhausted, merge all in memory batches and spills to get a total order. +/// +/// 2. check with the memory manager there is sufficient space to +/// buffer the batch in memory 2.1 if memory sufficient, buffer +/// batch in memory, go to 1. +/// +/// 2.2 if no more memory is available, sort all buffered batches and +/// spill to file. buffer the next batch in memory, go to 1. +/// +/// 3. when input is exhausted, merge all in memory batches and spills +/// to get a total order. +/// +/// # When data fits in available memory +/// +/// If there is sufficient memory, data is sorted in memory to produce the output +/// +/// ```text +/// ┌─────┐ +/// │ 2 │ +/// │ 3 │ +/// │ 1 │─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ 4 │ +/// │ 2 │ │ +/// └─────┘ ▼ +/// ┌─────┐ +/// │ 1 │ In memory +/// │ 4 │─ ─ ─ ─ ─ ─▶ sort/merge ─ ─ ─ ─ ─▶ total sorted output +/// │ 1 │ +/// └─────┘ ▲ +/// ... │ +/// +/// ┌─────┐ │ +/// │ 4 │ +/// │ 3 │─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// └─────┘ +/// +/// in_mem_batches +/// +/// ``` +/// +/// # When data does not fit in available memory +/// +/// When memory is exhausted, data is first sorted and written to one +/// or more spill files on disk: +/// +/// ```text +/// ┌─────┐ .─────────────────. +/// │ 2 │ ( ) +/// │ 3 │ │`─────────────────'│ +/// │ 1 │─ ─ ─ ─ ─ ─ ─ │ ┌────┐ │ +/// │ 4 │ │ │ │ 1 │░ │ +/// │ 2 │ │ │... │░ │ +/// └─────┘ ▼ │ │ 4 │░ ┌ ─ ─ │ +/// ┌─────┐ │ └────┘░ 1 │░ │ +/// │ 1 │ In memory │ ░░░░░░ │ ░░ │ +/// │ 4 │─ ─ ▶ sort/merge ─ ─ ─ ─ ┼ ─ ─ ─ ─ ─▶ ... │░ │ +/// │ 1 │ and write to file │ │ ░░ │ +/// └─────┘ │ 4 │░ │ +/// ... ▲ │ └░─░─░░ │ +/// │ │ ░░░░░░ │ +/// ┌─────┐ │.─────────────────.│ +/// │ 4 │ │ ( ) +/// │ 3 │─ ─ ─ ─ ─ ─ ─ `─────────────────' +/// └─────┘ +/// +/// in_mem_batches spills +/// (file on disk in Arrow +/// IPC format) +/// ``` +/// +/// Once the input is completely read, the spill files are read and +/// merged with any in memory batches to produce a single total sorted +/// output: +/// +/// ```text +/// .─────────────────. +/// ( ) +/// │`─────────────────'│ +/// │ ┌────┐ │ +/// │ │ 1 │░ │ +/// │ │... │─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ +/// │ │ 4 │░ ┌────┐ │ │ +/// │ └────┘░ │ 1 │░ │ ▼ +/// │ ░░░░░░ │ │░ │ +/// │ │... │─ ─│─ ─ ─ ▶ merge ─ ─ ─▶ total sorted output +/// │ │ │░ │ +/// │ │ 4 │░ │ ▲ +/// │ └────┘░ │ │ +/// │ ░░░░░░ │ +/// │.─────────────────.│ │ +/// ( ) +/// `─────────────────' │ +/// spills +/// │ +/// +/// │ +/// +/// ┌─────┐ │ +/// │ 1 │ +/// │ 4 │─ ─ ─ ─ │ +/// └─────┘ │ +/// ... In memory +/// └ ─ ─ ─▶ sort/merge +/// ┌─────┐ +/// │ 4 │ ▲ +/// │ 3 │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// └─────┘ +/// +/// in_mem_batches +/// ``` struct ExternalSorter { + /// schema of the output (and the input) schema: SchemaRef, + /// Potentially unsorted in memory buffer in_mem_batches: Vec, + /// if `Self::in_mem_batches` are sorted in_mem_batches_sorted: bool, - spills: Vec, + /// If data has previously been spilled, the locations of the + /// spill files (in Arrow IPC format) + spills: Vec, /// Sort expressions expr: Arc<[PhysicalSortExpr]>, + /// Runtime metrics metrics: ExternalSorterMetrics, + /// If Some, the maximum number of output rows that will be + /// produced. fetch: Option, + /// Reservation for in_mem_batches reservation: MemoryReservation, - partition_id: usize, + /// Reservation for the merging of in-memory batches. If the sort + /// might spill, `sort_spill_reservation_bytes` will be + /// pre-reserved to ensure there is some space for this sort/merge. + merge_reservation: MemoryReservation, + /// A handle to the runtime to get spill files runtime: Arc, + /// The target number of rows for output batches batch_size: usize, + /// How much memory to reserve for performing in-memory sort/merges + /// prior to spilling. + sort_spill_reservation_bytes: usize, + /// If the in size of buffered memory batches is below this size, + /// the data will be concated and sorted in place rather than + /// sort/merged. + sort_in_place_threshold_bytes: usize, } impl ExternalSorter { + // TOOD: make a builder or some other nicer API to avoid the + // clippy warning + #[allow(clippy::too_many_arguments)] pub fn new( partition_id: usize, schema: SchemaRef, expr: Vec, batch_size: usize, fetch: Option, + sort_spill_reservation_bytes: usize, + sort_in_place_threshold_bytes: usize, metrics: &ExecutionPlanMetricsSet, runtime: Arc, ) -> Self { @@ -115,6 +252,10 @@ impl ExternalSorter { .with_can_spill(true) .register(&runtime.memory_pool); + let merge_reservation = + MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) + .register(&runtime.memory_pool); + Self { schema, in_mem_batches: vec![], @@ -124,9 +265,11 @@ impl ExternalSorter { metrics, fetch, reservation, - partition_id, + merge_reservation, runtime, batch_size, + sort_spill_reservation_bytes, + sort_in_place_threshold_bytes, } } @@ -137,12 +280,13 @@ impl ExternalSorter { if input.num_rows() == 0 { return Ok(()); } + self.reserve_memory_for_merge()?; - let size = batch_byte_size(&input); + let size = input.get_array_memory_size(); if self.reservation.try_grow(size).is_err() { let before = self.reservation.size(); self.in_mem_sort().await?; - // Sorting may have freed memory, especially if fetch is not `None` + // Sorting may have freed memory, especially if fetch is `Some` // // As such we check again, and if the memory usage has dropped by // a factor of 2, and we can allocate the necessary capacity, @@ -168,7 +312,15 @@ impl ExternalSorter { !self.spills.is_empty() } - /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. + /// Returns the final sorted output of all batches inserted via + /// [`Self::insert_batch`] as a stream of [`RecordBatch`]es. + /// + /// This process could either be: + /// + /// 1. An in-memory sort/merge (if the input fit in memory) + /// + /// 2. A combined streaming merge incorporating both in-memory + /// batches and data from spill files on disk. fn sort(&mut self) -> Result { if self.spilled_before() { let mut streams = vec![]; @@ -179,6 +331,12 @@ impl ExternalSorter { } for spill in self.spills.drain(..) { + if !spill.path().exists() { + return Err(DataFusionError::Internal(format!( + "Spill file {:?} does not exist", + spill.path() + ))); + } let stream = read_spill_as_stream(spill, self.schema.clone())?; streams.push(stream); } @@ -189,29 +347,35 @@ impl ExternalSorter { &self.expr, self.metrics.baseline.clone(), self.batch_size, + self.fetch, + self.reservation.new_empty(), ) } else if !self.in_mem_batches.is_empty() { - let result = self.in_mem_sort_stream(self.metrics.baseline.clone()); - // Report to the memory manager we are no longer using memory - self.reservation.free(); - result + self.in_mem_sort_stream(self.metrics.baseline.clone()) } else { Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) } } + /// How much memory is buffered in this `ExternalSorter`? fn used(&self) -> usize { self.reservation.size() } + /// How many bytes have been spilled to disk? fn spilled_bytes(&self) -> usize { self.metrics.spilled_bytes.value() } + /// How many spill files have been created? fn spill_count(&self) -> usize { self.metrics.spill_count.value() } + /// Writes any `in_memory_batches` to a spill file and clears + /// the batches. The contents of the spil file are sorted. + /// + /// Returns the amount of memory freed. async fn spill(&mut self) -> Result { // we could always get a chance to free some memory as long as we are holding some if self.in_mem_batches.is_empty() { @@ -238,6 +402,11 @@ impl ExternalSorter { return Ok(()); } + // Release the memory reserved for merge back to the pool so + // there is some left when `in_memo_sort_stream` requests an + // allocation. + self.merge_reservation.free(); + self.in_mem_batches = self .in_mem_sort_stream(self.metrics.baseline.intermediate())? .try_collect() @@ -249,12 +418,72 @@ impl ExternalSorter { .map(|x| x.get_array_memory_size()) .sum(); - self.reservation.resize(size); + // Reserve headroom for next sort/merge + self.reserve_memory_for_merge()?; + + self.reservation.try_resize(size)?; self.in_mem_batches_sorted = true; Ok(()) } - /// Consumes in_mem_batches returning a sorted stream + /// Consumes in_mem_batches returning a sorted stream of + /// batches. This proceeds in one of two ways: + /// + /// # Small Datasets + /// + /// For "smaller" datasets, the data is first concatenated into a + /// single batch and then sorted. This is often faster than + /// sorting and then merging. + /// + /// ```text + /// ┌─────┐ + /// │ 2 │ + /// │ 3 │ + /// │ 1 │─ ─ ─ ─ ┐ ┌─────┐ + /// │ 4 │ │ 2 │ + /// │ 2 │ │ │ 3 │ + /// └─────┘ │ 1 │ sorted output + /// ┌─────┐ ▼ │ 4 │ stream + /// │ 1 │ │ 2 │ + /// │ 4 │─ ─▶ concat ─ ─ ─ ─ ▶│ 1 │─ ─ ▶ sort ─ ─ ─ ─ ─▶ + /// │ 1 │ │ 4 │ + /// └─────┘ ▲ │ 1 │ + /// ... │ │ ... │ + /// │ 4 │ + /// ┌─────┐ │ │ 3 │ + /// │ 4 │ └─────┘ + /// │ 3 │─ ─ ─ ─ ┘ + /// └─────┘ + /// in_mem_batches + /// ``` + /// + /// # Larger datasets + /// + /// For larger datasets, the batches are first sorted individually + /// and then merged together. + /// + /// ```text + /// ┌─────┐ ┌─────┐ + /// │ 2 │ │ 1 │ + /// │ 3 │ │ 2 │ + /// │ 1 │─ ─▶ sort ─ ─▶│ 2 │─ ─ ─ ─ ─ ┐ + /// │ 4 │ │ 3 │ + /// │ 2 │ │ 4 │ │ + /// └─────┘ └─────┘ sorted output + /// ┌─────┐ ┌─────┐ ▼ stream + /// │ 1 │ │ 1 │ + /// │ 4 │─ ▶ sort ─ ─ ▶│ 1 ├ ─ ─ ▶ merge ─ ─ ─ ─▶ + /// │ 1 │ │ 4 │ + /// └─────┘ └─────┘ ▲ + /// ... ... ... │ + /// + /// ┌─────┐ ┌─────┐ │ + /// │ 4 │ │ 3 │ + /// │ 3 │─ ▶ sort ─ ─ ▶│ 4 │─ ─ ─ ─ ─ ┘ + /// └─────┘ └─────┘ + /// + /// in_mem_batches + /// ``` fn in_mem_sort_stream( &mut self, metrics: BaselineMetrics, @@ -262,65 +491,80 @@ impl ExternalSorter { assert_ne!(self.in_mem_batches.len(), 0); if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.remove(0); - let stream = self.sort_batch_stream(batch, metrics)?; - self.in_mem_batches.clear(); - return Ok(stream); + let reservation = self.reservation.take(); + return self.sort_batch_stream(batch, metrics, reservation); } - // If less than 1MB of in-memory data, concatenate and sort in place - // - // This is a very rough heuristic and likely could be refined further - if self.reservation.size() < 1048576 { + // If less than sort_in_place_threshold_bytes, concatenate and sort in place + if self.reservation.size() < self.sort_in_place_threshold_bytes { // Concatenate memory batches together and sort let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); - return self.sort_batch_stream(batch, metrics); + self.reservation.try_resize(batch.get_array_memory_size())?; + let reservation = self.reservation.take(); + return self.sort_batch_stream(batch, metrics, reservation); } let streams = std::mem::take(&mut self.in_mem_batches) .into_iter() .map(|batch| { let metrics = self.metrics.baseline.intermediate(); - Ok(spawn_buffered(self.sort_batch_stream(batch, metrics)?, 1)) + let reservation = self.reservation.split(batch.get_array_memory_size()); + let input = self.sort_batch_stream(batch, metrics, reservation)?; + Ok(spawn_buffered(input, 1)) }) .collect::>()?; - // TODO: Pushdown fetch to streaming merge (#6000) - streaming_merge( streams, self.schema.clone(), &self.expr, metrics, self.batch_size, + self.fetch, + self.merge_reservation.new_empty(), ) } + /// Sorts a single `RecordBatch` into a single stream. + /// + /// `reservation` accounts for the memory used by this batch and + /// is released when the sort is complete fn sort_batch_stream( &self, batch: RecordBatch, metrics: BaselineMetrics, + reservation: MemoryReservation, ) -> Result { + assert_eq!(batch.get_array_memory_size(), reservation.size()); let schema = batch.schema(); - let mut reservation = - MemoryConsumer::new(format!("sort_batch_stream{}", self.partition_id)) - .register(&self.runtime.memory_pool); - - // TODO: This should probably be try_grow (#5885) - reservation.resize(batch.get_array_memory_size()); - let fetch = self.fetch; let expressions = self.expr.clone(); let stream = futures::stream::once(futures::future::lazy(move |_| { let sorted = sort_batch(&batch, &expressions, fetch)?; metrics.record_output(sorted.num_rows()); drop(batch); - reservation.free(); + drop(reservation); Ok(sorted) })); Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } + + /// If this sort may spill, pre-allocates + /// `sort_spill_reservation_bytes` of memory to gurarantee memory + /// left for the in memory sort/merge. + fn reserve_memory_for_merge(&mut self) -> Result<()> { + // Reserve headroom for next merge sort + if self.runtime.disk_manager.tmp_files_enabled() { + let size = self.sort_spill_reservation_bytes; + if self.merge_reservation.size() != size { + self.merge_reservation.try_resize(size)?; + } + } + + Ok(()) + } } impl Debug for ExternalSorter { @@ -333,7 +577,7 @@ impl Debug for ExternalSorter { } } -fn sort_batch( +pub(crate) fn sort_batch( batch: &RecordBatch, expressions: &[PhysicalSortExpr], fetch: Option, @@ -363,23 +607,23 @@ async fn spill_sorted_batches( let handle = task::spawn_blocking(move || write_sorted(batches, path, schema)); match handle.await { Ok(r) => r, - Err(e) => Err(DataFusionError::Execution(format!( - "Error occurred while spilling {e}" - ))), + Err(e) => exec_err!("Error occurred while spilling {e}"), } } -fn read_spill_as_stream( - path: NamedTempFile, +pub(crate) fn read_spill_as_stream( + path: RefCountedTempFile, schema: SchemaRef, ) -> Result { let mut builder = RecordBatchReceiverStream::builder(schema, 2); let sender = builder.tx(); builder.spawn_blocking(move || { - if let Err(e) = read_spill(sender, path.path()) { + let result = read_spill(sender, path.path()); + if let Err(e) = &result { error!("Failure while reading spill file: {:?}. Error: {}", path, e); } + result }); Ok(builder.build()) @@ -417,8 +661,8 @@ fn read_spill(sender: Sender>, path: &Path) -> Result<()> { /// Sort execution plan. /// -/// This operator supports sorting datasets that are larger than the -/// memory allotted by the memory manager, by spilling to disk. +/// Support sorting datasets that are larger than the memory allotted +/// by the memory manager, by spilling to disk. #[derive(Debug)] pub struct SortExec { /// Input schema @@ -491,7 +735,13 @@ impl SortExec { self } - /// Whether this `SortExec` preserves partitioning of the children + /// Modify how many rows to include in the result + /// + /// If None, then all rows will be returned, in sorted order. + /// If Some, then only the top `fetch` rows will be returned. + /// This can reduce the memory pressure required by the sort + /// operation since rows that are not going to be included + /// can be dropped. pub fn with_fetch(mut self, fetch: Option) -> Self { self.fetch = fetch; self @@ -513,6 +763,26 @@ impl SortExec { } } +impl DisplayAs for SortExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let expr = PhysicalSortExpr::format_list(&self.expr); + match self.fetch { + Some(fetch) => { + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}]",) + } + None => write!(f, "SortExec: expr=[{expr}]"), + } + } + } + } +} + impl ExecutionPlan for SortExec { fn as_any(&self) -> &dyn Any { self @@ -536,9 +806,7 @@ impl ExecutionPlan for SortExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - Err(DataFusionError::Plan( - "Sort Error: Can not sort unbounded inputs.".to_string(), - )) + plan_err!("Sort Error: Can not sort unbounded inputs.") } else { Ok(false) } @@ -558,8 +826,8 @@ impl ExecutionPlan for SortExec { vec![self.input.clone()] } - fn benefits_from_input_partitioning(&self) -> bool { - false + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -567,7 +835,10 @@ impl ExecutionPlan for SortExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + // Reset the ordering equivalence class with the new ordering: + self.input + .equivalence_properties() + .with_reorder(self.expr.to_vec()) } fn with_new_children( @@ -590,127 +861,113 @@ impl ExecutionPlan for SortExec { let mut input = self.input.execute(partition, context.clone())?; + let execution_options = &context.session_config().options().execution; + trace!("End SortExec's input.execute for partition: {}", partition); - let mut sorter = ExternalSorter::new( - partition, - input.schema(), - self.expr.clone(), - context.session_config().batch_size(), - self.fetch, - &self.metrics_set, - context.runtime_env(), - ); + if let Some(fetch) = self.fetch.as_ref() { + let mut topk = TopK::try_new( + partition, + input.schema(), + self.expr.clone(), + *fetch, + context.session_config().batch_size(), + context.runtime_env(), + &self.metrics_set, + partition, + )?; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + topk.insert_batch(batch)?; + } + topk.emit() + }) + .try_flatten(), + ))) + } else { + let mut sorter = ExternalSorter::new( + partition, + input.schema(), + self.expr.clone(), + context.session_config().batch_size(), + self.fetch, + execution_options.sort_spill_reservation_bytes, + execution_options.sort_in_place_threshold_bytes, + &self.metrics_set, + context.runtime_env(), + ); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; - } - sorter.sort() - }) - .try_flatten(), - ))) + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + sorter.sort() + }) + .try_flatten(), + ))) + } } fn metrics(&self) -> Option { Some(self.metrics_set.clone_inner()) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - match self.fetch { - Some(fetch) => { - write!(f, "SortExec: fetch={fetch}, expr=[{}]", expr.join(",")) - } - None => write!(f, "SortExec: expr=[{}]", expr.join(",")), - } - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; - use crate::execution::context::SessionConfig; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::collect; - use crate::physical_plan::expressions::col; - use crate::physical_plan::memory::MemoryExec; - use crate::prelude::SessionContext; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::collect; + use crate::expressions::col; + use crate::memory::MemoryExec; use crate::test; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::*; - use datafusion_common::cast::{as_primitive_array, as_string_array}; + use datafusion_common::cast::as_primitive_array; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeConfig; + use futures::FutureExt; - use std::collections::HashMap; #[tokio::test] async fn test_in_mem_sort() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; + let csv = test::scan_partitioned(partitions); let schema = csv.schema(); let sort_exec = Arc::new(SortExec::new( - vec![ - // c1 string column - PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }, - // c2 uin32 column - PhysicalSortExpr { - expr: col("c2", &schema)?, - options: SortOptions::default(), - }, - // c7 uin8 column - PhysicalSortExpr { - expr: col("c7", &schema)?, - options: SortOptions::default(), - }, - ], + vec![PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }], Arc::new(CoalescePartitionsExec::new(csv)), )); - let result = collect(sort_exec, task_ctx).await?; + let result = collect(sort_exec, task_ctx.clone()).await?; assert_eq!(result.len(), 1); - - let columns = result[0].columns(); - - let c1 = as_string_array(&columns[0])?; - assert_eq!(c1.value(0), "a"); - assert_eq!(c1.value(c1.len() - 1), "e"); - - let c2 = as_primitive_array::(&columns[1])?; - assert_eq!(c2.value(0), 1); - assert_eq!(c2.value(c2.len() - 1), 5,); - - let c7 = as_primitive_array::(&columns[6])?; - assert_eq!(c7.value(0), 15); - assert_eq!(c7.value(c7.len() - 1), 254,); + assert_eq!(result[0].num_rows(), 400); assert_eq!( - session_ctx.runtime_env().memory_pool.reserved(), + task_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); @@ -720,65 +977,53 @@ mod tests { #[tokio::test] async fn test_sort_spill() -> Result<()> { - // trigger spill there will be 4 batches with 5.5KB for each - let config = RuntimeConfig::new().with_memory_limit(12288, 1.0); - let runtime = Arc::new(RuntimeEnv::new(config)?); - let session_ctx = SessionContext::with_config_rt(SessionConfig::new(), runtime); + // trigger spill w/ 100 batches + let session_config = SessionConfig::new(); + let sort_spill_reservation_bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + let rt_config = RuntimeConfig::new() + .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0); + let runtime = Arc::new(RuntimeEnv::new(rt_config)?); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); - let partitions = 4; - let csv = test::scan_partitioned_csv(partitions)?; - let schema = csv.schema(); + let partitions = 100; + let input = test::scan_partitioned(partitions); + let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - vec![ - // c1 string column - PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }, - // c2 uin32 column - PhysicalSortExpr { - expr: col("c2", &schema)?, - options: SortOptions::default(), - }, - // c7 uin8 column - PhysicalSortExpr { - expr: col("c7", &schema)?, - options: SortOptions::default(), - }, - ], - Arc::new(CoalescePartitionsExec::new(csv)), + vec![PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }], + Arc::new(CoalescePartitionsExec::new(input)), )); - let task_ctx = session_ctx.task_ctx(); - let result = collect(sort_exec.clone(), task_ctx).await?; + let result = collect(sort_exec.clone(), task_ctx.clone()).await?; - assert_eq!(result.len(), 1); + assert_eq!(result.len(), 2); // Now, validate metrics let metrics = sort_exec.metrics().unwrap(); - assert_eq!(metrics.output_rows().unwrap(), 100); + assert_eq!(metrics.output_rows().unwrap(), 10000); assert!(metrics.elapsed_compute().unwrap() > 0); assert!(metrics.spill_count().unwrap() > 0); assert!(metrics.spilled_bytes().unwrap() > 0); let columns = result[0].columns(); - let c1 = as_string_array(&columns[0])?; - assert_eq!(c1.value(0), "a"); - assert_eq!(c1.value(c1.len() - 1), "e"); - - let c2 = as_primitive_array::(&columns[1])?; - assert_eq!(c2.value(0), 1); - assert_eq!(c2.value(c2.len() - 1), 5,); - - let c7 = as_primitive_array::(&columns[6])?; - assert_eq!(c7.value(0), 15); - assert_eq!(c7.value(c7.len() - 1), 254,); + let i = as_primitive_array::(&columns[0])?; + assert_eq!(i.value(0), 0); + assert_eq!(i.value(i.len() - 1), 81); assert_eq!( - session_ctx.runtime_env().memory_pool.reserved(), + task_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); @@ -789,7 +1034,7 @@ mod tests { #[tokio::test] async fn test_sort_fetch_memory_calculation() -> Result<()> { // This test mirrors down the size from the example above. - let avg_batch_size = 4000; + let avg_batch_size = 400; let partitions = 4; // A tuple of (fetch, expect_spillage) @@ -803,45 +1048,42 @@ mod tests { ]; for (fetch, expect_spillage) in test_options { - let config = RuntimeConfig::new() - .with_memory_limit(avg_batch_size * (partitions - 1), 1.0); - let runtime = Arc::new(RuntimeEnv::new(config)?); - let session_ctx = - SessionContext::with_config_rt(SessionConfig::new(), runtime); + let session_config = SessionConfig::new(); + let sort_spill_reservation_bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + + let rt_config = RuntimeConfig::new().with_memory_limit( + sort_spill_reservation_bytes + avg_batch_size * (partitions - 1), + 1.0, + ); + let runtime = Arc::new(RuntimeEnv::new(rt_config)?); + let task_ctx = Arc::new( + TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config), + ); - let csv = test::scan_partitioned_csv(partitions)?; + let csv = test::scan_partitioned(partitions); let schema = csv.schema(); let sort_exec = Arc::new( SortExec::new( - vec![ - // c1 string column - PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }, - // c2 uin32 column - PhysicalSortExpr { - expr: col("c2", &schema)?, - options: SortOptions::default(), - }, - // c7 uin8 column - PhysicalSortExpr { - expr: col("c7", &schema)?, - options: SortOptions::default(), - }, - ], + vec![PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }], Arc::new(CoalescePartitionsExec::new(csv)), ) .with_fetch(fetch), ); - let task_ctx = session_ctx.task_ctx(); - let result = collect(sort_exec.clone(), task_ctx).await?; + let result = collect(sort_exec.clone(), task_ctx.clone()).await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); - let did_it_spill = metrics.spill_count().unwrap() > 0; + let did_it_spill = metrics.spill_count().unwrap_or(0) > 0; assert_eq!(did_it_spill, expect_spillage, "with fetch: {fetch:?}"); } Ok(()) @@ -849,8 +1091,7 @@ mod tests { #[tokio::test] async fn test_sort_metadata() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let field_metadata: HashMap = vec![("foo".to_string(), "bar".to_string())] .into_iter() @@ -899,8 +1140,7 @@ mod tests { #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float64, true), @@ -1005,8 +1245,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1020,7 +1259,7 @@ mod tests { blocking_exec, )); - let fut = collect(sort_exec, task_ctx); + let fut = collect(sort_exec, task_ctx.clone()); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1028,7 +1267,7 @@ mod tests { assert_strong_count_converges_to_zero(refs).await; assert_eq!( - session_ctx.runtime_env().memory_pool.reserved(), + task_ctx.runtime_env().memory_pool.reserved(), 0, "The sort should have returned all memory used back to the memory manager" ); diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs similarity index 83% rename from datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs rename to datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 1195959a89b6a..f4b57e8bfb45c 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -20,22 +20,23 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::SchemaRef; -use log::{debug, trace}; - -use crate::physical_plan::common::spawn_buffered; -use crate::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, -}; -use crate::physical_plan::sorts::streaming_merge; -use crate::physical_plan::{ - expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, SendableRecordBatchStream, Statistics, +use crate::common::spawn_buffered; +use crate::expressions::PhysicalSortExpr; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::sorts::streaming_merge; +use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, }; -use datafusion_common::{DataFusionError, Result}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; +use log::{debug, trace}; + /// Sort preserving merge execution plan /// /// This takes an input execution plan and a list of sort expressions, and @@ -71,6 +72,8 @@ pub struct SortPreservingMergeExec { expr: Vec, /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Optional number of rows to fetch. Stops producing rows after this fetch + fetch: Option, } impl SortPreservingMergeExec { @@ -80,8 +83,14 @@ impl SortPreservingMergeExec { input, expr, metrics: ExecutionPlanMetricsSet::new(), + fetch: None, } } + /// Sets the number of rows to fetch + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } /// Input schema pub fn input(&self) -> &Arc { @@ -92,6 +101,34 @@ impl SortPreservingMergeExec { pub fn expr(&self) -> &[PhysicalSortExpr] { &self.expr } + + /// Fetch + pub fn fetch(&self) -> Option { + self.fetch + } +} + +impl DisplayAs for SortPreservingMergeExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SortPreservingMergeExec: [{}]", + PhysicalSortExpr::format_list(&self.expr) + )?; + if let Some(fetch) = self.fetch { + write!(f, ", fetch={fetch}")?; + }; + + Ok(()) + } + } + } } impl ExecutionPlan for SortPreservingMergeExec { @@ -109,10 +146,21 @@ impl ExecutionPlan for SortPreservingMergeExec { Partitioning::UnknownPartitioning(1) } + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + fn required_input_ordering(&self) -> Vec>> { vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))] } @@ -137,10 +185,10 @@ impl ExecutionPlan for SortPreservingMergeExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(SortPreservingMergeExec::new( - self.expr.clone(), - children[0].clone(), - ))) + Ok(Arc::new( + SortPreservingMergeExec::new(self.expr.clone(), children[0].clone()) + .with_fetch(self.fetch), + )) } fn execute( @@ -153,9 +201,9 @@ impl ExecutionPlan for SortPreservingMergeExec { partition ); if 0 != partition { - return Err(DataFusionError::Internal(format!( + return internal_err!( "SortPreservingMergeExec invalid partition {partition}" - ))); + ); } let input_partitions = self.input.output_partitioning().partition_count(); @@ -165,11 +213,14 @@ impl ExecutionPlan for SortPreservingMergeExec { ); let schema = self.schema(); + let reservation = + MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]")) + .register(&context.runtime_env().memory_pool); + match input_partitions { - 0 => Err(DataFusionError::Internal( + 0 => internal_err!( "SortPreservingMergeExec requires at least one input partition" - .to_owned(), - )), + ), 1 => { // bypass if there is only one partition to merge (no metrics in this case either) let result = self.input.execute(0, context); @@ -192,6 +243,8 @@ impl ExecutionPlan for SortPreservingMergeExec { &self.expr, BaselineMetrics::new(&self.metrics, partition), context.session_config().batch_size(), + self.fetch, + reservation, )?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -201,24 +254,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); - write!(f, "SortPreservingMergeExec: [{}]", expr.join(",")) - } - } - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { self.input.statistics() } } @@ -227,31 +267,29 @@ impl ExecutionPlan for SortPreservingMergeExec { mod tests { use std::iter::FromIterator; - use arrow::array::ArrayRef; + use super::*; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::expressions::col; + use crate::memory::MemoryExec; + use crate::metrics::{MetricValue, Timestamp}; + use crate::sorts::sort::SortExec; + use crate::stream::RecordBatchReceiverStream; + use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::{self, assert_is_pending, make_partition}; + use crate::{collect, common}; + + use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use futures::{FutureExt, StreamExt}; + use datafusion_common::{assert_batches_eq, assert_contains}; + use datafusion_execution::config::SessionConfig; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::expressions::col; - use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::metrics::MetricValue; - use crate::physical_plan::sorts::sort::SortExec; - use crate::physical_plan::stream::RecordBatchReceiverStream; - use crate::physical_plan::{collect, common}; - use crate::prelude::{SessionConfig, SessionContext}; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; - use crate::test::{self, assert_is_pending}; - use crate::{assert_batches_eq, test_util}; - use arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; - - use super::*; + use futures::{FutureExt, StreamExt}; #[tokio::test] async fn test_merge_interleave() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -297,10 +335,28 @@ mod tests { .await; } + #[tokio::test] + async fn test_merge_no_exprs() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); + + let schema = batch.schema(); + let sort = vec![]; // no sort expressions + let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None) + .unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); + + let res = collect(merge, task_ctx).await.unwrap_err(); + assert_contains!( + res.to_string(), + "Internal error: Sort expressions cannot be empty for streaming merge" + ); + } + #[tokio::test] async fn test_merge_some_overlap() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -348,8 +404,7 @@ mod tests { #[tokio::test] async fn test_merge_no_overlap() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -397,8 +452,7 @@ mod tests { #[tokio::test] async fn test_merge_three_partitions() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -518,34 +572,19 @@ mod tests { } #[tokio::test] - async fn test_partition_sort() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + async fn test_partition_sort() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); let partitions = 4; - let csv = test::scan_partitioned_csv(partitions).unwrap(); + let csv = test::scan_partitioned(partitions); let schema = csv.schema(); - let sort = vec![ - PhysicalSortExpr { - expr: col("c1", &schema).unwrap(), - options: SortOptions { - descending: true, - nulls_first: true, - }, - }, - PhysicalSortExpr { - expr: col("c2", &schema).unwrap(), - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("c7", &schema).unwrap(), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: col("c12", &schema).unwrap(), - options: SortOptions::default(), + let sort = vec![PhysicalSortExpr { + expr: col("i", &schema).unwrap(), + options: SortOptions { + descending: true, + nulls_first: true, }, - ]; + }]; let basic = basic_sort(csv.clone(), sort.clone(), task_ctx.clone()).await; let partition = partition_sort(csv, sort, task_ctx.clone()).await; @@ -561,6 +600,8 @@ mod tests { basic, partition, "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n" ); + + Ok(()) } // Split the provided record batch into multiple batch_size record batches @@ -590,51 +631,35 @@ mod tests { sort: Vec, sizes: &[usize], context: Arc, - ) -> Arc { + ) -> Result> { let partitions = 4; - let csv = test::scan_partitioned_csv(partitions).unwrap(); + let csv = test::scan_partitioned(partitions); let sorted = basic_sort(csv, sort, context).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); - Arc::new(MemoryExec::try_new(&split, sorted.schema(), None).unwrap()) + Ok(Arc::new( + MemoryExec::try_new(&split, sorted.schema(), None).unwrap(), + )) } #[tokio::test] - async fn test_partition_sort_streaming_input() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); - let sort = vec![ - // uint8 - PhysicalSortExpr { - expr: col("c7", &schema).unwrap(), - options: Default::default(), - }, - // int16 - PhysicalSortExpr { - expr: col("c4", &schema).unwrap(), - options: Default::default(), - }, - // utf-8 - PhysicalSortExpr { - expr: col("c1", &schema).unwrap(), - options: SortOptions::default(), - }, - // utf-8 - PhysicalSortExpr { - expr: col("c13", &schema).unwrap(), - options: SortOptions::default(), - }, - ]; + async fn test_partition_sort_streaming_input() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = make_partition(11).schema(); + let sort = vec![PhysicalSortExpr { + expr: col("i", &schema).unwrap(), + options: Default::default(), + }]; let input = - sorted_partitioned_input(sort.clone(), &[10, 3, 11], task_ctx.clone()).await; + sorted_partitioned_input(sort.clone(), &[10, 3, 11], task_ctx.clone()) + .await?; let basic = basic_sort(input.clone(), sort.clone(), task_ctx.clone()).await; let partition = sorted_merge(input, sort, task_ctx.clone()).await; - assert_eq!(basic.num_rows(), 300); - assert_eq!(partition.num_rows(), 300); + assert_eq!(basic.num_rows(), 1200); + assert_eq!(partition.num_rows(), 1200); let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -644,42 +669,37 @@ mod tests { .to_string(); assert_eq!(basic, partition); + + Ok(()) } #[tokio::test] - async fn test_partition_sort_streaming_input_output() { - let schema = test_util::aggr_test_schema(); - - let sort = vec![ - // float64 - PhysicalSortExpr { - expr: col("c12", &schema).unwrap(), - options: Default::default(), - }, - // utf-8 - PhysicalSortExpr { - expr: col("c13", &schema).unwrap(), - options: Default::default(), - }, - ]; + async fn test_partition_sort_streaming_input_output() -> Result<()> { + let schema = make_partition(11).schema(); + let sort = vec![PhysicalSortExpr { + expr: col("i", &schema).unwrap(), + options: Default::default(), + }]; - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + // Test streaming with default batch size + let task_ctx = Arc::new(TaskContext::default()); let input = - sorted_partitioned_input(sort.clone(), &[10, 5, 13], task_ctx.clone()).await; + sorted_partitioned_input(sort.clone(), &[10, 5, 13], task_ctx.clone()) + .await?; let basic = basic_sort(input.clone(), sort.clone(), task_ctx).await; - let session_ctx_bs_23 = - SessionContext::with_config(SessionConfig::new().with_batch_size(23)); + // batch size of 23 + let task_ctx = TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(23)); + let task_ctx = Arc::new(task_ctx); let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let task_ctx = session_ctx_bs_23.task_ctx(); let merged = collect(merge, task_ctx).await.unwrap(); - assert_eq!(merged.len(), 14); + assert_eq!(merged.len(), 53); - assert_eq!(basic.num_rows(), 300); - assert_eq!(merged.iter().map(|x| x.num_rows()).sum::(), 300); + assert_eq!(basic.num_rows(), 1200); + assert_eq!(merged.iter().map(|x| x.num_rows()).sum::(), 1200); let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -689,12 +709,13 @@ mod tests { .to_string(); assert_eq!(basic, partition); + + Ok(()) } #[tokio::test] async fn test_nulls() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ None, @@ -774,17 +795,16 @@ mod tests { } #[tokio::test] - async fn test_async() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let schema = test_util::aggr_test_schema(); + async fn test_async() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = make_partition(11).schema(); let sort = vec![PhysicalSortExpr { - expr: col("c12", &schema).unwrap(), + expr: col("i", &schema).unwrap(), options: SortOptions::default(), }]; let batches = - sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await; + sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await?; let partition_count = batches.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(partition_count); @@ -801,19 +821,26 @@ mod tests { // This causes the MergeStream to wait for more input tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; } + + Ok(()) }); streams.push(builder.build()); } let metrics = ExecutionPlanMetricsSet::new(); + let reservation = + MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool); + let fetch = None; let merge_stream = streaming_merge( streams, batches.schema(), sort.as_slice(), BaselineMetrics::new(&metrics, 0), task_ctx.session_config().batch_size(), + fetch, + reservation, ) .unwrap(); @@ -834,12 +861,13 @@ mod tests { basic, partition, "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n" ); + + Ok(()) } #[tokio::test] async fn test_merge_metrics() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); @@ -857,7 +885,7 @@ mod tests { let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); let collected = collect(merge.clone(), task_ctx).await.unwrap(); - let expected = vec![ + let expected = [ "+----+---+", "| a | b |", "+----+---+", @@ -880,11 +908,11 @@ mod tests { metrics.iter().for_each(|m| match m.value() { MetricValue::StartTimestamp(ts) => { saw_start = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); + assert!(nanos_from_timestamp(ts) > 0); } MetricValue::EndTimestamp(ts) => { saw_end = true; - assert!(ts.value().unwrap().timestamp_nanos() > 0); + assert!(nanos_from_timestamp(ts) > 0); } _ => {} }); @@ -893,10 +921,13 @@ mod tests { assert!(saw_end); } + fn nanos_from_timestamp(ts: &Timestamp) -> i64 { + ts.value().unwrap().timestamp_nanos_opt().unwrap() + } + #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -922,8 +953,7 @@ mod tests { #[tokio::test] async fn test_stable_sort() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); // Create record batches like: // batch_number |value diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs similarity index 81% rename from datafusion/core/src/physical_plan/sorts/stream.rs rename to datafusion/physical-plan/src/sorts/stream.rs index 97a3b85fa5353..135b4fbdece49 100644 --- a/datafusion/core/src/physical_plan/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_plan::sorts::cursor::{FieldArray, FieldCursor, RowCursor}; -use crate::physical_plan::SendableRecordBatchStream; -use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr}; +use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; +use crate::SendableRecordBatchStream; +use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::Array; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; use futures::stream::{Fuse, StreamExt}; use std::marker::PhantomData; use std::sync::Arc; @@ -75,7 +76,7 @@ impl FusedStreams { } /// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`] -/// and computes [`RowCursor`] based on the provided [`PhysicalSortExpr`] +/// and computes [`RowValues`] based on the provided [`PhysicalSortExpr`] #[derive(Debug)] pub struct RowCursorStream { /// Converter to convert output of physical expressions @@ -84,6 +85,8 @@ pub struct RowCursorStream { column_expressions: Vec>, /// Input streams streams: FusedStreams, + /// Tracks the memory used by `converter` + reservation: MemoryReservation, } impl RowCursorStream { @@ -91,6 +94,7 @@ impl RowCursorStream { schema: &Schema, expressions: &[PhysicalSortExpr], streams: Vec, + reservation: MemoryReservation, ) -> Result { let sort_fields = expressions .iter() @@ -104,25 +108,31 @@ impl RowCursorStream { let converter = RowConverter::new(sort_fields)?; Ok(Self { converter, + reservation, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), streams: FusedStreams(streams), }) } - fn convert_batch(&mut self, batch: &RecordBatch) -> Result { + fn convert_batch(&mut self, batch: &RecordBatch) -> Result { let cols = self .column_expressions .iter() - .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; let rows = self.converter.convert_columns(&cols)?; - Ok(RowCursor::new(rows)) + self.reservation.try_resize(self.converter.size())?; + + // track the memory in the newly created Rows. + let mut rows_reservation = self.reservation.new_empty(); + rows_reservation.try_grow(rows.size())?; + Ok(RowValues::new(rows, rows_reservation)) } } impl PartitionedStream for RowCursorStream { - type Output = Result<(RowCursor, RecordBatch)>; + type Output = Result<(RowValues, RecordBatch)>; fn partitions(&self) -> usize { self.streams.0.len() @@ -143,7 +153,7 @@ impl PartitionedStream for RowCursorStream { } /// Specialized stream for sorts on single primitive columns -pub struct FieldCursorStream { +pub struct FieldCursorStream { /// The physical expressions to sort by sort: PhysicalSortExpr, /// Input streams @@ -151,7 +161,7 @@ pub struct FieldCursorStream { phantom: PhantomData T>, } -impl std::fmt::Debug for FieldCursorStream { +impl std::fmt::Debug for FieldCursorStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PrimitiveCursorStream") .field("num_streams", &self.streams) @@ -159,7 +169,7 @@ impl std::fmt::Debug for FieldCursorStream { } } -impl FieldCursorStream { +impl FieldCursorStream { pub fn new(sort: PhysicalSortExpr, streams: Vec) -> Self { let streams = streams.into_iter().map(|s| s.fuse()).collect(); Self { @@ -169,16 +179,16 @@ impl FieldCursorStream { } } - fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { + fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; - let array = value.into_array(batch.num_rows()); + let array = value.into_array(batch.num_rows())?; let array = array.as_any().downcast_ref::().expect("field values"); - Ok(FieldCursor::new(self.sort.options, array)) + Ok(ArrayValues::new(self.sort.options, array)) } } -impl PartitionedStream for FieldCursorStream { - type Output = Result<(FieldCursor, RecordBatch)>; +impl PartitionedStream for FieldCursorStream { + type Output = Result<(ArrayValues, RecordBatch)>; fn partitions(&self) -> usize { self.streams.0.len() diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs new file mode 100644 index 0000000000000..4f8d8063853b3 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Merge that deals with an arbitrary size of streaming inputs. +//! This is an order-preserving merge. + +use crate::metrics::BaselineMetrics; +use crate::sorts::{ + merge::SortPreservingMergeStream, + stream::{FieldCursorStream, RowCursorStream}, +}; +use crate::{PhysicalSortExpr, SendableRecordBatchStream}; +use arrow::datatypes::{DataType, SchemaRef}; +use arrow_array::*; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::MemoryReservation; + +macro_rules! primitive_merge_helper { + ($t:ty, $($v:ident),+) => { + merge_helper!(PrimitiveArray<$t>, $($v),+) + }; +} + +macro_rules! merge_helper { + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + let streams = FieldCursorStream::<$t>::new($sort, $streams); + return Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + $schema, + $tracking_metrics, + $batch_size, + $fetch, + $reservation, + ))); + }}; +} + +/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions +/// while preserving order. +pub fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + metrics: BaselineMetrics, + batch_size: usize, + fetch: Option, + reservation: MemoryReservation, +) -> Result { + // If there are no sort expressions, preserving the order + // doesn't mean anything (and result in infinite loops) + if expressions.is_empty() { + return internal_err!("Sort expressions cannot be empty for streaming merge"); + } + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + metrics, + batch_size, + fetch, + reservation, + ))) +} diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/physical-plan/src/stream.rs similarity index 74% rename from datafusion/core/src/physical_plan/stream.rs rename to datafusion/physical-plan/src/stream.rs index bdc2050b24646..fdf32620ca50e 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -17,12 +17,16 @@ //! Stream wrappers for physical operators +use std::pin::Pin; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; -use crate::physical_plan::displayable; +use crate::displayable; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::{exec_err, internal_err}; use datafusion_execution::TaskContext; use futures::stream::BoxStream; use futures::{Future, Stream, StreamExt}; @@ -34,6 +38,124 @@ use tokio::task::JoinSet; use super::metrics::BaselineMetrics; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; +/// Creates a stream from a collection of producing tasks, routing panics to the stream. +/// +/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being: +/// +/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`). +/// +/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver. +/// +/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped. +/// +/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html + +pub(crate) struct ReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, + join_set: JoinSet>, +} + +impl ReceiverStreamBuilder { + /// create new channels with the specified buffer size + pub fn new(capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + + Self { + tx, + rx, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending data to the output + pub fn tx(&self) -> Sender> { + self.tx.clone() + } + + /// Spawn task that will be aborted if this builder (or the stream + /// built from it) are dropped + pub fn spawn(&mut self, task: F) + where + F: Future>, + F: Send + 'static, + { + self.join_set.spawn(task); + } + + /// Spawn a blocking task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce() -> Result<()>, + F: Send + 'static, + { + self.join_set.spawn_blocking(f); + } + + /// Create a stream of all data written to `tx` + pub fn build(self) -> BoxStream<'static, Result> { + let Self { + tx, + rx, + mut join_set, + } = self; + + // don't need tx + drop(tx); + + // future that checks the result of the join set, and propagates panic if seen + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(task_result) => { + match task_result { + // nothing to report + Ok(_) => continue, + // This means a blocking task error + Err(e) => { + return Some(exec_err!("Spawned Task error: {e}")); + } + } + } + // This means a tokio task error, likely a panic + Err(e) => { + if e.is_panic() { + // resume on the main thread + std::panic::resume_unwind(e.into_panic()); + } else { + // This should only occur if the task is + // cancelled, which would only occur if + // the JoinSet were aborted, which in turn + // would imply that the receiver has been + // dropped and this code is not running + return Some(internal_err!("Non Panic Task error: {e}")); + } + } + } + } + None + }; + + let check_stream = futures::stream::once(check) + // unwrap Option / only return the error + .filter_map(|item| async move { item }); + + // Convert the receiver into a stream + let rx_stream = futures::stream::unfold(rx, |mut rx| async move { + let next_item = rx.recv().await; + next_item.map(|next_item| (next_item, rx)) + }); + + // Merge the streams together so whichever is ready first + // produces the batch + futures::stream::select(rx_stream, check_stream).boxed() + } +} + /// Builder for [`RecordBatchReceiverStream`] that propagates errors /// and panic's correctly. /// @@ -43,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; /// /// This also handles propagating panic`s and canceling the tasks. pub struct RecordBatchReceiverStreamBuilder { - tx: Sender>, - rx: Receiver>, schema: SchemaRef, - join_set: JoinSet<()>, + inner: ReceiverStreamBuilder, } impl RecordBatchReceiverStreamBuilder { /// create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { - let (tx, rx) = tokio::sync::mpsc::channel(capacity); - Self { - tx, - rx, schema, - join_set: JoinSet::new(), + inner: ReceiverStreamBuilder::new(capacity), } } - /// Get a handle for sending [`RecordBatch`]es to the output + /// Get a handle for sending [`RecordBatch`] to the output pub fn tx(&self) -> Sender> { - self.tx.clone() + self.inner.tx() } /// Spawn task that will be aborted if this builder (or the stream @@ -74,10 +190,10 @@ impl RecordBatchReceiverStreamBuilder { /// retrieved from `Self::tx` pub fn spawn(&mut self, task: F) where - F: Future, + F: Future>, F: Send + 'static, { - self.join_set.spawn(task); + self.inner.spawn(task) } /// Spawn a blocking task that will be aborted if this builder (or the stream @@ -87,10 +203,10 @@ impl RecordBatchReceiverStreamBuilder { /// retrieved from `Self::tx` pub fn spawn_blocking(&mut self, f: F) where - F: FnOnce(), + F: FnOnce() -> Result<()>, F: Send + 'static, { - self.join_set.spawn_blocking(f); + self.inner.spawn_blocking(f) } /// runs the input_partition of the `input` ExecutionPlan on the @@ -106,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder { ) { let output = self.tx(); - self.spawn(async move { + self.inner.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { // If send fails, the plan being torn down, there @@ -116,7 +232,7 @@ impl RecordBatchReceiverStreamBuilder { "Stopping execution: error executing input: {}", displayable(input.as_ref()).one_line() ); - return; + return Ok(()); } Ok(stream) => stream, }; @@ -133,7 +249,7 @@ impl RecordBatchReceiverStreamBuilder { "Stopping execution: output is gone, plan cancelling: {}", displayable(input.as_ref()).one_line() ); - return; + return Ok(()); } // stop after the first error is encontered (don't @@ -143,79 +259,25 @@ impl RecordBatchReceiverStreamBuilder { "Stopping execution: plan returned error: {}", displayable(input.as_ref()).one_line() ); - return; + return Ok(()); } } + + Ok(()) }); } - /// Create a stream of all `RecordBatch`es written to `tx` + /// Create a stream of all [`RecordBatch`] written to `tx` pub fn build(self) -> SendableRecordBatchStream { - let Self { - tx, - rx, - schema, - mut join_set, - } = self; - - // don't need tx - drop(tx); - - // future that checks the result of the join set, and propagates panic if seen - let check = async move { - while let Some(result) = join_set.join_next().await { - match result { - Ok(()) => continue, // nothing to report - // This means a tokio task error, likely a panic - Err(e) => { - if e.is_panic() { - // resume on the main thread - std::panic::resume_unwind(e.into_panic()); - } else { - // This should only occur if the task is - // cancelled, which would only occur if - // the JoinSet were aborted, which in turn - // would imply that the receiver has been - // dropped and this code is not running - return Some(Err(DataFusionError::Internal(format!( - "Non Panic Task error: {e}" - )))); - } - } - } - } - None - }; - - let check_stream = futures::stream::once(check) - // unwrap Option / only return the error - .filter_map(|item| async move { item }); - - // Convert the receiver into a stream - let rx_stream = futures::stream::unfold(rx, |mut rx| async move { - let next_item = rx.recv().await; - next_item.map(|next_item| (next_item, rx)) - }); - - // Merge the streams together so whichever is ready first - // produces the batch - let inner = futures::stream::select(rx_stream, check_stream).boxed(); - - Box::pin(RecordBatchReceiverStream { schema, inner }) + Box::pin(RecordBatchStreamAdapter::new( + self.schema, + self.inner.build(), + )) } } -/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs, -/// on new tokio Tasks, increasing the potential parallelism. -/// -/// This structure also handles propagating panics and cancelling the -/// underlying tasks correctly. -/// -/// Use [`Self::builder`] to construct one. -pub struct RecordBatchReceiverStream { - schema: SchemaRef, - inner: BoxStream<'static, Result>, -} +#[doc(hidden)] +pub struct RecordBatchReceiverStream {} impl RecordBatchReceiverStream { /// Create a builder with an internal buffer of capacity batches. @@ -227,23 +289,6 @@ impl RecordBatchReceiverStream { } } -impl Stream for RecordBatchReceiverStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_next_unpin(cx) - } -} - -impl RecordBatchStream for RecordBatchReceiverStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing /// [`RecordBatchStream`] for the combination @@ -276,10 +321,7 @@ where { type Item = Result; - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().stream.poll_next(cx) } @@ -297,6 +339,37 @@ where } } +/// EmptyRecordBatchStream can be used to create a RecordBatchStream +/// that will produce no results +pub struct EmptyRecordBatchStream { + /// Schema wrapped by Arc + schema: SchemaRef, +} + +impl EmptyRecordBatchStream { + /// Create an empty RecordBatchStream + pub fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +impl RecordBatchStream for EmptyRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for EmptyRecordBatchStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(None) + } +} + /// Stream wrapper that records `BaselineMetrics` for a particular /// `[SendableRecordBatchStream]` (likely a partition) pub(crate) struct ObservedStream { @@ -326,9 +399,9 @@ impl futures::Stream for ObservedStream { type Item = Result; fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let poll = self.inner.poll_next_unpin(cx); self.baseline_metrics.record_poll(poll) } @@ -338,12 +411,10 @@ impl futures::Stream for ObservedStream { mod test { use super::*; use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::exec_err; - use crate::{ - execution::context::SessionContext, - test::exec::{ - assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, - }, + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, }; fn schema() -> SchemaRef { @@ -382,8 +453,7 @@ mod test { #[tokio::test] async fn record_batch_receiver_stream_drop_cancel() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = schema(); // Make an input that never proceeds @@ -408,19 +478,13 @@ mod test { /// `RecordBatchReceiverStream` stops early and does not drive /// other streams to completion. async fn record_batch_receiver_stream_error_does_not_drive_completion() { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let schema = schema(); // make an input that will error twice - let error_stream = MockExec::new( - vec![ - Err(DataFusionError::Execution("Test1".to_string())), - Err(DataFusionError::Execution("Test2".to_string())), - ], - schema.clone(), - ) - .with_use_task(false); + let error_stream = + MockExec::new(vec![exec_err!("Test1"), exec_err!("Test2")], schema.clone()) + .with_use_task(false); let mut builder = RecordBatchReceiverStream::builder(schema, 2); builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); @@ -429,7 +493,7 @@ mod test { // get the first result, which should be an error let first_batch = stream.next().await.unwrap(); let first_err = first_batch.unwrap_err(); - assert_eq!(first_err.to_string(), "Execution error: Test1"); + assert_eq!(first_err.strip_backtrace(), "Execution error: Test1"); // There should be no more batches produced (should not get the second error) assert!(stream.next().await.is_none()); @@ -440,8 +504,7 @@ mod test { /// /// panic's if more than max_batches is seen, async fn consume(input: PanicExec, max_batches: usize) { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); let input = Arc::new(input); let num_partitions = input.output_partitioning().partition_count(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs new file mode 100644 index 0000000000000..59819c6921fb4 --- /dev/null +++ b/datafusion/physical-plan/src/streaming.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic plans for deferred execution: [`StreamingTableExec`] and [`PartitionStream`] + +use std::any::Any; +use std::sync::Arc; + +use super::{DisplayAs, DisplayFormatType}; +use crate::display::{OutputOrderingDisplay, ProjectSchemaDisplay}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; + +use arrow::datatypes::SchemaRef; +use arrow_schema::Schema; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; + +use async_trait::async_trait; +use futures::stream::StreamExt; +use log::debug; + +/// A partition that can be converted into a [`SendableRecordBatchStream`] +/// +/// Combined with [`StreamingTableExec`], you can use this trait to implement +/// [`ExecutionPlan`] for a custom source with less boiler plate than +/// implementing `ExecutionPlan` directly for many use cases. +pub trait PartitionStream: Send + Sync { + /// Returns the schema of this partition + fn schema(&self) -> &SchemaRef; + + /// Returns a stream yielding this partitions values + fn execute(&self, ctx: Arc) -> SendableRecordBatchStream; +} + +/// An [`ExecutionPlan`] for one or more [`PartitionStream`]s. +/// +/// If your source can be represented as one or more [`PartitionStream`]s, you can +/// use this struct to implement [`ExecutionPlan`]. +pub struct StreamingTableExec { + partitions: Vec>, + projection: Option>, + projected_schema: SchemaRef, + projected_output_ordering: Vec, + infinite: bool, +} + +impl StreamingTableExec { + /// Try to create a new [`StreamingTableExec`] returning an error if the schema is incorrect + pub fn try_new( + schema: SchemaRef, + partitions: Vec>, + projection: Option<&Vec>, + projected_output_ordering: impl IntoIterator, + infinite: bool, + ) -> Result { + for x in partitions.iter() { + let partition_schema = x.schema(); + if !schema.eq(partition_schema) { + debug!( + "Target schema does not match with partition schema. \ + Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + ); + return plan_err!("Mismatch between schema and batches"); + } + } + + let projected_schema = match projection { + Some(p) => Arc::new(schema.project(p)?), + None => schema, + }; + + Ok(Self { + partitions, + projected_schema, + projection: projection.cloned().map(Into::into), + projected_output_ordering: projected_output_ordering.into_iter().collect(), + infinite, + }) + } + + pub fn partitions(&self) -> &Vec> { + &self.partitions + } + + pub fn partition_schema(&self) -> &SchemaRef { + self.partitions[0].schema() + } + + pub fn projection(&self) -> &Option> { + &self.projection + } + + pub fn projected_schema(&self) -> &Schema { + &self.projected_schema + } + + pub fn projected_output_ordering(&self) -> impl IntoIterator { + self.projected_output_ordering.clone() + } + + pub fn is_infinite(&self) -> bool { + self.infinite + } +} + +impl std::fmt::Debug for StreamingTableExec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LazyMemTableExec").finish_non_exhaustive() + } +} + +impl DisplayAs for StreamingTableExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "StreamingTableExec: partition_sizes={:?}", + self.partitions.len(), + )?; + if !self.projected_schema.fields().is_empty() { + write!( + f, + ", projection={}", + ProjectSchemaDisplay(&self.projected_schema) + )?; + } + if self.infinite { + write!(f, ", infinite_source=true")?; + } + + self.projected_output_ordering + .first() + .map_or(Ok(()), |ordering| { + if !ordering.is_empty() { + write!( + f, + ", output_ordering={}", + OutputOrderingDisplay(ordering) + )?; + } + Ok(()) + }) + } + } + } +} + +#[async_trait] +impl ExecutionPlan for StreamingTableExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.projected_schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.partitions.len()) + } + + fn unbounded_output(&self, _children: &[bool]) -> Result { + Ok(self.infinite) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.projected_output_ordering + .first() + .map(|ordering| ordering.as_slice()) + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + EquivalenceProperties::new_with_orderings( + self.schema(), + &self.projected_output_ordering, + ) + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(self) + } else { + internal_err!("Children cannot be replaced in {self:?}") + } + } + + fn execute( + &self, + partition: usize, + ctx: Arc, + ) -> Result { + let stream = self.partitions[partition].execute(ctx); + Ok(match self.projection.clone() { + Some(projection) => Box::pin(RecordBatchStreamAdapter::new( + self.projected_schema.clone(), + stream.map(move |x| { + x.and_then(|b| b.project(projection.as_ref()).map_err(Into::into)) + }), + )), + None => stream, + }) + } +} diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs new file mode 100644 index 0000000000000..9e6312284c08f --- /dev/null +++ b/datafusion/physical-plan/src/test.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for testing datafusion-physical-plan + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use futures::{Future, FutureExt}; + +use crate::memory::MemoryExec; +use crate::ExecutionPlan; + +pub mod exec; + +/// Asserts that given future is pending. +pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send + 'a>>) { + let waker = futures::task::noop_waker(); + let mut cx = futures::task::Context::from_waker(&waker); + let poll = fut.poll_unpin(&mut cx); + + assert!(poll.is_pending()); +} + +/// Get the schema for the aggregate_test_* csv files +pub fn aggr_test_schema() -> SchemaRef { + let mut f1 = Field::new("c1", DataType::Utf8, false); + f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())])); + let schema = Schema::new(vec![ + f1, + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ]); + + Arc::new(schema) +} + +/// returns record batch with 3 columns of i32 in memory +pub fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() +} + +/// returns memory table scan wrapped around record batch with 3 columns of i32 +pub fn build_table_scan_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) +} + +/// Return a RecordBatch with a single Int32 array with values (0..sz) in a field named "i" +pub fn make_partition(sz: i32) -> RecordBatch { + let seq_start = 0; + let seq_end = sz; + let values = (seq_start..seq_end).collect::>(); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from(values)); + let arr = arr as ArrayRef; + + RecordBatch::try_new(schema, vec![arr]).unwrap() +} + +/// Returns a `MemoryExec` that scans `partitions` of 100 batches each +pub fn scan_partitioned(partitions: usize) -> Arc { + Arc::new(mem_exec(partitions)) +} + +/// Returns a `MemoryExec` that scans `partitions` of 100 batches each +pub fn mem_exec(partitions: usize) -> MemoryExec { + let data: Vec> = (0..partitions).map(|_| vec![make_partition(100)]).collect(); + + let schema = data[0][0].schema(); + let projection = None; + MemoryExec::try_new(&data, schema, projection).unwrap() +} diff --git a/datafusion/core/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs similarity index 91% rename from datafusion/core/src/test/exec.rs rename to datafusion/physical-plan/src/test/exec.rs index 41a0a1b4d084a..1f6ee1f117aa4 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -23,27 +23,22 @@ use std::{ sync::{Arc, Weak}, task::{Context, Poll}, }; -use tokio::sync::Barrier; - -use arrow::{ - datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use futures::Stream; -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::{ - common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, -}; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::{ - error::{DataFusionError, Result}, - physical_plan::stream::RecordBatchReceiverStream, -}; -use crate::{ - execution::context::TaskContext, physical_plan::stream::RecordBatchStreamAdapter, + common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalSortExpr; + +use futures::Stream; +use tokio::sync::Barrier; + /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] pub struct BatchIndex { @@ -153,6 +148,20 @@ impl MockExec { } } +impl DisplayAs for MockExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "MockExec") + } + } + } +} + impl ExecutionPlan for MockExec { fn as_any(&self) -> &dyn Any { self @@ -212,6 +221,8 @@ impl ExecutionPlan for MockExec { println!("ERROR batch via delayed stream: {e}"); } } + + Ok(()) }); // returned stream simply reads off the rx stream Ok(builder.build()) @@ -225,20 +236,8 @@ impl ExecutionPlan for MockExec { } } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "MockExec") - } - } - } - // Panics if one of the batches is an error - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let data: Result> = self .data .iter() @@ -248,9 +247,13 @@ impl ExecutionPlan for MockExec { }) .collect(); - let data = data.unwrap(); + let data = data?; - common::compute_record_batch_statistics(&[data], &self.schema, None) + Ok(common::compute_record_batch_statistics( + &[data], + &self.schema, + None, + )) } } @@ -295,6 +298,20 @@ impl BarrierExec { } } +impl DisplayAs for BarrierExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BarrierExec") + } + } + } +} + impl ExecutionPlan for BarrierExec { fn as_any(&self) -> &dyn Any { self @@ -346,26 +363,20 @@ impl ExecutionPlan for BarrierExec { println!("ERROR batch via barrier stream stream: {e}"); } } + + Ok(()) }); // returned stream simply reads off the rx stream Ok(builder.build()) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "BarrierExec") - } - } - } - - fn statistics(&self) -> Statistics { - common::compute_record_batch_statistics(&self.data, &self.schema, None) + fn statistics(&self) -> Result { + Ok(common::compute_record_batch_statistics( + &self.data, + &self.schema, + None, + )) } } @@ -392,6 +403,20 @@ impl ErrorExec { } } +impl DisplayAs for ErrorExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ErrorExec") + } + } + } +} + impl ExecutionPlan for ErrorExec { fn as_any(&self) -> &dyn Any { self @@ -426,25 +451,7 @@ impl ExecutionPlan for ErrorExec { partition: usize, _context: Arc, ) -> Result { - Err(DataFusionError::Internal(format!( - "ErrorExec, unsurprisingly, errored in partition {partition}" - ))) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "ErrorExec") - } - } - } - - fn statistics(&self) -> Statistics { - Statistics::default() + internal_err!("ErrorExec, unsurprisingly, errored in partition {partition}") } } @@ -456,12 +463,9 @@ pub struct StatisticsExec { } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { - assert!( + assert_eq!( stats - .column_statistics - .as_ref() - .map(|cols| cols.len() == schema.fields().len()) - .unwrap_or(true), + .column_statistics.len(), schema.fields().len(), "if defined, the column statistics vector length should be the number of fields" ); Self { @@ -470,6 +474,26 @@ impl StatisticsExec { } } } + +impl DisplayAs for StatisticsExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "StatisticsExec: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } +} + impl ExecutionPlan for StatisticsExec { fn as_any(&self) -> &dyn Any { self @@ -506,25 +530,8 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Statistics { - self.stats.clone() - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!( - f, - "StatisticsExec: col_count={}, row_count={:?}", - self.schema.fields().len(), - self.stats.num_rows, - ) - } - } + fn statistics(&self) -> Result { + Ok(self.stats.clone()) } } @@ -563,6 +570,20 @@ impl BlockingExec { } } +impl DisplayAs for BlockingExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BlockingExec",) + } + } + } +} + impl ExecutionPlan for BlockingExec { fn as_any(&self) -> &dyn Any { self @@ -589,9 +610,7 @@ impl ExecutionPlan for BlockingExec { self: Arc, _: Vec>, ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {self:?}" - ))) + internal_err!("Children cannot be replaced in {self:?}") } fn execute( @@ -604,22 +623,6 @@ impl ExecutionPlan for BlockingExec { _refs: Arc::clone(&self.refs), })) } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "BlockingExec",) - } - } - } - - fn statistics(&self) -> Statistics { - unimplemented!() - } } /// A [`RecordBatchStream`] that is pending forever. @@ -697,6 +700,20 @@ impl PanicExec { } } +impl DisplayAs for PanicExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "PanickingExec",) + } + } + } +} + impl ExecutionPlan for PanicExec { fn as_any(&self) -> &dyn Any { self @@ -724,10 +741,7 @@ impl ExecutionPlan for PanicExec { self: Arc, _: Vec>, ) -> Result> { - Err(DataFusionError::Internal(format!( - "Children cannot be replaced in {:?}", - self - ))) + internal_err!("Children cannot be replaced in {:?}", self) } fn execute( @@ -742,22 +756,6 @@ impl ExecutionPlan for PanicExec { ready: false, })) } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "PanickingExec",) - } - } - } - - fn statistics(&self) -> Statistics { - unimplemented!() - } } /// A [`RecordBatchStream`] that yields every other batch and panics @@ -792,7 +790,7 @@ impl Stream for PanicStream { } else { self.ready = true; // get called again - cx.waker().clone().wake(); + cx.waker().wake_by_ref(); return Poll::Pending; } } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs new file mode 100644 index 0000000000000..9120566273d35 --- /dev/null +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -0,0 +1,644 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TopK: Combination of Sort / LIMIT + +use arrow::{ + compute::interleave, + row::{RowConverter, Rows, SortField}, +}; +use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; + +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + runtime_env::RuntimeEnv, +}; +use datafusion_physical_expr::PhysicalSortExpr; +use hashbrown::HashMap; + +use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; + +use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; + +/// Global TopK +/// +/// # Background +/// +/// "Top K" is a common query optimization used for queries such as +/// "find the top 3 customers by revenue". The (simplified) SQL for +/// such a query might be: +/// +/// ```sql +/// SELECT customer_id, revenue FROM 'sales.csv' ORDER BY revenue DESC limit 3; +/// ``` +/// +/// The simple plan would be: +/// +/// ```sql +/// > explain SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3; +/// +--------------+----------------------------------------+ +/// | plan_type | plan | +/// +--------------+----------------------------------------+ +/// | logical_plan | Limit: 3 | +/// | | Sort: revenue DESC NULLS FIRST | +/// | | Projection: customer_id, revenue | +/// | | TableScan: sales | +/// +--------------+----------------------------------------+ +/// ``` +/// +/// While this plan produces the correct answer, it will fully sorts the +/// input before discarding everything other than the top 3 elements. +/// +/// The same answer can be produced by simply keeping track of the top +/// K=3 elements, reducing the total amount of required buffer memory. +/// +/// # Structure +/// +/// This operator tracks the top K items using a `TopKHeap`. +pub struct TopK { + /// schema of the output (and the input) + schema: SchemaRef, + /// Runtime metrics + metrics: TopKMetrics, + /// Reservation + reservation: MemoryReservation, + /// The target number of rows for output batches + batch_size: usize, + /// sort expressions + expr: Arc<[PhysicalSortExpr]>, + /// row converter, for sort keys + row_converter: RowConverter, + /// scratch space for converting rows + scratch_rows: Rows, + /// stores the top k values and their sort key values, in order + heap: TopKHeap, +} + +impl TopK { + /// Create a new [`TopK`] that stores the top `k` values, as + /// defined by the sort expressions in `expr`. + // TOOD: make a builder or some other nicer API to avoid the + // clippy warning + #[allow(clippy::too_many_arguments)] + pub fn try_new( + partition_id: usize, + schema: SchemaRef, + expr: Vec, + k: usize, + batch_size: usize, + runtime: Arc, + metrics: &ExecutionPlanMetricsSet, + partition: usize, + ) -> Result { + let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) + .register(&runtime.memory_pool); + + let expr: Arc<[PhysicalSortExpr]> = expr.into(); + + let sort_fields: Vec<_> = expr + .iter() + .map(|e| { + Ok(SortField::new_with_options( + e.expr.data_type(&schema)?, + e.options, + )) + }) + .collect::>()?; + + // TODO there is potential to add special cases for single column sort fields + // to improve performance + let row_converter = RowConverter::new(sort_fields)?; + let scratch_rows = row_converter.empty_rows( + batch_size, + 20 * batch_size, // guestimate 20 bytes per row + ); + + Ok(Self { + schema: schema.clone(), + metrics: TopKMetrics::new(metrics, partition), + reservation, + batch_size, + expr, + row_converter, + scratch_rows, + heap: TopKHeap::new(k, batch_size, schema), + }) + } + + /// Insert `batch`, remembering if any of its values are among + /// the top k seen so far. + pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { + // Updates on drop + let _timer = self.metrics.baseline.elapsed_compute().timer(); + + let sort_keys: Vec = self + .expr + .iter() + .map(|expr| { + let value = expr.expr.evaluate(&batch)?; + value.into_array(batch.num_rows()) + }) + .collect::>>()?; + + // reuse existing `Rows` to avoid reallocations + let rows = &mut self.scratch_rows; + rows.clear(); + self.row_converter.append(rows, &sort_keys)?; + + // TODO make this algorithmically better?: + // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) + // this avoids some work and also might be better vectorizable. + let mut batch_entry = self.heap.register_batch(batch); + for (index, row) in rows.iter().enumerate() { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(&mut batch_entry, row, index); + self.metrics.row_replacements.add(1); + } + } + } + self.heap.insert_batch_entry(batch_entry); + + // conserve memory + self.heap.maybe_compact()?; + + // update memory reservation + self.reservation.try_resize(self.size())?; + Ok(()) + } + + /// Returns the top k results broken into `batch_size` [`RecordBatch`]es, consuming the heap + pub fn emit(self) -> Result { + let Self { + schema, + metrics, + reservation: _, + batch_size, + expr: _, + row_converter: _, + scratch_rows: _, + mut heap, + } = self; + let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop + + let mut batch = heap.emit()?; + metrics.baseline.output_rows().add(batch.num_rows()); + + // break into record batches as needed + let mut batches = vec![]; + loop { + if batch.num_rows() < batch_size { + batches.push(Ok(batch)); + break; + } else { + batches.push(Ok(batch.slice(0, batch_size))); + let remaining_length = batch.num_rows() - batch_size; + batch = batch.slice(batch_size, remaining_length); + } + } + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(batches), + ))) + } + + /// return the size of memory used by this operator, in bytes + fn size(&self) -> usize { + std::mem::size_of::() + + self.row_converter.size() + + self.scratch_rows.size() + + self.heap.size() + } +} + +struct TopKMetrics { + /// metrics + pub baseline: BaselineMetrics, + + /// count of how many rows were replaced in the heap + pub row_replacements: Count, +} + +impl TopKMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + row_replacements: MetricBuilder::new(metrics) + .counter("row_replacements", partition), + } + } +} + +/// This structure keeps at most the *smallest* k items, using the +/// [arrow::row] format for sort keys. While it is called "topK" for +/// values like `1, 2, 3, 4, 5` the "top 3" really means the +/// *smallest* 3 , `1, 2, 3`, not the *largest* 3 `3, 4, 5`. +/// +/// Using the `Row` format handles things such as ascending vs +/// descending and nulls first vs nulls last. +struct TopKHeap { + /// The maximum number of elemenents to store in this heap. + k: usize, + /// The target number of rows for output batches + batch_size: usize, + /// Storage for up at most `k` items using a BinaryHeap. Reverserd + /// so that the smallest k so far is on the top + inner: BinaryHeap, + /// Storage the original row values (TopKRow only has the sort key) + store: RecordBatchStore, + /// The size of all owned data held by this heap + owned_bytes: usize, +} + +impl TopKHeap { + fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self { + assert!(k > 0); + Self { + k, + batch_size, + inner: BinaryHeap::new(), + store: RecordBatchStore::new(schema), + owned_bytes: 0, + } + } + + /// Register a [`RecordBatch`] with the heap, returning the + /// appropriate entry + pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry { + self.store.register(batch) + } + + /// Insert a [`RecordBatchEntry`] created by a previous call to + /// [`Self::register_batch`] into storage. + pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) { + self.store.insert(entry) + } + + /// Returns the largest value stored by the heap if there are k + /// items, otherwise returns None. Remember this structure is + /// keeping the "smallest" k values + fn max(&self) -> Option<&TopKRow> { + if self.inner.len() < self.k { + None + } else { + self.inner.peek() + } + } + + /// Adds `row` to this heap. If inserting this new item would + /// increase the size past `k`, removes the previously smallest + /// item. + fn add( + &mut self, + batch_entry: &mut RecordBatchEntry, + row: impl AsRef<[u8]>, + index: usize, + ) { + let batch_id = batch_entry.id; + batch_entry.uses += 1; + + assert!(self.inner.len() <= self.k); + let row = row.as_ref(); + + // Reuse storage for evicted item if possible + let new_top_k = if self.inner.len() == self.k { + let prev_min = self.inner.pop().unwrap(); + + // Update batch use + if prev_min.batch_id == batch_entry.id { + batch_entry.uses -= 1; + } else { + self.store.unuse(prev_min.batch_id); + } + + // update memory accounting + self.owned_bytes -= prev_min.owned_size(); + prev_min.with_new_row(row, batch_id, index) + } else { + TopKRow::new(row, batch_id, index) + }; + + self.owned_bytes += new_top_k.owned_size(); + + // put the new row into the heap + self.inner.push(new_top_k) + } + + /// Returns the values stored in this heap, from values low to + /// high, as a single [`RecordBatch`], resetting the inner heap + pub fn emit(&mut self) -> Result { + Ok(self.emit_with_state()?.0) + } + + /// Returns the values stored in this heap, from values low to + /// high, as a single [`RecordBatch`], and a sorted vec of the + /// current heap's contents + pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec)> { + let schema = self.store.schema().clone(); + + // generate sorted rows + let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); + + if self.store.is_empty() { + return Ok((RecordBatch::new_empty(schema), topk_rows)); + } + + // Indices for each row within its respective RecordBatch + let indices: Vec<_> = topk_rows + .iter() + .enumerate() + .map(|(i, k)| (i, k.index)) + .collect(); + + let num_columns = schema.fields().len(); + + // build the output columns one at time, using the + // `interleave` kernel to pick rows from different arrays + let output_columns: Vec<_> = (0..num_columns) + .map(|col| { + let input_arrays: Vec<_> = topk_rows + .iter() + .map(|k| { + let entry = + self.store.get(k.batch_id).expect("invalid stored batch id"); + entry.batch.column(col) as &dyn Array + }) + .collect(); + + // at this point `indices` contains indexes within the + // rows and `input_arrays` contains a reference to the + // relevant Array for that index. `interleave` pulls + // them together into a single new array + Ok(interleave(&input_arrays, &indices)?) + }) + .collect::>()?; + + let new_batch = RecordBatch::try_new(schema, output_columns)?; + Ok((new_batch, topk_rows)) + } + + /// Compact this heap, rewriting all stored batches into a single + /// input batch + pub fn maybe_compact(&mut self) -> Result<()> { + // we compact if the number of "unused" rows in the store is + // past some pre-defined threshold. Target holding up to + // around 20 batches, but handle cases of large k where some + // batches might be partially full + let max_unused_rows = (20 * self.batch_size) + self.k; + let unused_rows = self.store.unused_rows(); + + // don't compact if the store has one extra batch or + // unused rows is under the threshold + if self.store.len() <= 2 || unused_rows < max_unused_rows { + return Ok(()); + } + // at first, compact the entire thing always into a new batch + // (maybe we can get fancier in the future about ignoring + // batches that have a high usage ratio already + + // Note: new batch is in the same order as inner + let num_rows = self.inner.len(); + let (new_batch, mut topk_rows) = self.emit_with_state()?; + + // clear all old entires in store (this invalidates all + // store_ids in `inner`) + self.store.clear(); + + let mut batch_entry = self.register_batch(new_batch); + batch_entry.uses = num_rows; + + // rewrite all existing entries to use the new batch, and + // remove old entries. The sortedness and their relative + // position do not change + for (i, topk_row) in topk_rows.iter_mut().enumerate() { + topk_row.batch_id = batch_entry.id; + topk_row.index = i; + } + self.insert_batch_entry(batch_entry); + // restore the heap + self.inner = BinaryHeap::from(topk_rows); + + Ok(()) + } + + /// return the size of memory used by this heap, in bytes + fn size(&self) -> usize { + std::mem::size_of::() + + (self.inner.capacity() * std::mem::size_of::()) + + self.store.size() + + self.owned_bytes + } +} + +/// Represents one of the top K rows held in this heap. Orders +/// according to memcmp of row (e.g. the arrow Row format, but could +/// also be primtive values) +/// +/// Reuses allocations to minimize runtime overhead of creating new Vecs +#[derive(Debug, PartialEq)] +struct TopKRow { + /// the value of the sort key for this row. This contains the + /// bytes that could be stored in `OwnedRow` but uses `Vec` to + /// reuse allocations. + row: Vec, + /// the RecordBatch this row came from: an id into a [`RecordBatchStore`] + batch_id: u32, + /// the index in this record batch the row came from + index: usize, +} + +impl TopKRow { + /// Create a new TopKRow with new allocation + fn new(row: impl AsRef<[u8]>, batch_id: u32, index: usize) -> Self { + Self { + row: row.as_ref().to_vec(), + batch_id, + index, + } + } + + /// Create a new TopKRow reusing the existing allocation + fn with_new_row( + self, + new_row: impl AsRef<[u8]>, + batch_id: u32, + index: usize, + ) -> Self { + let Self { + mut row, + batch_id: _, + index: _, + } = self; + row.clear(); + row.extend_from_slice(new_row.as_ref()); + + Self { + row, + batch_id, + index, + } + } + + /// Returns the number of bytes owned by this row in the heap (not + /// including itself) + fn owned_size(&self) -> usize { + self.row.capacity() + } + + /// Returns a slice to the owned row value + fn row(&self) -> &[u8] { + self.row.as_slice() + } +} + +impl Eq for TopKRow {} + +impl PartialOrd for TopKRow { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for TopKRow { + fn cmp(&self, other: &Self) -> Ordering { + self.row.cmp(&other.row) + } +} + +#[derive(Debug)] +struct RecordBatchEntry { + id: u32, + batch: RecordBatch, + // for this batch, how many times has it been used + uses: usize, +} + +/// This structure tracks [`RecordBatch`] by an id so that: +/// +/// 1. The baches can be tracked via an id that can be copied cheaply +/// 2. The total memory held by all batches is tracked +#[derive(Debug)] +struct RecordBatchStore { + /// id generator + next_id: u32, + /// storage + batches: HashMap, + /// total size of all record batches tracked by this store + batches_size: usize, + /// schema of the batches + schema: SchemaRef, +} + +impl RecordBatchStore { + fn new(schema: SchemaRef) -> Self { + Self { + next_id: 0, + batches: HashMap::new(), + batches_size: 0, + schema, + } + } + + /// Register this batch with the store and assign an ID. No + /// attempt is made to compare this batch to other batches + pub fn register(&mut self, batch: RecordBatch) -> RecordBatchEntry { + let id = self.next_id; + self.next_id += 1; + RecordBatchEntry { id, batch, uses: 0 } + } + + /// Insert a record batch entry into this store, tracking its + /// memory use, if it has any uses + pub fn insert(&mut self, entry: RecordBatchEntry) { + // uses of 0 means that none of the rows in the batch were stored in the topk + if entry.uses > 0 { + self.batches_size += entry.batch.get_array_memory_size(); + self.batches.insert(entry.id, entry); + } + } + + /// Clear all values in this store, invalidating all previous batch ids + fn clear(&mut self) { + self.batches.clear(); + self.batches_size = 0; + } + + fn get(&self, id: u32) -> Option<&RecordBatchEntry> { + self.batches.get(&id) + } + + /// returns the total number of batches stored in this store + fn len(&self) -> usize { + self.batches.len() + } + + /// Returns the total number of rows in batches minus the number + /// which are in use + fn unused_rows(&self) -> usize { + self.batches + .values() + .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses) + .sum() + } + + /// returns true if the store has nothing stored + fn is_empty(&self) -> bool { + self.batches.is_empty() + } + + /// return the schema of batches stored + fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// remove a use from the specified batch id. If the use count + /// reaches zero the batch entry is removed from the store + /// + /// panics if there were no remaining uses of id + pub fn unuse(&mut self, id: u32) { + let remove = if let Some(batch_entry) = self.batches.get_mut(&id) { + batch_entry.uses = batch_entry.uses.checked_sub(1).expect("underflow"); + batch_entry.uses == 0 + } else { + panic!("No entry for id {id}"); + }; + + if remove { + let old_entry = self.batches.remove(&id).unwrap(); + self.batches_size = self + .batches_size + .checked_sub(old_entry.batch.get_array_memory_size()) + .unwrap(); + } + } + + /// returns the size of memory used by this store, including all + /// referenced `RecordBatch`es, in bytes + pub fn size(&self) -> usize { + std::mem::size_of::() + + self.batches.capacity() + * (std::mem::size_of::() + std::mem::size_of::()) + + self.batches_size + } +} diff --git a/datafusion/core/src/physical_plan/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs similarity index 94% rename from datafusion/core/src/physical_plan/tree_node.rs rename to datafusion/physical-plan/src/tree_node.rs index fad6508fdabef..bce906a00c4d8 100644 --- a/datafusion/core/src/physical_plan/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -17,7 +17,7 @@ //! This module provides common traits for visiting or rewriting tree nodes easily. -use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::{with_new_children_if_necessary, ExecutionPlan}; use datafusion_common::tree_node::{DynTreeNode, Transformed}; use datafusion_common::Result; use std::sync::Arc; diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/physical-plan/src/udaf.rs similarity index 58% rename from datafusion/core/src/physical_plan/udaf.rs rename to datafusion/physical-plan/src/udaf.rs index d9f52eba77d0c..94017efe97aa1 100644 --- a/datafusion/core/src/physical_plan/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -27,9 +27,9 @@ use arrow::{ }; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; -use crate::physical_plan::PhysicalExpr; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, DataFusionError, Result}; pub use datafusion_expr::AggregateUDF; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -50,7 +50,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), - data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(), + data_type: fun.return_type(&input_exprs_types)?, name: name.into(), })) } @@ -83,7 +83,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - let fields = (self.fun.state_type)(&self.data_type)? + let fields = self + .fun + .state_type(&self.data_type)? .iter() .enumerate() .map(|(i, data_type)| { @@ -103,7 +105,62 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - (self.fun.accumulator)(&self.data_type) + self.fun.accumulator(&self.data_type) + } + + fn create_sliding_accumulator(&self) -> Result> { + let accumulator = self.fun.accumulator(&self.data_type)?; + + // Accumulators that have window frame startings different + // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to + // implement retract_batch method in order to run correctly + // currently in DataFusion. + // + // If this `retract_batches` is not present, there is no way + // to calculate result correctly. For example, the query + // + // ```sql + // SELECT + // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a + // FROM + // t + // ``` + // + // 1. First sum value will be the sum of rows between `[0, 1)`, + // + // 2. Second sum value will be the sum of rows between `[0, 2)` + // + // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. + // + // Since the accumulator keeps the running sum: + // + // 1. First sum we add to the state sum value between `[0, 1)` + // + // 2. Second sum we add to the state sum value between `[1, 2)` + // (`[0, 1)` is already in the state sum, hence running sum will + // cover `[0, 2)` range) + // + // 3. Third sum we add to the state sum value between `[2, 3)` + // (`[0, 2)` is already in the state sum). Also we need to + // retract values between `[0, 1)` by this way we can obtain sum + // between [1, 3) which is indeed the apropriate range. + // + // When we use `UNBOUNDED PRECEDING` in the query starting + // index will always be 0 for the desired range, and hence the + // `retract_batch` method will not be called. In this case + // having retract_batch is not a requirement. + // + // This approach is a a bit different than window function + // approach. In window function (when they use a window frame) + // they get all the desired range during evaluation. + if !accumulator.supports_retract_batch() { + return not_impl_err!( + "Aggregate can not be used as a sliding accumulator because \ + `retract_batch` is not implemented: {}", + self.name + ); + } + Ok(accumulator) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/physical-plan/src/union.rs similarity index 61% rename from datafusion/core/src/physical_plan/union.rs rename to datafusion/physical-plan/src/union.rs index a81c43398cde6..14ef9c2ec27bf 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -25,26 +25,26 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::{any::Any, sync::Arc}; -use arrow::{ - datatypes::{Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use datafusion_common::{DFSchemaRef, DataFusionError}; -use futures::Stream; -use itertools::Itertools; -use log::{debug, trace, warn}; - use super::{ expressions::PhysicalSortExpr, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::physical_plan::common::get_meet_of_orderings; -use crate::physical_plan::stream::ObservedStream; -use crate::physical_plan::{expressions, metrics::BaselineMetrics}; -use datafusion_common::Result; +use crate::common::get_meet_of_orderings; +use crate::metrics::BaselineMetrics; +use crate::stream::ObservedStream; + +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::stats::Precision; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::EquivalenceProperties; + +use futures::Stream; +use itertools::Itertools; +use log::{debug, trace, warn}; use tokio::macros::support::thread_rng_n; /// `UnionExec`: `UNION ALL` execution plan. @@ -95,38 +95,6 @@ pub struct UnionExec { } impl UnionExec { - /// Create a new UnionExec with specified schema. - /// The `schema` should always be a subset of the schema of `inputs`, - /// otherwise, an error will be returned. - pub fn try_new_with_schema( - inputs: Vec>, - schema: DFSchemaRef, - ) -> Result { - let mut exec = Self::new(inputs); - let exec_schema = exec.schema(); - let fields = schema - .fields() - .iter() - .map(|dff| { - exec_schema - .field_with_name(dff.name()) - .cloned() - .map_err(|_| { - DataFusionError::Internal(format!( - "Cannot find the field {:?} in child schema", - dff.name() - )) - }) - }) - .collect::>>()?; - let schema = Arc::new(Schema::new_with_metadata( - fields, - exec.schema().metadata().clone(), - )); - exec.schema = schema; - Ok(exec) - } - /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); @@ -144,6 +112,20 @@ impl UnionExec { } } +impl DisplayAs for UnionExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "UnionExec") + } + } + } +} + impl ExecutionPlan for UnionExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -209,6 +191,46 @@ impl ExecutionPlan for UnionExec { } } + fn equivalence_properties(&self) -> EquivalenceProperties { + // TODO: In some cases, we should be able to preserve some equivalence + // classes and constants. Add support for such cases. + let children_eqs = self + .inputs + .iter() + .map(|child| child.equivalence_properties()) + .collect::>(); + let mut result = EquivalenceProperties::new(self.schema()); + // Use the ordering equivalence class of the first child as the seed: + let mut meets = children_eqs[0] + .oeq_class() + .iter() + .map(|item| item.to_vec()) + .collect::>(); + // Iterate over all the children: + for child_eqs in &children_eqs[1..] { + // Compute meet orderings of the current meets and the new ordering + // equivalence class. + let mut idx = 0; + while idx < meets.len() { + // Find all the meets of `current_meet` with this child's orderings: + let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { + child_eqs.get_meet_ordering(ordering, &meets[idx]) + }); + // Use the longest of these meets as others are redundant: + if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { + meets[idx] = next_meet; + idx += 1; + } else { + meets.swap_remove(idx); + } + } + } + // We know have all the valid orderings after union, remove redundant + // entries (implicitly) and return: + result.add_new_orderings(meets); + result + } + fn with_new_children( self: Arc, children: Vec>, @@ -242,37 +264,28 @@ impl ExecutionPlan for UnionExec { warn!("Error in Union: Partition {} not found", partition); - Err(DataFusionError::Execution(format!( - "Partition {partition} not found in Union" - ))) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "UnionExec") - } - } + exec_err!("Partition {partition} not found in Union") } fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.inputs + fn statistics(&self) -> Result { + let stats = self + .inputs .iter() - .map(|ep| ep.statistics()) + .map(|stat| stat.statistics()) + .collect::>>()?; + + Ok(stats + .into_iter() .reduce(stats_union) - .unwrap_or_default() + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) } - fn benefits_from_input_partitioning(&self) -> bool { - false + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false; self.children().len()] } } @@ -324,9 +337,9 @@ impl InterleaveExec { let schema = union_schema(&inputs); if !can_interleave(&inputs) { - return Err(DataFusionError::Internal(String::from( - "Not all InterleaveExec children have a consistent hash partitioning", - ))); + return internal_err!( + "Not all InterleaveExec children have a consistent hash partitioning" + ); } Ok(InterleaveExec { @@ -342,6 +355,20 @@ impl InterleaveExec { } } +impl DisplayAs for InterleaveExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "InterleaveExec") + } + } + } +} + impl ExecutionPlan for InterleaveExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -416,37 +443,28 @@ impl ExecutionPlan for InterleaveExec { warn!("Error in InterleaveExec: Partition {} not found", partition); - Err(DataFusionError::Execution(format!( - "Partition {partition} not found in InterleaveExec" - ))) - } - - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "InterleaveExec") - } - } + exec_err!("Partition {partition} not found in InterleaveExec") } fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - self.inputs + fn statistics(&self) -> Result { + let stats = self + .inputs .iter() - .map(|ep| ep.statistics()) + .map(|stat| stat.statistics()) + .collect::>>()?; + + Ok(stats + .into_iter() .reduce(stats_union) - .unwrap_or_default() + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) } - fn benefits_from_input_partitioning(&self) -> bool { - false + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false; self.children().len()] } } @@ -563,58 +581,73 @@ fn col_stats_union( mut left: ColumnStatistics, right: ColumnStatistics, ) -> ColumnStatistics { - left.distinct_count = None; - left.min_value = left - .min_value - .zip(right.min_value) - .map(|(a, b)| expressions::helpers::min(&a, &b)) - .and_then(Result::ok); - left.max_value = left - .max_value - .zip(right.max_value) - .map(|(a, b)| expressions::helpers::max(&a, &b)) - .and_then(Result::ok); - left.null_count = left.null_count.zip(right.null_count).map(|(a, b)| a + b); + left.distinct_count = Precision::Absent; + left.min_value = left.min_value.min(&right.min_value); + left.max_value = left.max_value.max(&right.max_value); + left.null_count = left.null_count.add(&right.null_count); left } fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { - left.is_exact = left.is_exact && right.is_exact; - left.num_rows = left.num_rows.zip(right.num_rows).map(|(a, b)| a + b); - left.total_byte_size = left - .total_byte_size - .zip(right.total_byte_size) - .map(|(a, b)| a + b); - left.column_statistics = - left.column_statistics - .zip(right.column_statistics) - .map(|(a, b)| { - a.into_iter() - .zip(b) - .map(|(ca, cb)| col_stats_union(ca, cb)) - .collect() - }); + left.num_rows = left.num_rows.add(&right.num_rows); + left.total_byte_size = left.total_byte_size.add(&right.total_byte_size); + left.column_statistics = left + .column_statistics + .into_iter() + .zip(right.column_statistics) + .map(|(a, b)| col_stats_union(a, b)) + .collect::>(); left } #[cfg(test)] mod tests { use super::*; + use crate::collect; + use crate::memory::MemoryExec; use crate::test; - use crate::prelude::SessionContext; - use crate::{physical_plan::collect, scalar::ScalarValue}; use arrow::record_batch::RecordBatch; + use arrow_schema::{DataType, SortOptions}; + use datafusion_common::ScalarValue; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::PhysicalExpr; + + // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) + fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g])); + + Ok(schema) + } + + // Convert each tuple to PhysicalSortExpr + fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], + ) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: (*expr).clone(), + options: *options, + }) + .collect::>() + } #[tokio::test] async fn test_union_partitions() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + let task_ctx = Arc::new(TaskContext::default()); - // Create csv's with different partitioning - let csv = test::scan_partitioned_csv(4)?; - let csv2 = test::scan_partitioned_csv(5)?; + // Create inputs with different partitioning + let csv = test::scan_partitioned(4); + let csv2 = test::scan_partitioned(5); let union_exec = Arc::new(UnionExec::new(vec![csv, csv2])); @@ -630,84 +663,182 @@ mod tests { #[tokio::test] async fn test_stats_union() { let left = Statistics { - is_exact: true, - num_rows: Some(5), - total_byte_size: Some(23), - column_statistics: Some(vec![ + num_rows: Precision::Exact(5), + total_byte_size: Precision::Exact(23), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(5), - max_value: Some(ScalarValue::Int64(Some(21))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(0), + distinct_count: Precision::Exact(5), + max_value: Precision::Exact(ScalarValue::Int64(Some(21))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(0), }, ColumnStatistics { - distinct_count: Some(1), - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: Some(3), + distinct_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Exact(3), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Float32(Some(1.1))), - min_value: Some(ScalarValue::Float32(Some(0.1))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), + min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), + null_count: Precision::Absent, }, - ]), + ], }; let right = Statistics { - is_exact: true, - num_rows: Some(7), - total_byte_size: Some(29), - column_statistics: Some(vec![ + num_rows: Precision::Exact(7), + total_byte_size: Precision::Exact(29), + column_statistics: vec![ ColumnStatistics { - distinct_count: Some(3), - max_value: Some(ScalarValue::Int64(Some(34))), - min_value: Some(ScalarValue::Int64(Some(1))), - null_count: Some(1), + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(34))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + null_count: Precision::Exact(1), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Utf8(Some(String::from("c")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("b")))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::from("c")), + min_value: Precision::Exact(ScalarValue::from("b")), + null_count: Precision::Absent, }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Absent, }, - ]), + ], }; let result = stats_union(left, right); let expected = Statistics { - is_exact: true, - num_rows: Some(12), - total_byte_size: Some(52), - column_statistics: Some(vec![ + num_rows: Precision::Exact(12), + total_byte_size: Precision::Exact(52), + column_statistics: vec![ ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Int64(Some(34))), - min_value: Some(ScalarValue::Int64(Some(-4))), - null_count: Some(1), + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int64(Some(34))), + min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), + null_count: Precision::Exact(1), }, ColumnStatistics { - distinct_count: None, - max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), - min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::from("x")), + min_value: Precision::Exact(ScalarValue::from("a")), + null_count: Precision::Absent, }, ColumnStatistics { - distinct_count: None, - max_value: None, - min_value: None, - null_count: None, + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Absent, }, - ]), + ], }; assert_eq!(result, expected); } + + #[tokio::test] + async fn test_union_equivalence_properties() -> Result<()> { + let schema = create_test_schema()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + let col_e = &col("e", &schema)?; + let col_f = &col("f", &schema)?; + let options = SortOptions::default(); + let test_cases = vec![ + //-----------TEST CASE 1----------// + ( + // First child orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + // Second child orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + // Union output orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + ], + ), + //-----------TEST CASE 2----------// + ( + // First child orderings + vec![ + // [a ASC, b ASC, f ASC] + vec![(col_a, options), (col_b, options), (col_f, options)], + // d ASC + vec![(col_d, options)], + ], + // Second child orderings + vec![ + // [a ASC, b ASC, c ASC] + vec![(col_a, options), (col_b, options), (col_c, options)], + // [e ASC] + vec![(col_e, options)], + ], + // Union output orderings + vec![ + // [a ASC, b ASC] + vec![(col_a, options), (col_b, options)], + ], + ), + ]; + + for ( + test_idx, + (first_child_orderings, second_child_orderings, union_orderings), + ) in test_cases.iter().enumerate() + { + let first_orderings = first_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let second_orderings = second_child_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let union_expected_orderings = union_orderings + .iter() + .map(|ordering| convert_to_sort_exprs(ordering)) + .collect::>(); + let child1 = Arc::new( + MemoryExec::try_new(&[], schema.clone(), None)? + .with_sort_information(first_orderings), + ); + let child2 = Arc::new( + MemoryExec::try_new(&[], schema.clone(), None)? + .with_sort_information(second_orderings), + ); + + let union = UnionExec::new(vec![child1, child2]); + let union_eq_properties = union.equivalence_properties(); + let union_actual_orderings = union_eq_properties.oeq_class(); + let err_msg = format!( + "Error in test id: {:?}, test case: {:?}", + test_idx, test_cases[test_idx] + ); + assert_eq!( + union_actual_orderings.len(), + union_expected_orderings.len(), + "{}", + err_msg + ); + for expected in &union_expected_orderings { + assert!(union_actual_orderings.contains(expected), "{}", err_msg); + } + } + Ok(()) + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs new file mode 100644 index 0000000000000..af4a81626cd74 --- /dev/null +++ b/datafusion/physical-plan/src/unnest.rs @@ -0,0 +1,558 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the unnest column plan for unnesting values in a column that contains a list +//! type, conceptually is like joining each row with all the values in the list column. + +use std::time::Instant; +use std::{any::Any, sync::Arc}; + +use super::DisplayAs; +use crate::{ + expressions::Column, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PhysicalExpr, PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, +}; + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, LargeListArray, ListArray, + PrimitiveArray, +}; +use arrow::compute::kernels; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int32Type, Int64Type, Schema, SchemaRef, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::{GenericListArray, OffsetSizeTrait}; +use datafusion_common::{exec_err, DataFusionError, Result, UnnestOptions}; +use datafusion_execution::TaskContext; + +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use log::trace; + +/// Unnest the given column by joining the row with each value in the +/// nested type. +/// +/// See [`UnnestOptions`] for more details and an example. +#[derive(Debug)] +pub struct UnnestExec { + /// Input execution plan + input: Arc, + /// The schema once the unnest is applied + schema: SchemaRef, + /// The unnest column + column: Column, + /// Options + options: UnnestOptions, +} + +impl UnnestExec { + /// Create a new [UnnestExec]. + pub fn new( + input: Arc, + column: Column, + schema: SchemaRef, + options: UnnestOptions, + ) -> Self { + UnnestExec { + input, + schema, + column, + options, + } + } +} + +impl DisplayAs for UnnestExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "UnnestExec") + } + } + } +} + +impl ExecutionPlan for UnnestExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns an error to indicate this. + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(UnnestExec::new( + children[0].clone(), + self.column.clone(), + self.schema.clone(), + self.options.clone(), + ))) + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn output_partitioning(&self) -> Partitioning { + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context)?; + + Ok(Box::pin(UnnestStream { + input, + schema: self.schema.clone(), + column: self.column.clone(), + options: self.options.clone(), + num_input_batches: 0, + num_input_rows: 0, + num_output_batches: 0, + num_output_rows: 0, + unnest_time: 0, + })) + } +} + +/// A stream that issues [RecordBatch]es with unnested column data. +struct UnnestStream { + /// Input stream + input: SendableRecordBatchStream, + /// Unnested schema + schema: Arc, + /// The unnest column + column: Column, + /// Options + options: UnnestOptions, + /// number of input batches + num_input_batches: usize, + /// number of input rows + num_input_rows: usize, + /// number of batches produced + num_output_batches: usize, + /// number of rows produced + num_output_rows: usize, + /// total time for column unnesting, in ms + unnest_time: usize, +} + +impl RecordBatchStream for UnnestStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[async_trait] +impl Stream for UnnestStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_next_impl(cx) + } +} + +impl UnnestStream { + /// Separate implementation function that unpins the [`UnnestStream`] so + /// that partial borrows work correctly + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + self.input + .poll_next_unpin(cx) + .map(|maybe_batch| match maybe_batch { + Some(Ok(batch)) => { + let start = Instant::now(); + let result = + build_batch(&batch, &self.schema, &self.column, &self.options); + self.num_input_batches += 1; + self.num_input_rows += batch.num_rows(); + if let Ok(ref batch) = result { + self.unnest_time += start.elapsed().as_millis() as usize; + self.num_output_batches += 1; + self.num_output_rows += batch.num_rows(); + } + + Some(result) + } + other => { + trace!( + "Processed {} probe-side input batches containing {} rows and \ + produced {} output batches containing {} rows in {} ms", + self.num_input_batches, + self.num_input_rows, + self.num_output_batches, + self.num_output_rows, + self.unnest_time, + ); + other + } + }) + } +} + +fn build_batch( + batch: &RecordBatch, + schema: &SchemaRef, + column: &Column, + options: &UnnestOptions, +) -> Result { + let list_array = column.evaluate(batch)?.into_array(batch.num_rows())?; + match list_array.data_type() { + DataType::List(_) => { + let list_array = list_array.as_any().downcast_ref::().unwrap(); + build_batch_generic_list::( + batch, + schema, + column.index(), + list_array, + options, + ) + } + DataType::LargeList(_) => { + let list_array = list_array + .as_any() + .downcast_ref::() + .unwrap(); + build_batch_generic_list::( + batch, + schema, + column.index(), + list_array, + options, + ) + } + DataType::FixedSizeList(_, _) => { + let list_array = list_array + .as_any() + .downcast_ref::() + .unwrap(); + build_batch_fixedsize_list(batch, schema, column.index(), list_array, options) + } + _ => exec_err!("Invalid unnest column {column}"), + } +} + +fn build_batch_generic_list>( + batch: &RecordBatch, + schema: &SchemaRef, + unnest_column_idx: usize, + list_array: &GenericListArray, + options: &UnnestOptions, +) -> Result { + let unnested_array = unnest_generic_list::(list_array, options)?; + + let take_indicies = + create_take_indicies_generic::(list_array, unnested_array.len(), options); + + batch_from_indices( + batch, + schema, + unnest_column_idx, + &unnested_array, + &take_indicies, + ) +} + +/// Given this `GenericList` list_array: +/// +/// ```ignore +/// [1], null, [2, 3, 4], null, [5, 6] +/// ``` +/// Its values array is represented like this: +/// +/// ```ignore +/// [1, 2, 3, 4, 5, 6] +/// ``` +/// +/// So if there are no null values or `UnnestOptions.preserve_nulls` is false +/// we can return the values array without any copying. +/// +/// Otherwise we'll transfrom the values array using the take kernel and the following take indicies: +/// +/// ```ignore +/// 0, null, 1, 2, 3, null, 4, 5 +/// ``` +/// +fn unnest_generic_list>( + list_array: &GenericListArray, + options: &UnnestOptions, +) -> Result> { + let values = list_array.values(); + if list_array.null_count() == 0 || !options.preserve_nulls { + Ok(values.clone()) + } else { + let mut take_indicies_builder = + PrimitiveArray::

::builder(values.len() + list_array.null_count()); + let mut take_offset = 0; + + list_array.iter().for_each(|elem| match elem { + Some(array) => { + for i in 0..array.len() { + // take_offset + i is always positive + let take_index = P::Native::from_usize(take_offset + i).unwrap(); + take_indicies_builder.append_value(take_index); + } + take_offset += array.len(); + } + None => { + take_indicies_builder.append_null(); + } + }); + Ok(kernels::take::take( + &values, + &take_indicies_builder.finish(), + None, + )?) + } +} + +fn build_batch_fixedsize_list( + batch: &RecordBatch, + schema: &SchemaRef, + unnest_column_idx: usize, + list_array: &FixedSizeListArray, + options: &UnnestOptions, +) -> Result { + let unnested_array = unnest_fixed_list(list_array, options)?; + + let take_indicies = + create_take_indicies_fixed(list_array, unnested_array.len(), options); + + batch_from_indices( + batch, + schema, + unnest_column_idx, + &unnested_array, + &take_indicies, + ) +} + +/// Given this `FixedSizeListArray` list_array: +/// +/// ```ignore +/// [1, 2], null, [3, 4], null, [5, 6] +/// ``` +/// Its values array is represented like this: +/// +/// ```ignore +/// [1, 2, null, null 3, 4, null, null, 5, 6] +/// ``` +/// +/// So if there are no null values +/// we can return the values array without any copying. +/// +/// Otherwise we'll transfrom the values array using the take kernel. +/// +/// If `UnnestOptions.preserve_nulls` is true the take indicies will look like this: +/// +/// ```ignore +/// 0, 1, null, 4, 5, null, 8, 9 +/// ``` +/// Otherwise we drop the nulls and take indicies will look like this: +/// +/// ```ignore +/// 0, 1, 4, 5, 8, 9 +/// ``` +/// +fn unnest_fixed_list( + list_array: &FixedSizeListArray, + options: &UnnestOptions, +) -> Result> { + let values = list_array.values(); + + if list_array.null_count() == 0 { + Ok(values.clone()) + } else { + let len_without_nulls = + values.len() - list_array.null_count() * list_array.value_length() as usize; + let null_count = if options.preserve_nulls { + list_array.null_count() + } else { + 0 + }; + let mut builder = + PrimitiveArray::::builder(len_without_nulls + null_count); + let mut take_offset = 0; + let fixed_value_length = list_array.value_length() as usize; + list_array.iter().for_each(|elem| match elem { + Some(_) => { + for i in 0..fixed_value_length { + //take_offset + i is always positive + let take_index = take_offset + i; + builder.append_value(take_index as i32); + } + take_offset += fixed_value_length; + } + None => { + if options.preserve_nulls { + builder.append_null(); + } + take_offset += fixed_value_length; + } + }); + Ok(kernels::take::take(&values, &builder.finish(), None)?) + } +} + +/// Creates take indicies to be used to expand all other column's data. +/// Every column value needs to be repeated as many times as many elements there is in each corresponding array value. +/// +/// If the column being unnested looks like this: +/// +/// ```ignore +/// [1], null, [2, 3, 4], null, [5, 6] +/// ``` +/// Then `create_take_indicies_generic` will return an array like this +/// +/// ```ignore +/// [1, null, 2, 2, 2, null, 4, 4] +/// ``` +/// +fn create_take_indicies_generic>( + list_array: &GenericListArray, + capacity: usize, + options: &UnnestOptions, +) -> PrimitiveArray

{ + let mut builder = PrimitiveArray::

::builder(capacity); + let null_repeat: usize = if options.preserve_nulls { 1 } else { 0 }; + + for row in 0..list_array.len() { + let repeat = if list_array.is_null(row) { + null_repeat + } else { + list_array.value(row).len() + }; + + // `index` is a positive interger. + let index = P::Native::from_usize(row).unwrap(); + (0..repeat).for_each(|_| builder.append_value(index)); + } + + builder.finish() +} + +fn create_take_indicies_fixed( + list_array: &FixedSizeListArray, + capacity: usize, + options: &UnnestOptions, +) -> PrimitiveArray { + let mut builder = PrimitiveArray::::builder(capacity); + let null_repeat: usize = if options.preserve_nulls { 1 } else { 0 }; + + for row in 0..list_array.len() { + let repeat = if list_array.is_null(row) { + null_repeat + } else { + list_array.value_length() as usize + }; + + // `index` is a positive interger. + let index = ::Native::from_usize(row).unwrap(); + (0..repeat).for_each(|_| builder.append_value(index)); + } + + builder.finish() +} + +/// Create the final batch given the unnested column array and a `indices` array +/// that is used by the take kernel to copy values. +/// +/// For example if we have the following `RecordBatch`: +/// +/// ```ignore +/// c1: [1], null, [2, 3, 4], null, [5, 6] +/// c2: 'a', 'b', 'c', null, 'd' +/// ``` +/// +/// then the `unnested_array` contains the unnest column that will replace `c1` in +/// the final batch: +/// +/// ```ignore +/// c1: 1, null, 2, 3, 4, null, 5, 6 +/// ``` +/// +/// And the `indices` array contains the indices that are used by `take` kernel to +/// repeat the values in `c2`: +/// +/// ```ignore +/// 0, 1, 2, 2, 2, 3, 4, 4 +/// ``` +/// +/// so that the final batch will look like: +/// +/// ```ignore +/// c1: 1, null, 2, 3, 4, null, 5, 6 +/// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' +/// ``` +/// +fn batch_from_indices( + batch: &RecordBatch, + schema: &SchemaRef, + unnest_column_idx: usize, + unnested_array: &ArrayRef, + indices: &PrimitiveArray, +) -> Result +where + T: ArrowPrimitiveType, +{ + let arrays = batch + .columns() + .iter() + .enumerate() + .map(|(col_idx, arr)| { + if col_idx == unnest_column_idx { + Ok(unnested_array.clone()) + } else { + Ok(kernels::take::take(&arr, indices, None)?) + } + }) + .collect::>>()?; + + Ok(RecordBatch::try_new(schema.clone(), arrays.to_vec())?) +} diff --git a/datafusion/core/src/physical_plan/values.rs b/datafusion/physical-plan/src/values.rs similarity index 65% rename from datafusion/core/src/physical_plan/values.rs rename to datafusion/physical-plan/src/values.rs index d1cf6927a2e30..b624fb362e656 100644 --- a/datafusion/core/src/physical_plan/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -17,20 +17,21 @@ //! Values execution plan +use std::any::Any; +use std::sync::Arc; + use super::expressions::PhysicalSortExpr; -use super::{common, SendableRecordBatchStream, Statistics}; -use crate::physical_plan::{ +use super::{common, DisplayAs, SendableRecordBatchStream, Statistics}; +use crate::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; + use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_execution::TaskContext; -use std::any::Any; -use std::sync::Arc; /// Execution plan for values list based relation (produces constant rows) #[derive(Debug)] @@ -48,7 +49,7 @@ impl ValuesExec { data: Vec>>, ) -> Result { if data.is_empty() { - return Err(DataFusionError::Plan("Values list cannot be empty".into())); + return plan_err!("Values list cannot be empty"); } let n_row = data.len(); let n_col = schema.fields().len(); @@ -66,15 +67,16 @@ impl ValuesExec { (0..n_row) .map(|i| { let r = data[i][j].evaluate(&batch); + match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) + Ok(ScalarValue::List(a)) } Ok(ColumnarValue::Array(a)) => { - Err(DataFusionError::Plan(format!( + plan_err!( "Cannot have array values {a:?} in a values list" - ))) + ) } Err(err) => Err(err), } @@ -88,12 +90,53 @@ impl ValuesExec { Ok(Self { schema, data }) } + /// Create a new plan using the provided schema and batches. + /// + /// Errors if any of the batches don't match the provided schema, or if no + /// batches are provided. + pub fn try_new_from_batches( + schema: SchemaRef, + batches: Vec, + ) -> Result { + if batches.is_empty() { + return plan_err!("Values list cannot be empty"); + } + + for batch in &batches { + let batch_schema = batch.schema(); + if batch_schema != schema { + return plan_err!( + "Batch has invalid schema. Expected: {schema}, got: {batch_schema}" + ); + } + } + + Ok(ValuesExec { + schema, + data: batches, + }) + } + /// provides the data - fn data(&self) -> Vec { + pub fn data(&self) -> Vec { self.data.clone() } } +impl DisplayAs for ValuesExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ValuesExec") + } + } + } +} + impl ExecutionPlan for ValuesExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -133,9 +176,9 @@ impl ExecutionPlan for ValuesExec { ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { - return Err(DataFusionError::Internal(format!( + return internal_err!( "ValuesExec invalid partition {partition} (expected 0)" - ))); + ); } Ok(Box::pin(MemoryStream::try_new( @@ -145,34 +188,56 @@ impl ExecutionPlan for ValuesExec { )?)) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "ValuesExec") - } - } - } - - fn statistics(&self) -> Statistics { + fn statistics(&self) -> Result { let batch = self.data(); - common::compute_record_batch_statistics(&[batch], &self.schema, None) + Ok(common::compute_record_batch_statistics( + &[batch], + &self.schema, + None, + )) } } #[cfg(test)] mod tests { use super::*; - use crate::test_util; + use crate::test::{self, make_partition}; + + use arrow_schema::{DataType, Field, Schema}; #[tokio::test] async fn values_empty_case() -> Result<()> { - let schema = test_util::aggr_test_schema(); + let schema = test::aggr_test_schema(); let empty = ValuesExec::try_new(schema, vec![]); assert!(empty.is_err()); Ok(()) } + + #[test] + fn new_exec_with_batches() { + let batch = make_partition(7); + let schema = batch.schema(); + let batches = vec![batch.clone(), batch]; + + let _exec = ValuesExec::try_new_from_batches(schema, batches).unwrap(); + } + + #[test] + fn new_exec_with_batches_empty() { + let batch = make_partition(7); + let schema = batch.schema(); + let _ = ValuesExec::try_new_from_batches(schema, Vec::new()).unwrap_err(); + } + + #[test] + fn new_exec_with_batches_invalid_schema() { + let batch = make_partition(7); + let batches = vec![batch.clone(), batch]; + + let invalid_schema = Arc::new(Schema::new(vec![ + Field::new("col0", DataType::UInt32, false), + Field::new("col1", DataType::Utf8, false), + ])); + let _ = ValuesExec::try_new_from_batches(invalid_schema, batches).unwrap_err(); + } } diff --git a/datafusion/physical-plan/src/visitor.rs b/datafusion/physical-plan/src/visitor.rs new file mode 100644 index 0000000000000..ca826c50022d4 --- /dev/null +++ b/datafusion/physical-plan/src/visitor.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::ExecutionPlan; + +/// Visit all children of this plan, according to the order defined on `ExecutionPlanVisitor`. +// Note that this would be really nice if it were a method on +// ExecutionPlan, but it can not be because it takes a generic +// parameter and `ExecutionPlan` is a trait +pub fn accept( + plan: &dyn ExecutionPlan, + visitor: &mut V, +) -> Result<(), V::Error> { + visitor.pre_visit(plan)?; + for child in plan.children() { + visit_execution_plan(child.as_ref(), visitor)?; + } + visitor.post_visit(plan)?; + Ok(()) +} + +/// Trait that implements the [Visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for a +/// depth first walk of `ExecutionPlan` nodes. `pre_visit` is called +/// before any children are visited, and then `post_visit` is called +/// after all children have been visited. +/// +/// To use, define a struct that implements this trait and then invoke +/// ['accept']. +/// +/// For example, for an execution plan that looks like: +/// +/// ```text +/// ProjectionExec: id +/// FilterExec: state = CO +/// CsvExec: +/// ``` +/// +/// The sequence of visit operations would be: +/// ```text +/// visitor.pre_visit(ProjectionExec) +/// visitor.pre_visit(FilterExec) +/// visitor.pre_visit(CsvExec) +/// visitor.post_visit(CsvExec) +/// visitor.post_visit(FilterExec) +/// visitor.post_visit(ProjectionExec) +/// ``` +pub trait ExecutionPlanVisitor { + /// The type of error returned by this visitor + type Error; + + /// Invoked on an `ExecutionPlan` plan before any of its child + /// inputs have been visited. If Ok(true) is returned, the + /// recursion continues. If Err(..) or Ok(false) are returned, the + /// recursion stops immediately and the error, if any, is returned + /// to `accept` + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result; + + /// Invoked on an `ExecutionPlan` plan *after* all of its child + /// inputs have been visited. The return value is handled the same + /// as the return value of `pre_visit`. The provided default + /// implementation returns `Ok(true)`. + fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { + Ok(true) + } +} + +/// Recursively calls `pre_visit` and `post_visit` for this node and +/// all of its children, as described on [`ExecutionPlanVisitor`] +pub fn visit_execution_plan( + plan: &dyn ExecutionPlan, + visitor: &mut V, +) -> Result<(), V::Error> { + visitor.pre_visit(plan)?; + for child in plan.children() { + visit_execution_plan(child.as_ref(), visitor)?; + } + visitor.post_visit(plan)?; + Ok(()) +} diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs similarity index 82% rename from datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs rename to datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 3a95308503e42..431a43bc6055b 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -20,66 +20,54 @@ //! the input data seen so far), which makes it appropriate when processing //! infinite inputs. -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, -}; -use crate::physical_plan::windows::{ - calc_requirements, get_ordered_partition_by_indices, window_ordering_equivalence, +use std::any::Any; +use std::cmp::{min, Ordering}; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::expressions::PhysicalSortExpr; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::windows::{ + calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, + window_equivalence_properties, }; -use crate::physical_plan::{ - ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, +use crate::{ + ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + InputOrderMode, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, WindowExpr, }; -use datafusion_common::Result; -use datafusion_execution::TaskContext; -use ahash::RandomState; use arrow::{ - array::{Array, ArrayRef, UInt32Builder}, + array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices}, datatypes::{Schema, SchemaBuilder, SchemaRef}, record_batch::RecordBatch, }; -use futures::stream::Stream; -use futures::{ready, StreamExt}; -use hashbrown::raw::RawTable; -use indexmap::IndexMap; -use log::debug; - -use std::any::Any; -use std::cmp::{min, Ordering}; -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::stats::Precision; use datafusion_common::utils::{ evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, get_record_batch_at_indices, get_row_at_idx, }; -use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr::hash_utils::create_hashes; use datafusion_physical_expr::window::{ - PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, - WindowAggState, WindowState, + PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; use datafusion_physical_expr::{ - EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, - PhysicalSortRequirement, + EquivalenceProperties, PhysicalExpr, PhysicalSortRequirement, }; -#[derive(Debug, Clone, PartialEq)] -/// Specifies partition column properties in terms of input ordering -pub enum PartitionSearchMode { - /// None of the columns among the partition columns is ordered. - Linear, - /// Some columns of the partition columns are ordered but not all - PartiallySorted(Vec), - /// All Partition columns are ordered (Also empty case) - Sorted, -} +use ahash::RandomState; +use futures::stream::Stream; +use futures::{ready, StreamExt}; +use hashbrown::raw::RawTable; +use indexmap::IndexMap; +use log::debug; /// Window execution plan #[derive(Debug)] @@ -90,14 +78,12 @@ pub struct BoundedWindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Partition by search mode - pub partition_search_mode: PartitionSearchMode, + /// Describes how the input is ordered relative to the partition keys + pub input_order_mode: InputOrderMode, /// Partition by indices that define ordering // For example, if input ordering is ORDER BY a, b and window expression // contains PARTITION BY b, a; `ordered_partition_by_indices` would be 1, 0. @@ -112,15 +98,14 @@ impl BoundedWindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, - partition_search_mode: PartitionSearchMode, + input_order_mode: InputOrderMode, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let partition_by_exprs = window_expr[0].partition_by(); - let ordered_partition_by_indices = match &partition_search_mode { - PartitionSearchMode::Sorted => { + let ordered_partition_by_indices = match &input_order_mode { + InputOrderMode::Sorted => { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, @@ -131,10 +116,8 @@ impl BoundedWindowAggExec { (0..partition_by_exprs.len()).collect::>() } } - PartitionSearchMode::PartiallySorted(ordered_indices) => { - ordered_indices.clone() - } - PartitionSearchMode::Linear => { + InputOrderMode::PartiallySorted(ordered_indices) => ordered_indices.clone(), + InputOrderMode::Linear => { vec![] } }; @@ -142,10 +125,9 @@ impl BoundedWindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), - partition_search_mode, + input_order_mode, ordered_partition_by_indices, }) } @@ -160,20 +142,18 @@ impl BoundedWindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { - // Partition by sort keys indices are stored in self.ordered_partition_by_indices. - let sort_keys = self.input.output_ordering().unwrap_or(&[]); - get_at_indices(sort_keys, &self.ordered_partition_by_indices) + let partition_by = self.window_expr()[0].partition_by(); + get_partition_by_sort_exprs( + &self.input, + partition_by, + &self.ordered_partition_by_indices, + ) } /// Initializes the appropriate [`PartitionSearcher`] implementation from @@ -181,26 +161,55 @@ impl BoundedWindowAggExec { fn get_search_algo(&self) -> Result> { let partition_by_sort_keys = self.partition_by_sort_keys()?; let ordered_partition_by_indices = self.ordered_partition_by_indices.clone(); - Ok(match &self.partition_search_mode { - PartitionSearchMode::Sorted => { + Ok(match &self.input_order_mode { + InputOrderMode::Sorted => { // In Sorted mode, all partition by columns should be ordered. if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() { - return Err(DataFusionError::Execution("All partition by columns should have an ordering in Sorted mode.".to_string())); + return exec_err!("All partition by columns should have an ordering in Sorted mode."); } Box::new(SortedSearch { partition_by_sort_keys, ordered_partition_by_indices, }) } - PartitionSearchMode::Linear | PartitionSearchMode::PartiallySorted(_) => { + InputOrderMode::Linear | InputOrderMode::PartiallySorted(_) => { Box::new(LinearSearch::new(ordered_partition_by_indices)) } }) } } +impl DisplayAs for BoundedWindowAggExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BoundedWindowAggExec: ")?; + let g: Vec = self + .window_expr + .iter() + .map(|e| { + format!( + "{}: {:?}, frame: {:?}", + e.name().to_owned(), + e.field(), + e.get_window_frame() + ) + }) + .collect(); + let mode = &self.input_order_mode; + write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; + } + } + Ok(()) + } +} + impl ExecutionPlan for BoundedWindowAggExec { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -234,7 +243,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_ordering(&self) -> Vec>> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.partition_search_mode != PartitionSearchMode::Sorted + if self.input_order_mode != InputOrderMode::Sorted || self.ordered_partition_by_indices.len() >= partition_bys.len() { let partition_bys = self @@ -256,13 +265,9 @@ impl ExecutionPlan for BoundedWindowAggExec { } } + /// Get the [`EquivalenceProperties`] within the plan fn equivalence_properties(&self) -> EquivalenceProperties { - self.input().equivalence_properties() - } - - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - window_ordering_equivalence(&self.schema, &self.input, &self.window_expr) + window_equivalence_properties(&self.schema, &self.input, &self.window_expr) } fn maintains_input_order(&self) -> Vec { @@ -276,9 +281,8 @@ impl ExecutionPlan for BoundedWindowAggExec { Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), - self.partition_search_mode.clone(), + self.input_order_mode.clone(), )?)) } @@ -299,55 +303,26 @@ impl ExecutionPlan for BoundedWindowAggExec { Ok(stream) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "BoundedWindowAggExec: ")?; - let g: Vec = self - .window_expr - .iter() - .map(|e| { - format!( - "{}: {:?}, frame: {:?}", - e.name().to_owned(), - e.field(), - e.get_window_frame() - ) - }) - .collect(); - let mode = &self.partition_search_mode; - write!(f, "wdw=[{}], mode=[{:?}]", g.join(", "), mode)?; - } - } - Ok(()) - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stat = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stat = self.input.statistics()?; let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - if let Some(input_col_stats) = input_stat.column_statistics { - column_statistics.extend(input_col_stats); - } else { - column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) } - column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); - Statistics { - is_exact: input_stat.is_exact, + Ok(Statistics { num_rows: input_stat.num_rows, - column_statistics: Some(column_statistics), - total_byte_size: None, - } + column_statistics, + total_byte_size: Precision::Absent, + }) } } @@ -609,9 +584,9 @@ impl LinearSearch { .iter() .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(scalar) => Err(DataFusionError::Plan(format!( - "Sort operation is not applicable to scalar value {scalar}" - ))), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(record_batch.num_rows()) + } }) .collect() } @@ -948,7 +923,7 @@ impl BoundedWindowAggStream { .columns() .iter() .map(|elem| elem.slice(0, n_out)) - .chain(window_expr_out.into_iter()) + .chain(window_expr_out) .collect::>(); let n_generated = columns_to_show[0].len(); self.prune_state(n_generated)?; @@ -1051,8 +1026,11 @@ impl BoundedWindowAggStream { .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer = - RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + self.input_buffer = RecordBatch::try_new_with_options( + self.input_buffer.schema(), + batch_to_keep, + &RecordBatchOptions::new().with_row_count(Some(n_to_keep)), + )?; Ok(()) } @@ -1126,10 +1104,138 @@ fn get_aggregate_result_out_column( } } if running_length != len_to_show { - return Err(DataFusionError::Execution(format!( + return exec_err!( "Generated row number should be {len_to_show}, it is {running_length}" - ))); + ); } result .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } + +#[cfg(test)] +mod tests { + use crate::common::collect; + use crate::memory::MemoryExec; + use crate::windows::{BoundedWindowAggExec, InputOrderMode}; + use crate::{get_plan_string, ExecutionPlan}; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{assert_batches_eq, Result, ScalarValue}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::expressions::NthValue; + use datafusion_physical_expr::window::BuiltInWindowExpr; + use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; + use std::sync::Arc; + + // Tests NTH_VALUE(negative index) with memoize feature. + // To be able to trigger memoize feature for NTH_VALUE we need to + // - feed BoundedWindowAggExec with batch stream data. + // - Window frame should contain UNBOUNDED PRECEDING. + // It hard to ensure these conditions are met, from the sql query. + #[tokio::test] + async fn test_window_nth_value_bounded_memoize() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], + )?; + + let memory_exec = MemoryExec::try_new( + &[vec![batch.clone(), batch.clone(), batch.clone()]], + schema.clone(), + None, + ) + .map(|e| Arc::new(e) as Arc)?; + let col_a = col("a", &schema)?; + let nth_value_func1 = + NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1)? + .reverse_expr() + .unwrap(); + let nth_value_func2 = + NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2)? + .reverse_expr() + .unwrap(); + let last_value_func = + Arc::new(NthValue::last("last", col_a.clone(), DataType::Int32)) as _; + let window_exprs = vec![ + // LAST_VALUE(a) + Arc::new(BuiltInWindowExpr::new( + last_value_func, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -1) + Arc::new(BuiltInWindowExpr::new( + nth_value_func1, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + // NTH_VALUE(a, -2) + Arc::new(BuiltInWindowExpr::new( + nth_value_func2, + &[], + &[], + Arc::new(WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + end_bound: WindowFrameBound::CurrentRow, + }), + )) as _, + ]; + let physical_plan = BoundedWindowAggExec::try_new( + window_exprs, + memory_exec, + vec![], + InputOrderMode::Sorted, + ) + .map(|e| Arc::new(e) as Arc)?; + + let batches = collect(physical_plan.execute(0, task_ctx)?).await?; + + let expected = vec![ + "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[3]", + ]; + // Get string representation of the plan + let actual = get_plan_string(&physical_plan); + assert_eq!( + expected, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = [ + "+---+------+---------------+---------------+", + "| a | last | nth_value(-1) | nth_value(-2) |", + "+---+------+---------------+---------------+", + "| 1 | 1 | 1 | |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "| 1 | 1 | 1 | 3 |", + "| 2 | 2 | 2 | 1 |", + "| 3 | 3 | 3 | 2 |", + "+---+------+---------------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs new file mode 100644 index 0000000000000..3187e6b0fbd3f --- /dev/null +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -0,0 +1,1031 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expressions for window functions + +use std::borrow::Borrow; +use std::convert::TryInto; +use std::sync::Arc; + +use crate::{ + aggregates, + expressions::{ + cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, + PhysicalSortExpr, RowNumber, + }, + udaf, unbounded_output, ExecutionPlan, InputOrderMode, PhysicalExpr, +}; + +use arrow::datatypes::Schema; +use arrow_schema::{DataType, Field, SchemaRef}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + window_function::{BuiltInWindowFunction, WindowFunction}, + PartitionEvaluator, WindowFrame, WindowUDF, +}; +use datafusion_physical_expr::equivalence::collapse_lex_req; +use datafusion_physical_expr::{ + reverse_order_bys, + window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, + AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, +}; + +mod bounded_window_agg_exec; +mod window_agg_exec; + +pub use bounded_window_agg_exec::BoundedWindowAggExec; +pub use window_agg_exec::WindowAggExec; + +pub use datafusion_physical_expr::window::{ + BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, +}; + +/// Create a physical expression for window function +pub fn create_window_expr( + fun: &WindowFunction, + name: String, + args: &[Arc], + partition_by: &[Arc], + order_by: &[PhysicalSortExpr], + window_frame: Arc, + input_schema: &Schema, +) -> Result> { + Ok(match fun { + WindowFunction::AggregateFunction(fun) => { + let aggregate = aggregates::create_aggregate_expr( + fun, + false, + args, + &[], + input_schema, + name, + )?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } + WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), + WindowFunction::AggregateUDF(fun) => { + let aggregate = + udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } + WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )), + }) +} + +/// Creates an appropriate [`WindowExpr`] based on the window frame and +fn window_expr_from_aggregate_expr( + partition_by: &[Arc], + order_by: &[PhysicalSortExpr], + window_frame: Arc, + aggregate: Arc, +) -> Arc { + // Is there a potentially unlimited sized window frame? + let unbounded_window = window_frame.start_bound.is_unbounded(); + + if !unbounded_window { + Arc::new(SlidingAggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } else { + Arc::new(PlainAggregateWindowExpr::new( + aggregate, + partition_by, + order_by, + window_frame, + )) + } +} + +fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .value() + .clone(); + Some(tmp) + } else { + None + }) +} + +fn create_built_in_window_expr( + fun: &BuiltInWindowFunction, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + Ok(match fun { + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)), + BuiltInWindowFunction::Rank => Arc::new(rank(name)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)), + BuiltInWindowFunction::Ntile => { + let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if n.is_unsigned() { + let n: u64 = n.try_into()?; + Arc::new(Ntile::new(name, n)) + } else { + let n: i64 = n.try_into()?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Arc::new(Ntile::new(name, n as u64)) + } + } + BuiltInWindowFunction::Lag => { + let arg = args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + let shift_offset = get_scalar_value_from_args(args, 1)? + .map(|v| v.try_into()) + .and_then(|v| v.ok()); + let default_value = get_scalar_value_from_args(args, 2)?; + Arc::new(lag(name, data_type, arg, shift_offset, default_value)) + } + BuiltInWindowFunction::Lead => { + let arg = args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + let shift_offset = get_scalar_value_from_args(args, 1)? + .map(|v| v.try_into()) + .and_then(|v| v.ok()); + let default_value = get_scalar_value_from_args(args, 2)?; + Arc::new(lead(name, data_type, arg, shift_offset, default_value)) + } + BuiltInWindowFunction::NthValue => { + let arg = args[0].clone(); + let n = args[1].as_any().downcast_ref::().unwrap().value(); + let n: i64 = n + .clone() + .try_into() + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + let n: u32 = n as u32; + let data_type = args[0].data_type(input_schema)?; + Arc::new(NthValue::nth(name, arg, data_type, n)?) + } + BuiltInWindowFunction::FirstValue => { + let arg = args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(NthValue::first(name, arg, data_type)) + } + BuiltInWindowFunction::LastValue => { + let arg = args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(NthValue::last(name, arg, data_type)) + } + }) +} + +/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function +fn create_udwf_window_expr( + fun: &Arc, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + // need to get the types into an owned vec for some reason + let input_types: Vec<_> = args + .iter() + .map(|arg| arg.data_type(input_schema)) + .collect::>()?; + + // figure out the output type + let data_type = fun.return_type(&input_types)?; + Ok(Arc::new(WindowUDFExpr { + fun: Arc::clone(fun), + args: args.to_vec(), + name, + data_type, + })) +} + +/// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`] +#[derive(Clone, Debug)] +struct WindowUDFExpr { + fun: Arc, + args: Vec>, + /// Display name + name: String, + /// result type + data_type: DataType, +} + +impl BuiltInWindowFunctionExpr for WindowUDFExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn create_evaluator(&self) -> Result> { + self.fun.partition_evaluator_factory() + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + None + } +} + +pub(crate) fn calc_requirements< + T: Borrow>, + S: Borrow, +>( + partition_by_exprs: impl IntoIterator, + orderby_sort_exprs: impl IntoIterator, +) -> Option> { + let mut sort_reqs = partition_by_exprs + .into_iter() + .map(|partition_by| { + PhysicalSortRequirement::new(partition_by.borrow().clone(), None) + }) + .collect::>(); + for element in orderby_sort_exprs.into_iter() { + let PhysicalSortExpr { expr, options } = element.borrow(); + if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { + sort_reqs.push(PhysicalSortRequirement::new(expr.clone(), Some(*options))); + } + } + // Convert empty result to None. Otherwise wrap result inside Some() + (!sort_reqs.is_empty()).then_some(sort_reqs) +} + +/// This function calculates the indices such that when partition by expressions reordered with this indices +/// resulting expressions define a preset for existing ordering. +// For instance, if input is ordered by a, b, c and PARTITION BY b, a is used +// This vector will be [1, 0]. It means that when we iterate b,a columns with the order [1, 0] +// resulting vector (a, b) is a preset of the existing ordering (a, b, c). +pub(crate) fn get_ordered_partition_by_indices( + partition_by_exprs: &[Arc], + input: &Arc, +) -> Vec { + let (_, indices) = input + .equivalence_properties() + .find_longest_permutation(partition_by_exprs); + indices +} + +pub(crate) fn get_partition_by_sort_exprs( + input: &Arc, + partition_by_exprs: &[Arc], + ordered_partition_by_indices: &[usize], +) -> Result { + let ordered_partition_exprs = ordered_partition_by_indices + .iter() + .map(|idx| partition_by_exprs[*idx].clone()) + .collect::>(); + // Make sure ordered section doesn't move over the partition by expression + assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); + let (ordering, _) = input + .equivalence_properties() + .find_longest_permutation(&ordered_partition_exprs); + if ordering.len() == ordered_partition_exprs.len() { + Ok(ordering) + } else { + exec_err!("Expects PARTITION BY expression to be ordered") + } +} + +pub(crate) fn window_equivalence_properties( + schema: &SchemaRef, + input: &Arc, + window_expr: &[Arc], +) -> EquivalenceProperties { + // We need to update the schema, so we can not directly use + // `input.equivalence_properties()`. + let mut window_eq_properties = + EquivalenceProperties::new(schema.clone()).extend(input.equivalence_properties()); + + for expr in window_expr { + if let Some(builtin_window_expr) = + expr.as_any().downcast_ref::() + { + builtin_window_expr.add_equal_orderings(&mut window_eq_properties); + } + } + window_eq_properties +} + +/// Constructs the best-fitting windowing operator (a `WindowAggExec` or a +/// `BoundedWindowExec`) for the given `input` according to the specifications +/// of `window_exprs` and `physical_partition_keys`. Here, best-fitting means +/// not requiring additional sorting and/or partitioning for the given input. +/// - A return value of `None` represents that there is no way to construct a +/// windowing operator that doesn't need additional sorting/partitioning for +/// the given input. Existing ordering should be changed to run the given +/// windowing operation. +/// - A `Some(window exec)` value contains the optimal windowing operator (a +/// `WindowAggExec` or a `BoundedWindowExec`) for the given input. +pub fn get_best_fitting_window( + window_exprs: &[Arc], + input: &Arc, + // These are the partition keys used during repartitioning. + // They are either the same with `window_expr`'s PARTITION BY columns, + // or it is empty if partitioning is not desirable for this windowing operator. + physical_partition_keys: &[Arc], +) -> Result>> { + // Contains at least one window expr and all of the partition by and order by sections + // of the window_exprs are same. + let partitionby_exprs = window_exprs[0].partition_by(); + let orderby_keys = window_exprs[0].order_by(); + let (should_reverse, input_order_mode) = + if let Some((should_reverse, input_order_mode)) = + get_window_mode(partitionby_exprs, orderby_keys, input) + { + (should_reverse, input_order_mode) + } else { + return Ok(None); + }; + let is_unbounded = unbounded_output(input); + if !is_unbounded && input_order_mode != InputOrderMode::Sorted { + // Executor has bounded input and `input_order_mode` is not `InputOrderMode::Sorted` + // in this case removing the sort is not helpful, return: + return Ok(None); + }; + + let window_expr = if should_reverse { + if let Some(reversed_window_expr) = window_exprs + .iter() + .map(|e| e.get_reverse_expr()) + .collect::>>() + { + reversed_window_expr + } else { + // Cannot take reverse of any of the window expr + // In this case, with existing ordering window cannot be run + return Ok(None); + } + } else { + window_exprs.to_vec() + }; + + // If all window expressions can run with bounded memory, choose the + // bounded window variant: + if window_expr.iter().all(|e| e.uses_bounded_memory()) { + Ok(Some(Arc::new(BoundedWindowAggExec::try_new( + window_expr, + input.clone(), + physical_partition_keys.to_vec(), + input_order_mode, + )?) as _)) + } else if input_order_mode != InputOrderMode::Sorted { + // For `WindowAggExec` to work correctly PARTITION BY columns should be sorted. + // Hence, if `input_order_mode` is not `Sorted` we should convert + // input ordering such that it can work with `Sorted` (add `SortExec`). + // Effectively `WindowAggExec` works only in `Sorted` mode. + Ok(None) + } else { + Ok(Some(Arc::new(WindowAggExec::try_new( + window_expr, + input.clone(), + physical_partition_keys.to_vec(), + )?) as _)) + } +} + +/// Compares physical ordering (output ordering of the `input` operator) with +/// `partitionby_exprs` and `orderby_keys` to decide whether existing ordering +/// is sufficient to run the current window operator. +/// - A `None` return value indicates that we can not remove the sort in question +/// (input ordering is not sufficient to run current window executor). +/// - A `Some((bool, InputOrderMode))` value indicates that the window operator +/// can run with existing input ordering, so we can remove `SortExec` before it. +/// The `bool` field in the return value represents whether we should reverse window +/// operator to remove `SortExec` before it. The `InputOrderMode` field represents +/// the mode this window operator should work in to accommodate the existing ordering. +pub fn get_window_mode( + partitionby_exprs: &[Arc], + orderby_keys: &[PhysicalSortExpr], + input: &Arc, +) -> Option<(bool, InputOrderMode)> { + let input_eqs = input.equivalence_properties(); + let mut partition_by_reqs: Vec = vec![]; + let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); + partition_by_reqs.extend(indices.iter().map(|&idx| PhysicalSortRequirement { + expr: partitionby_exprs[idx].clone(), + options: None, + })); + // Treat partition by exprs as constant. During analysis of requirements are satisfied. + let partition_by_eqs = input_eqs.add_constants(partitionby_exprs.iter().cloned()); + let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); + let reverse_order_by_reqs = + PhysicalSortRequirement::from_sort_exprs(&reverse_order_bys(orderby_keys)); + for (should_swap, order_by_reqs) in + [(false, order_by_reqs), (true, reverse_order_by_reqs)] + { + let req = [partition_by_reqs.clone(), order_by_reqs].concat(); + let req = collapse_lex_req(req); + if partition_by_eqs.ordering_satisfy_requirement(&req) { + // Window can be run with existing ordering + let mode = if indices.len() == partitionby_exprs.len() { + InputOrderMode::Sorted + } else if indices.is_empty() { + InputOrderMode::Linear + } else { + InputOrderMode::PartiallySorted(indices) + }; + return Some((should_swap, mode)); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::aggregates::AggregateFunction; + use crate::collect; + use crate::expressions::col; + use crate::streaming::StreamingTableExec; + use crate::test::assert_is_pending; + use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, SchemaRef}; + use datafusion_execution::TaskContext; + + use futures::FutureExt; + + use InputOrderMode::{Linear, PartiallySorted, Sorted}; + + fn create_test_schema() -> Result { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + + Ok(schema) + } + + fn create_test_schema2() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); + Ok(schema) + } + + // Generate a schema which consists of 5 columns (a, b, c, d, e) + fn create_test_schema3() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, false); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, false); + let e = Field::new("e", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); + Ok(schema) + } + + /// make PhysicalSortExpr with default options + pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + sort_expr_options(name, schema, SortOptions::default()) + } + + /// PhysicalSortExpr with specified options + pub fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, + ) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } + } + + /// Created a sorted Streaming Table exec + pub fn streaming_table_exec( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, + infinite_source: bool, + ) -> Result> { + let sort_exprs = sort_exprs.into_iter().collect(); + + Ok(Arc::new(StreamingTableExec::try_new( + schema.clone(), + vec![], + None, + Some(sort_exprs), + infinite_source, + )?)) + } + + #[tokio::test] + async fn test_calc_requirements() -> Result<()> { + let schema = create_test_schema2()?; + let test_data = vec![ + // PARTITION BY a, ORDER BY b ASC NULLS FIRST + ( + vec!["a"], + vec![("b", true, true)], + vec![("a", None), ("b", Some((true, true)))], + ), + // PARTITION BY a, ORDER BY a ASC NULLS FIRST + (vec!["a"], vec![("a", true, true)], vec![("a", None)]), + // PARTITION BY a, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST + ( + vec!["a"], + vec![("b", true, true), ("c", false, false)], + vec![ + ("a", None), + ("b", Some((true, true))), + ("c", Some((false, false))), + ], + ), + // PARTITION BY a, c, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST + ( + vec!["a", "c"], + vec![("b", true, true), ("c", false, false)], + vec![("a", None), ("c", None), ("b", Some((true, true)))], + ), + ]; + for (pb_params, ob_params, expected_params) in test_data { + let mut partitionbys = vec![]; + for col_name in pb_params { + partitionbys.push(col(col_name, &schema)?); + } + + let mut orderbys = vec![]; + for (col_name, descending, nulls_first) in ob_params { + let expr = col(col_name, &schema)?; + let options = SortOptions { + descending, + nulls_first, + }; + orderbys.push(PhysicalSortExpr { expr, options }); + } + + let mut expected: Option> = None; + for (col_name, reqs) in expected_params { + let options = reqs.map(|(descending, nulls_first)| SortOptions { + descending, + nulls_first, + }); + let expr = col(col_name, &schema)?; + let res = PhysicalSortRequirement::new(expr, options); + if let Some(expected) = &mut expected { + expected.push(res); + } else { + expected = Some(vec![res]); + } + } + assert_eq!(calc_requirements(partitionbys, orderbys), expected); + } + Ok(()) + } + + #[tokio::test] + async fn test_drop_cancel() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let refs = blocking_exec.refs(); + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("a", &schema)?], + &[], + &[], + Arc::new(WindowFrame::new(false)), + schema.as_ref(), + )?], + blocking_exec, + vec![], + )?); + + let fut = collect(window_agg_exec, task_ctx); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } + + #[tokio::test] + async fn test_satisfiy_nullable() -> Result<()> { + let schema = create_test_schema()?; + let params = vec![ + ((true, true), (false, false), false), + ((true, true), (false, true), false), + ((true, true), (true, false), false), + ((true, false), (false, true), false), + ((true, false), (false, false), false), + ((true, false), (true, true), false), + ((true, false), (true, false), true), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + expected, + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let res = physical_ordering.satisfy(&required_ordering.into(), &schema); + assert_eq!(res, expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_satisfy_non_nullable() -> Result<()> { + let schema = create_test_schema()?; + + let params = vec![ + ((true, true), (false, false), false), + ((true, true), (false, true), false), + ((true, true), (true, false), true), + ((true, false), (false, true), false), + ((true, false), (false, false), false), + ((true, false), (true, true), true), + ((true, false), (true, false), true), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + expected, + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let res = physical_ordering.satisfy(&required_ordering.into(), &schema); + assert_eq!(res, expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_get_window_mode_exhaustive() -> Result<()> { + let test_schema = create_test_schema3()?; + // Columns a,c are nullable whereas b,d are not nullable. + // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST + // Column e is not ordered. + let sort_exprs = vec![ + sort_expr("a", &test_schema), + sort_expr("b", &test_schema), + sort_expr("c", &test_schema), + sort_expr("d", &test_schema), + ]; + let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + + // test cases consists of vector of tuples. Where each tuple represents a single test case. + // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns + // For instance `vec!["a", "b"]` corresponds to PARTITION BY a, b + // Second field in the tuple is Vec where each element in the vector represents ORDER BY columns + // For instance, vec!["c"], corresponds to ORDER BY c ASC NULLS FIRST, (ordering is default ordering. We do not check + // for reversibility in this test). + // Third field in the tuple is Option, which corresponds to expected algorithm mode. + // None represents that existing ordering is not sufficient to run executor with any one of the algorithms + // (We need to add SortExec to be able to run it). + // Some(InputOrderMode) represents, we can run algorithm with existing ordering; and algorithm should work in + // InputOrderMode. + let test_cases = vec![ + (vec!["a"], vec!["a"], Some(Sorted)), + (vec!["a"], vec!["b"], Some(Sorted)), + (vec!["a"], vec!["c"], None), + (vec!["a"], vec!["a", "b"], Some(Sorted)), + (vec!["a"], vec!["b", "c"], Some(Sorted)), + (vec!["a"], vec!["a", "c"], None), + (vec!["a"], vec!["a", "b", "c"], Some(Sorted)), + (vec!["b"], vec!["a"], Some(Linear)), + (vec!["b"], vec!["b"], Some(Linear)), + (vec!["b"], vec!["c"], None), + (vec!["b"], vec!["a", "b"], Some(Linear)), + (vec!["b"], vec!["b", "c"], None), + (vec!["b"], vec!["a", "c"], Some(Linear)), + (vec!["b"], vec!["a", "b", "c"], Some(Linear)), + (vec!["c"], vec!["a"], Some(Linear)), + (vec!["c"], vec!["b"], None), + (vec!["c"], vec!["c"], Some(Linear)), + (vec!["c"], vec!["a", "b"], Some(Linear)), + (vec!["c"], vec!["b", "c"], None), + (vec!["c"], vec!["a", "c"], Some(Linear)), + (vec!["c"], vec!["a", "b", "c"], Some(Linear)), + (vec!["b", "a"], vec!["a"], Some(Sorted)), + (vec!["b", "a"], vec!["b"], Some(Sorted)), + (vec!["b", "a"], vec!["c"], Some(Sorted)), + (vec!["b", "a"], vec!["a", "b"], Some(Sorted)), + (vec!["b", "a"], vec!["b", "c"], Some(Sorted)), + (vec!["b", "a"], vec!["a", "c"], Some(Sorted)), + (vec!["b", "a"], vec!["a", "b", "c"], Some(Sorted)), + (vec!["c", "b"], vec!["a"], Some(Linear)), + (vec!["c", "b"], vec!["b"], Some(Linear)), + (vec!["c", "b"], vec!["c"], Some(Linear)), + (vec!["c", "b"], vec!["a", "b"], Some(Linear)), + (vec!["c", "b"], vec!["b", "c"], Some(Linear)), + (vec!["c", "b"], vec!["a", "c"], Some(Linear)), + (vec!["c", "b"], vec!["a", "b", "c"], Some(Linear)), + (vec!["c", "a"], vec!["a"], Some(PartiallySorted(vec![1]))), + (vec!["c", "a"], vec!["b"], Some(PartiallySorted(vec![1]))), + (vec!["c", "a"], vec!["c"], Some(PartiallySorted(vec![1]))), + ( + vec!["c", "a"], + vec!["a", "b"], + Some(PartiallySorted(vec![1])), + ), + ( + vec!["c", "a"], + vec!["b", "c"], + Some(PartiallySorted(vec![1])), + ), + ( + vec!["c", "a"], + vec!["a", "c"], + Some(PartiallySorted(vec![1])), + ), + ( + vec!["c", "a"], + vec!["a", "b", "c"], + Some(PartiallySorted(vec![1])), + ), + (vec!["c", "b", "a"], vec!["a"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["b"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["c"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["a", "b"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["b", "c"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["a", "c"], Some(Sorted)), + (vec!["c", "b", "a"], vec!["a", "b", "c"], Some(Sorted)), + ]; + for (case_idx, test_case) in test_cases.iter().enumerate() { + let (partition_by_columns, order_by_params, expected) = &test_case; + let mut partition_by_exprs = vec![]; + for col_name in partition_by_columns { + partition_by_exprs.push(col(col_name, &test_schema)?); + } + + let mut order_by_exprs = vec![]; + for col_name in order_by_params { + let expr = col(col_name, &test_schema)?; + // Give default ordering, this is same with input ordering direction + // In this test we do check for reversibility. + let options = SortOptions::default(); + order_by_exprs.push(PhysicalSortExpr { expr, options }); + } + let res = + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); + // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option + let res = res.map(|(_, mode)| mode); + assert_eq!( + res, *expected, + "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_get_window_mode() -> Result<()> { + let test_schema = create_test_schema3()?; + // Columns a,c are nullable whereas b,d are not nullable. + // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST + // Column e is not ordered. + let sort_exprs = vec![ + sort_expr("a", &test_schema), + sort_expr("b", &test_schema), + sort_expr("c", &test_schema), + sort_expr("d", &test_schema), + ]; + let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + + // test cases consists of vector of tuples. Where each tuple represents a single test case. + // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns + // For instance `vec!["a", "b"]` corresponds to PARTITION BY a, b + // Second field in the tuple is Vec<(str, bool, bool)> where each element in the vector represents ORDER BY columns + // For instance, vec![("c", false, false)], corresponds to ORDER BY c ASC NULLS LAST, + // similarly, vec![("c", true, true)], corresponds to ORDER BY c DESC NULLS FIRST, + // Third field in the tuple is Option<(bool, InputOrderMode)>, which corresponds to expected result. + // None represents that existing ordering is not sufficient to run executor with any one of the algorithms + // (We need to add SortExec to be able to run it). + // Some((bool, InputOrderMode)) represents, we can run algorithm with existing ordering. Algorithm should work in + // InputOrderMode, bool field represents whether we should reverse window expressions to run executor with existing ordering. + // For instance, `Some((false, InputOrderMode::Sorted))`, represents that we shouldn't reverse window expressions. And algorithm + // should work in Sorted mode to work with existing ordering. + let test_cases = vec![ + // PARTITION BY a, b ORDER BY c ASC NULLS LAST + (vec!["a", "b"], vec![("c", false, false)], None), + // ORDER BY c ASC NULLS FIRST + (vec![], vec![("c", false, true)], None), + // PARTITION BY b, ORDER BY c ASC NULLS FIRST + (vec!["b"], vec![("c", false, true)], None), + // PARTITION BY a, ORDER BY c ASC NULLS FIRST + (vec!["a"], vec![("c", false, true)], None), + // PARTITION BY b, ORDER BY c ASC NULLS FIRST + ( + vec!["a", "b"], + vec![("c", false, true), ("e", false, true)], + None, + ), + // PARTITION BY a, ORDER BY b ASC NULLS FIRST + (vec!["a"], vec![("b", false, true)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY a ASC NULLS FIRST + (vec!["a"], vec![("a", false, true)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY a ASC NULLS LAST + (vec!["a"], vec![("a", false, false)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY a DESC NULLS FIRST + (vec!["a"], vec![("a", true, true)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY a DESC NULLS LAST + (vec!["a"], vec![("a", true, false)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY b ASC NULLS LAST + (vec!["a"], vec![("b", false, false)], Some((false, Sorted))), + // PARTITION BY a, ORDER BY b DESC NULLS LAST + (vec!["a"], vec![("b", true, false)], Some((true, Sorted))), + // PARTITION BY a, b ORDER BY c ASC NULLS FIRST + ( + vec!["a", "b"], + vec![("c", false, true)], + Some((false, Sorted)), + ), + // PARTITION BY b, a ORDER BY c ASC NULLS FIRST + ( + vec!["b", "a"], + vec![("c", false, true)], + Some((false, Sorted)), + ), + // PARTITION BY a, b ORDER BY c DESC NULLS LAST + ( + vec!["a", "b"], + vec![("c", true, false)], + Some((true, Sorted)), + ), + // PARTITION BY e ORDER BY a ASC NULLS FIRST + ( + vec!["e"], + vec![("a", false, true)], + // For unbounded, expects to work in Linear mode. Shouldn't reverse window function. + Some((false, Linear)), + ), + // PARTITION BY b, c ORDER BY a ASC NULLS FIRST, c ASC NULLS FIRST + ( + vec!["b", "c"], + vec![("a", false, true), ("c", false, true)], + Some((false, Linear)), + ), + // PARTITION BY b ORDER BY a ASC NULLS FIRST + (vec!["b"], vec![("a", false, true)], Some((false, Linear))), + // PARTITION BY a, e ORDER BY b ASC NULLS FIRST + ( + vec!["a", "e"], + vec![("b", false, true)], + Some((false, PartiallySorted(vec![0]))), + ), + // PARTITION BY a, c ORDER BY b ASC NULLS FIRST + ( + vec!["a", "c"], + vec![("b", false, true)], + Some((false, PartiallySorted(vec![0]))), + ), + // PARTITION BY c, a ORDER BY b ASC NULLS FIRST + ( + vec!["c", "a"], + vec![("b", false, true)], + Some((false, PartiallySorted(vec![1]))), + ), + // PARTITION BY d, b, a ORDER BY c ASC NULLS FIRST + ( + vec!["d", "b", "a"], + vec![("c", false, true)], + Some((false, PartiallySorted(vec![2, 1]))), + ), + // PARTITION BY e, b, a ORDER BY c ASC NULLS FIRST + ( + vec!["e", "b", "a"], + vec![("c", false, true)], + Some((false, PartiallySorted(vec![2, 1]))), + ), + // PARTITION BY d, a ORDER BY b ASC NULLS FIRST + ( + vec!["d", "a"], + vec![("b", false, true)], + Some((false, PartiallySorted(vec![1]))), + ), + // PARTITION BY b, ORDER BY b, a ASC NULLS FIRST + ( + vec!["a"], + vec![("b", false, true), ("a", false, true)], + Some((false, Sorted)), + ), + // ORDER BY b, a ASC NULLS FIRST + (vec![], vec![("b", false, true), ("a", false, true)], None), + ]; + for (case_idx, test_case) in test_cases.iter().enumerate() { + let (partition_by_columns, order_by_params, expected) = &test_case; + let mut partition_by_exprs = vec![]; + for col_name in partition_by_columns { + partition_by_exprs.push(col(col_name, &test_schema)?); + } + + let mut order_by_exprs = vec![]; + for (col_name, descending, nulls_first) in order_by_params { + let expr = col(col_name, &test_schema)?; + let options = SortOptions { + descending: *descending, + nulls_first: *nulls_first, + }; + order_by_exprs.push(PhysicalSortExpr { expr, options }); + } + + assert_eq!( + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded), + *expected, + "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" + ); + } + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs similarity index 82% rename from datafusion/core/src/physical_plan/windows/window_agg_exec.rs rename to datafusion/physical-plan/src/windows/window_agg_exec.rs index f57dfbc0b68b2..6c245f65ba4f9 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -17,19 +17,24 @@ //! Stream and channel implementations for window function expressions. -use crate::physical_plan::common::transpose; -use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, -}; -use crate::physical_plan::windows::{ - calc_requirements, get_ordered_partition_by_indices, window_ordering_equivalence, +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::common::transpose; +use crate::expressions::PhysicalSortExpr; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::windows::{ + calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, + window_equivalence_properties, }; -use crate::physical_plan::{ - ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties, - ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, +use crate::{ + ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, + WindowExpr, }; + use arrow::compute::{concat, concat_batches}; use arrow::datatypes::SchemaBuilder; use arrow::error::ArrowError; @@ -38,17 +43,14 @@ use arrow::{ datatypes::{Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::utils::{evaluate_partition_ranges, get_at_indices}; -use datafusion_common::DataFusionError; -use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_common::utils::evaluate_partition_ranges; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; + use futures::stream::Stream; use futures::{ready, StreamExt}; -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; /// Window execution plan #[derive(Debug)] @@ -59,8 +61,6 @@ pub struct WindowAggExec { window_expr: Vec>, /// Schema after the window is run schema: SchemaRef, - /// Schema before the window - input_schema: SchemaRef, /// Partition Keys pub partition_keys: Vec>, /// Execution metrics @@ -75,10 +75,9 @@ impl WindowAggExec { pub fn try_new( window_expr: Vec>, input: Arc, - input_schema: SchemaRef, partition_keys: Vec>, ) -> Result { - let schema = create_schema(&input_schema, &window_expr)?; + let schema = create_schema(&input.schema(), &window_expr)?; let schema = Arc::new(schema); let ordered_partition_by_indices = @@ -87,7 +86,6 @@ impl WindowAggExec { input, window_expr, schema, - input_schema, partition_keys, metrics: ExecutionPlanMetricsSet::new(), ordered_partition_by_indices, @@ -104,20 +102,46 @@ impl WindowAggExec { &self.input } - /// Get the input schema before any window functions are applied - pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() - } - /// Return the output sort order of partition keys: For example /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { - // Partition by sort keys indices are stored in self.ordered_partition_by_indices. - let sort_keys = self.input.output_ordering().unwrap_or(&[]); - get_at_indices(sort_keys, &self.ordered_partition_by_indices) + let partition_by = self.window_expr()[0].partition_by(); + get_partition_by_sort_exprs( + &self.input, + partition_by, + &self.ordered_partition_by_indices, + ) + } +} + +impl DisplayAs for WindowAggExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "WindowAggExec: ")?; + let g: Vec = self + .window_expr + .iter() + .map(|e| { + format!( + "{}: {:?}, frame: {:?}", + e.name().to_owned(), + e.field(), + e.get_window_frame() + ) + }) + .collect(); + write!(f, "wdw=[{}]", g.join(", "))?; + } + } + Ok(()) } } @@ -148,10 +172,9 @@ impl ExecutionPlan for WindowAggExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - Err(DataFusionError::Plan( + plan_err!( "Window Error: Windowing is not currently support for unbounded inputs." - .to_string(), - )) + ) } else { Ok(false) } @@ -187,13 +210,9 @@ impl ExecutionPlan for WindowAggExec { } } + /// Get the [`EquivalenceProperties`] within the plan fn equivalence_properties(&self) -> EquivalenceProperties { - self.input().equivalence_properties() - } - - /// Get the OrderingEquivalenceProperties within the plan - fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - window_ordering_equivalence(&self.schema, &self.input, &self.window_expr) + window_equivalence_properties(&self.schema, &self.input, &self.window_expr) } fn with_new_children( @@ -203,7 +222,6 @@ impl ExecutionPlan for WindowAggExec { Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - self.input_schema.clone(), self.partition_keys.clone(), )?)) } @@ -225,54 +243,26 @@ impl ExecutionPlan for WindowAggExec { Ok(stream) } - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default => { - write!(f, "WindowAggExec: ")?; - let g: Vec = self - .window_expr - .iter() - .map(|e| { - format!( - "{}: {:?}, frame: {:?}", - e.name().to_owned(), - e.field(), - e.get_window_frame() - ) - }) - .collect(); - write!(f, "wdw=[{}]", g.join(", "))?; - } - } - Ok(()) - } - fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Statistics { - let input_stat = self.input.statistics(); + fn statistics(&self) -> Result { + let input_stat = self.input.statistics()?; let win_cols = self.window_expr.len(); - let input_cols = self.input_schema.fields().len(); + let input_cols = self.input.schema().fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - if let Some(input_col_stats) = input_stat.column_statistics { - column_statistics.extend(input_col_stats); - } else { - column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) } - column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); - Statistics { - is_exact: input_stat.is_exact, + Ok(Statistics { num_rows: input_stat.num_rows, - column_statistics: Some(column_statistics), - total_byte_size: None, - } + column_statistics, + total_byte_size: Precision::Absent, + }) } } @@ -325,9 +315,7 @@ impl WindowAggStream { ) -> Result { // In WindowAggExec all partition by columns should be ordered. if window_expr[0].partition_by().len() != ordered_partition_by_indices.len() { - return Err(DataFusionError::Internal( - "All partition by columns should have an ordering".to_string(), - )); + return internal_err!("All partition by columns should have an ordering"); } Ok(Self { schema, diff --git a/datafusion/proto/.gitignore b/datafusion/proto/.gitignore new file mode 100644 index 0000000000000..3aa373dc479b8 --- /dev/null +++ b/datafusion/proto/.gitignore @@ -0,0 +1,4 @@ +# Files generated by regen.sh +proto/proto_descriptor.bin +src/datafusion.rs +datafusion.serde.rs diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 70137f63cf15c..4dda689fff4c3 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -19,14 +19,14 @@ name = "datafusion-proto" description = "Protobuf serialization of DataFusion logical plan expressions" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] @@ -36,21 +36,22 @@ name = "datafusion_proto" path = "src/lib.rs" [features] -default = [] +default = ["parquet"] json = ["pbjson", "serde", "serde_json"] +parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } -chrono = { version = "0.4", default-features = false } -datafusion = { path = "../core", version = "26.0.0" } -datafusion-common = { path = "../common", version = "26.0.0" } -datafusion-expr = { path = "../expr", version = "26.0.0" } -object_store = { version = "0.6.1" } +chrono = { workspace = true } +datafusion = { path = "../core", version = "33.0.0" } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +object_store = { workspace = true } pbjson = { version = "0.5", optional = true } -prost = "0.11.0" +prost = "0.12.0" serde = { version = "1.0", optional = true } -serde_json = { version = "1.0", optional = true } +serde_json = { workspace = true, optional = true } [dev-dependencies] -doc-comment = "0.3" +doc-comment = { workspace = true } tokio = "1.18" diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index 236584ded6b34..171aadb744d69 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -19,7 +19,7 @@ # DataFusion Proto -[DataFusion](df) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate is a submodule of DataFusion that provides a protocol buffer format for representing query plans and expressions. @@ -58,7 +58,7 @@ use datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes}; #[tokio::main] async fn main() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await ?; let plan = ctx.table("t1").await?.into_optimized_plan()?; @@ -81,7 +81,7 @@ use datafusion_proto::bytes::{physical_plan_from_bytes,physical_plan_to_bytes}; #[tokio::main] async fn main() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await ?; let logical_plan = ctx.table("t1").await?.into_optimized_plan()?; diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 01e7bb9871a41..8b3f3f98a8a1d 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -19,8 +19,8 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" -edition = "2021" -rust-version = "1.62" +edition = { workspace = true } +rust-version = "1.64" authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" @@ -31,5 +31,5 @@ publish = false [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = { version = "=0.5.1" } -prost-build = { version = "=0.11.9" } +pbjson-build = "=0.6.2" +prost-build = "=0.12.3" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9b05dea712943..f391592dfe76b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -73,6 +73,7 @@ message LogicalPlanNode { CustomTableScanNode custom_scan = 25; PrepareNode prepare = 26; DropViewNode drop_view = 27; + DistinctOnNode distinct_on = 28; } } @@ -88,6 +89,10 @@ message ProjectionColumns { message CsvFormat { bool has_header = 1; string delimiter = 2; + string quote = 3; + oneof optional_escape { + string escape = 4; + } } message ParquetFormat { @@ -176,6 +181,25 @@ message EmptyRelationNode { bool produce_one_row = 1; } +message PrimaryKeyConstraint{ + repeated uint64 indices = 1; +} + +message UniqueConstraint{ + repeated uint64 indices = 1; +} + +message Constraint{ + oneof constraint_mode{ + PrimaryKeyConstraint primary_key = 1; + UniqueConstraint unique = 2; + } +} + +message Constraints{ + repeated Constraint constraints = 1; +} + message CreateExternalTableNode { reserved 1; // was string name OwnedTableReference name = 12; @@ -191,6 +215,8 @@ message CreateExternalTableNode { repeated LogicalExprNodeCollection order_exprs = 13; bool unbounded = 14; map options = 11; + Constraints constraints = 15; + map column_defaults = 16; } message PrepareNode { @@ -284,6 +310,13 @@ message DistinctNode { LogicalPlanNode input = 1; } +message DistinctOnNode { + repeated LogicalExprNode on_expr = 1; + repeated LogicalExprNode select_expr = 2; + repeated LogicalExprNode sort_expr = 3; + LogicalPlanNode input = 4; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } @@ -339,7 +372,7 @@ message LogicalExprNode { SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; - bool wildcard = 15; + Wildcard wildcard = 15; ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; @@ -375,6 +408,10 @@ message LogicalExprNode { } } +message Wildcard { + optional string qualifier = 1; +} + message PlaceholderNode { string id = 1; ArrowType data_type = 2; @@ -396,11 +433,26 @@ message RollupNode { repeated LogicalExprNode expr = 1; } +message NamedStructField { + ScalarValue name = 1; +} + +message ListIndex { + LogicalExprNode key = 1; +} +message ListRange { + LogicalExprNode start = 1; + LogicalExprNode stop = 2; +} message GetIndexedField { LogicalExprNode expr = 1; - ScalarValue key = 2; + oneof field { + NamedStructField named_struct_field = 2; + ListIndex list_index = 3; + ListRange list_range = 4; + } } message IsNull { @@ -442,6 +494,7 @@ message Not { message AliasNode { LogicalExprNode expr = 1; string alias = 2; + repeated OwnedTableReference relation = 3; } message BinaryExprNode { @@ -552,7 +605,7 @@ enum ScalarFunction { ArrayAppend = 86; ArrayConcat = 87; ArrayDims = 88; - ArrayFill = 89; + ArrayRepeat = 89; ArrayLength = 90; ArrayNdims = 91; ArrayPosition = 92; @@ -562,7 +615,37 @@ enum ScalarFunction { ArrayReplace = 96; ArrayToString = 97; Cardinality = 98; - TrimArray = 99; + ArrayElement = 99; + ArraySlice = 100; + Encode = 101; + Decode = 102; + Cot = 103; + ArrayHas = 104; + ArrayHasAny = 105; + ArrayHasAll = 106; + ArrayRemoveN = 107; + ArrayReplaceN = 108; + ArrayRemoveAll = 109; + ArrayReplaceAll = 110; + Nanvl = 111; + Flatten = 112; + Isnan = 113; + Iszero = 114; + ArrayEmpty = 115; + ArrayPopBack = 116; + StringToArray = 117; + ToTimestampNanos = 118; + ArrayIntersect = 119; + ArrayUnion = 120; + OverLay = 121; + Range = 122; + ArrayExcept = 123; + ArrayPopFront = 124; + Levenshtein = 125; + SubstrIndex = 126; + FindInSet = 127; + ArraySort = 128; + ArrayDistinct = 129; } message ScalarFunctionNode { @@ -599,6 +682,16 @@ enum AggregateFunction { // we append "_AGG" to obey name scoping rules. FIRST_VALUE_AGG = 24; LAST_VALUE_AGG = 25; + REGR_SLOPE = 26; + REGR_INTERCEPT = 27; + REGR_COUNT = 28; + REGR_R2 = 29; + REGR_AVGX = 30; + REGR_AVGY = 31; + REGR_SXX = 32; + REGR_SYY = 33; + REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { @@ -639,7 +732,8 @@ message WindowExprNode { oneof window_function { AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; - // udaf = 3 + string udaf = 3; + string udwf = 9; } LogicalExprNode expr = 4; repeated LogicalExprNode partition_by = 5; @@ -736,6 +830,7 @@ message WindowFrameBound { message Schema { repeated Field columns = 1; + map metadata = 2; } message Field { @@ -745,6 +840,9 @@ message Field { bool nullable = 3; // for complex data types like structs, unions repeated Field children = 4; + map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; } message FixedSizeBinary{ @@ -814,12 +912,10 @@ message Union{ repeated int32 type_ids = 3; } -message ScalarListValue{ - // encode null explicitly to distinguish a list with a null value - // from a list with no values) - bool is_null = 3; - Field field = 1; - repeated ScalarValue values = 2; +message ScalarListValue { + bytes ipc_message = 1; + bytes arrow_data = 2; + Schema schema = 3; } message ScalarTime32Value { @@ -895,13 +991,22 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; + ScalarListValue large_list_value = 16; ScalarListValue list_value = 17; - //WAS: ScalarType null_list_value = 18; + ScalarListValue fixed_size_list_value = 18; Decimal128 decimal128_value = 20; + Decimal256 decimal256_value = 39; + int64 date_64_value = 21; int32 interval_yearmonth_value = 24; int64 interval_daytime_value = 25; + + int64 duration_second_value = 35; + int64 duration_millisecond_value = 36; + int64 duration_microsecond_value = 37; + int64 duration_nanosecond_value = 38; + ScalarTimestampValue timestamp_value = 26; ScalarDictionaryValue dictionary_value = 27; bytes binary_value = 28; @@ -919,6 +1024,12 @@ message Decimal128{ int64 s = 3; } +message Decimal256{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + // Serialized data type message ArrowType{ oneof arrow_type_enum { @@ -989,8 +1100,10 @@ message PlanType { OptimizedLogicalPlanType OptimizedLogicalPlan = 2; EmptyMessage FinalLogicalPlan = 3; EmptyMessage InitialPhysicalPlan = 4; + EmptyMessage InitialPhysicalPlanWithStats = 9; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; EmptyMessage FinalPhysicalPlan = 6; + EmptyMessage FinalPhysicalPlanWithStats = 10; } } @@ -1047,9 +1160,63 @@ message PhysicalPlanNode { UnionExecNode union = 19; ExplainExecNode explain = 20; SortPreservingMergeExecNode sort_preserving_merge = 21; + NestedLoopJoinExecNode nested_loop_join = 22; + AnalyzeExecNode analyze = 23; + JsonSinkExecNode json_sink = 24; + SymmetricHashJoinExecNode symmetric_hash_join = 25; + InterleaveExecNode interleave = 26; + PlaceholderRowExecNode placeholder_row = 27; + } +} + +enum CompressionTypeVariant { + GZIP = 0; + BZIP2 = 1; + XZ = 2; + ZSTD = 3; + UNCOMPRESSED = 4; +} + +message PartitionColumn { + string name = 1; + ArrowType arrow_type = 2; +} + +message FileTypeWriterOptions { + oneof FileType { + JsonWriterOptions json_options = 1; } } +message JsonWriterOptions { + CompressionTypeVariant compression = 1; +} + +message FileSinkConfig { + reserved 6; // writer_mode + + string object_store_url = 1; + repeated PartitionedFile file_groups = 2; + repeated string table_paths = 3; + Schema output_schema = 4; + repeated PartitionColumn table_partition_cols = 5; + bool single_file_output = 7; + bool unbounded_input = 8; + bool overwrite = 9; + FileTypeWriterOptions file_type_writer_options = 10; +} + +message JsonSink { + FileSinkConfig config = 1; +} + +message JsonSinkExecNode { + PhysicalPlanNode input = 1; + JsonSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; @@ -1057,6 +1224,9 @@ message PhysicalExtensionNode { // physical expressions message PhysicalExprNode { + // Was date_time_interval_expr + reserved 17; + oneof ExprType { // column references PhysicalColumn column = 1; @@ -1087,8 +1257,6 @@ message PhysicalExprNode { PhysicalScalarUdfNode scalar_udf = 16; - PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17; - PhysicalLikeExprNode like_expr = 18; PhysicalGetIndexedFieldExprNode get_indexed_field_expr = 19; @@ -1107,6 +1275,7 @@ message PhysicalAggregateExprNode { string user_defined_aggr_function = 4; } repeated PhysicalExprNode expr = 2; + repeated PhysicalSortExprNode ordering_req = 5; bool distinct = 3; } @@ -1116,7 +1285,11 @@ message PhysicalWindowExprNode { BuiltInWindowFunction built_in_function = 2; // udaf = 3 } - PhysicalExprNode expr = 4; + repeated PhysicalExprNode args = 4; + repeated PhysicalExprNode partition_by = 5; + repeated PhysicalSortExprNode order_by = 6; + WindowFrame window_frame = 7; + string name = 8; } message PhysicalIsNull { @@ -1202,6 +1375,7 @@ message PhysicalNegativeNode { message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; + uint32 default_filter_selectivity = 3; } message FileGroup { @@ -1244,6 +1418,10 @@ message CsvScanExecNode { FileScanExecConf base_conf = 1; bool has_header = 2; string delimiter = 3; + string quote = 4; + oneof optional_escape { + string escape = 5; + } } message AvroScanExecNode { @@ -1266,6 +1444,25 @@ message HashJoinExecNode { JoinFilter filter = 8; } +enum StreamPartitionMode { + SINGLE_PARTITION = 0; + PARTITIONED_EXEC = 1; +} + +message SymmetricHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + JoinType join_type = 4; + StreamPartitionMode partition_mode = 6; + bool null_equals_null = 7; + JoinFilter filter = 8; +} + +message InterleaveExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } @@ -1276,6 +1473,13 @@ message ExplainExecNode { bool verbose = 3; } +message AnalyzeExecNode { + bool verbose = 1; + bool show_statistics = 2; + PhysicalPlanNode input = 3; + Schema schema = 4; +} + message CrossJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; @@ -1292,8 +1496,11 @@ message JoinOn { } message EmptyExecNode { - bool produce_one_row = 1; - Schema schema = 2; + Schema schema = 1; +} + +message PlaceholderRowExecNode { + Schema schema = 1; } message ProjectionExecNode { @@ -1307,13 +1514,23 @@ enum AggregateMode { FINAL = 1; FINAL_PARTITIONED = 2; SINGLE = 3; + SINGLE_PARTITIONED = 4; +} + +message PartiallySortedInputOrderMode { + repeated uint64 columns = 6; } message WindowAggExecNode { PhysicalPlanNode input = 1; - repeated PhysicalExprNode window_expr = 2; - repeated string window_expr_name = 3; - Schema input_schema = 4; + repeated PhysicalWindowExprNode window_expr = 2; + repeated PhysicalExprNode partition_keys = 5; + // Set optional to `None` for `BoundedWindowAggExec`. + oneof input_order_mode { + EmptyMessage linear = 7; + PartiallySortedInputOrderMode partially_sorted = 8; + EmptyMessage sorted = 9; + } } message MaybeFilter { @@ -1363,6 +1580,15 @@ message SortExecNode { message SortPreservingMergeExecNode { PhysicalPlanNode input = 1; repeated PhysicalExprNode expr = 2; + // Maximum number of highest/lowest rows to fetch; negative means no limit + int64 fetch = 3; +} + +message NestedLoopJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + JoinType join_type = 3; + JoinFilter filter = 4; } message CoalesceBatchesExecNode { @@ -1424,21 +1650,48 @@ message PartitionStats { repeated ColumnStats column_stats = 4; } +message Precision{ + PrecisionInfo precision_info = 1; + ScalarValue val = 2; +} + +enum PrecisionInfo { + EXACT = 0; + INEXACT = 1; + ABSENT = 2; +} + message Statistics { - int64 num_rows = 1; - int64 total_byte_size = 2; + Precision num_rows = 1; + Precision total_byte_size = 2; repeated ColumnStats column_stats = 3; - bool is_exact = 4; } message ColumnStats { - ScalarValue min_value = 1; - ScalarValue max_value = 2; - uint32 null_count = 3; - uint32 distinct_count = 4; + Precision min_value = 1; + Precision max_value = 2; + Precision null_count = 3; + Precision distinct_count = 4; +} + +message NamedStructFieldExpr { + ScalarValue name = 1; +} + +message ListIndexExpr { + PhysicalExprNode key = 1; +} + +message ListRangeExpr { + PhysicalExprNode start = 1; + PhysicalExprNode stop = 2; } message PhysicalGetIndexedFieldExprNode { PhysicalExprNode arg = 1; - ScalarValue key = 2; + oneof field { + NamedStructFieldExpr named_struct_field_expr = 2; + ListIndexExpr list_index_expr = 3; + ListRangeExpr list_range_expr = 4; + } } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 4a7e18c8fc647..9377501499e2a 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,8 +24,11 @@ use crate::physical_plan::{ }; use crate::protobuf; use datafusion::physical_plan::functions::make_scalar_function; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{create_udaf, create_udf, Expr, LogicalPlan, Volatility}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::{ + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, + WindowUDF, +}; use prost::{ bytes::{Bytes, BytesMut}, Message, @@ -85,13 +88,13 @@ pub trait Serializeable: Sized { impl Serializeable for Expr { fn to_bytes(&self) -> Result { let mut buffer = BytesMut::new(); - let protobuf: protobuf::LogicalExprNode = self.try_into().map_err(|e| { - DataFusionError::Plan(format!("Error encoding expr as protobuf: {e}")) - })?; + let protobuf: protobuf::LogicalExprNode = self + .try_into() + .map_err(|e| plan_datafusion_err!("Error encoding expr as protobuf: {e}"))?; - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; let bytes: Bytes = buffer.into(); @@ -118,16 +121,26 @@ impl Serializeable for Expr { ))) } - fn udaf(&self, name: &str) -> Result> { + fn udaf(&self, name: &str) -> Result> { Ok(Arc::new(create_udaf( name, - arrow::datatypes::DataType::Null, + vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, Arc::new(|_| unimplemented!()), Arc::new(vec![]), ))) } + + fn udwf(&self, name: &str) -> Result> { + Ok(Arc::new(create_udwf( + name, + arrow::datatypes::DataType::Null, + Arc::new(arrow::datatypes::DataType::Null), + Volatility::Immutable, + Arc::new(|| unimplemented!()), + ))) + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; @@ -138,13 +151,11 @@ impl Serializeable for Expr { bytes: &[u8], registry: &dyn FunctionRegistry, ) -> Result { - let protobuf = protobuf::LogicalExprNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::LogicalExprNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - logical_plan::from_proto::parse_expr(&protobuf, registry).map_err(|e| { - DataFusionError::Plan(format!("Error parsing protobuf into Expr: {e}")) - }) + logical_plan::from_proto::parse_expr(&protobuf, registry) + .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } @@ -160,9 +171,9 @@ pub fn logical_plan_to_json(plan: &LogicalPlan) -> Result { let extension_codec = DefaultLogicalExtensionCodec {}; let protobuf = protobuf::LogicalPlanNode::try_from_logical_plan(plan, &extension_codec) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}"))) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } /// Serialize a LogicalPlan as bytes, using the provided extension codec @@ -173,9 +184,9 @@ pub fn logical_plan_to_bytes_with_extension_codec( let protobuf = protobuf::LogicalPlanNode::try_from_logical_plan(plan, extension_codec)?; let mut buffer = BytesMut::new(); - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; Ok(buffer.into()) } @@ -183,7 +194,7 @@ pub fn logical_plan_to_bytes_with_extension_codec( #[cfg(feature = "json")] pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result { let back: protobuf::LogicalPlanNode = serde_json::from_str(json) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; back.try_into_logical_plan(ctx, &extension_codec) } @@ -203,9 +214,8 @@ pub fn logical_plan_from_bytes_with_extension_codec( ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, ) -> Result { - let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::LogicalPlanNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; protobuf.try_into_logical_plan(ctx, extension_codec) } @@ -221,9 +231,9 @@ pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; let protobuf = protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}"))) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } /// Serialize a PhysicalPlan as bytes, using the provided extension codec @@ -234,9 +244,9 @@ pub fn physical_plan_to_bytes_with_extension_codec( let protobuf = protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; let mut buffer = BytesMut::new(); - protobuf.encode(&mut buffer).map_err(|e| { - DataFusionError::Plan(format!("Error encoding protobuf as bytes: {e}")) - })?; + protobuf + .encode(&mut buffer) + .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; Ok(buffer.into()) } @@ -247,7 +257,7 @@ pub fn physical_plan_from_json( ctx: &SessionContext, ) -> Result> { let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) - .map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {e}")))?; + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; back.try_into_physical_plan(ctx, &ctx.runtime_env(), &extension_codec) } @@ -267,227 +277,7 @@ pub fn physical_plan_from_bytes_with_extension_codec( ctx: &SessionContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let protobuf = protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { - DataFusionError::Plan(format!("Error decoding expr as protobuf: {e}")) - })?; + let protobuf = protobuf::PhysicalPlanNode::decode(bytes) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; protobuf.try_into_physical_plan(ctx, &ctx.runtime_env(), extension_codec) } - -#[cfg(test)] -mod test { - use super::*; - use arrow::{array::ArrayRef, datatypes::DataType}; - use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::prelude::SessionContext; - use datafusion_expr::{col, create_udf, lit, Volatility}; - use std::sync::Arc; - - #[test] - #[should_panic( - expected = "Error decoding expr as protobuf: failed to decode Protobuf message" - )] - fn bad_decode() { - Expr::from_bytes(b"Leet").unwrap(); - } - - #[test] - #[cfg(feature = "json")] - fn plan_to_json() { - use datafusion_common::DFSchema; - use datafusion_expr::logical_plan::EmptyRelation; - - let plan = LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: Arc::new(DFSchema::empty()), - }); - let actual = logical_plan_to_json(&plan).unwrap(); - let expected = r#"{"emptyRelation":{}}"#.to_string(); - assert_eq!(actual, expected); - } - - #[test] - #[cfg(feature = "json")] - fn json_to_plan() { - let input = r#"{"emptyRelation":{}}"#.to_string(); - let ctx = SessionContext::new(); - let actual = logical_plan_from_json(&input, &ctx).unwrap(); - let result = matches!(actual, LogicalPlan::EmptyRelation(_)); - assert!(result, "Should parse empty relation"); - } - - #[test] - fn udf_roundtrip_with_registry() { - let ctx = context_with_udf(); - - let expr = ctx - .udf("dummy") - .expect("could not find udf") - .call(vec![lit("")]); - - let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); - - assert_eq!(expr, deserialized_expr); - } - - #[test] - #[should_panic( - expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'" - )] - fn udf_roundtrip_without_registry() { - let ctx = context_with_udf(); - - let expr = ctx - .udf("dummy") - .expect("could not find udf") - .call(vec![lit("")]); - - let bytes = expr.to_bytes().unwrap(); - // should explode - Expr::from_bytes(&bytes).unwrap(); - } - - fn roundtrip_expr(expr: &Expr) -> Expr { - let bytes = expr.to_bytes().unwrap(); - Expr::from_bytes(&bytes).unwrap() - } - - #[test] - fn exact_roundtrip_linearized_binary_expr() { - // (((A AND B) AND C) AND D) - let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D")); - assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered)); - - // Ensure that no other variation becomes equal - let other_variants = vec![ - // (((B AND A) AND C) AND D) - col("B").and(col("A")).and(col("C")).and(col("D")), - // (((A AND C) AND B) AND D) - col("A").and(col("C")).and(col("B")).and(col("D")), - // (((A AND B) AND D) AND C) - col("A").and(col("B")).and(col("D")).and(col("C")), - // A AND (B AND (C AND D))) - col("A").and(col("B").and(col("C").and(col("D")))), - ]; - for case in other_variants { - // Each variant is still equal to itself - assert_eq!(case, roundtrip_expr(&case)); - - // But non of them is equal to the original - assert_ne!(expr_ordered, roundtrip_expr(&case)); - assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case)); - } - } - - #[test] - fn roundtrip_deeply_nested_binary_expr() { - // We need more stack space so this doesn't overflow in dev builds - std::thread::Builder::new() - .stack_size(10_000_000) - .spawn(|| { - let n = 100; - // a < 5 - let basic_expr = col("a").lt(lit(5i32)); - // (a < 5) OR (a < 5) OR (a < 5) OR ... - let or_chain = (0..n) - .fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone())); - // (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ... - let expr = - (0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone())); - - // Should work fine. - let bytes = expr.to_bytes().unwrap(); - - let decoded_expr = Expr::from_bytes(&bytes).expect( - "serialization worked, so deserialization should work as well", - ); - assert_eq!(decoded_expr, expr); - }) - .expect("spawning thread") - .join() - .expect("joining thread"); - } - - #[test] - fn roundtrip_deeply_nested_binary_expr_reverse_order() { - // We need more stack space so this doesn't overflow in dev builds - std::thread::Builder::new() - .stack_size(10_000_000) - .spawn(|| { - let n = 100; - - // a < 5 - let expr_base = col("a").lt(lit(5i32)); - - // ((a < 5 AND a < 5) AND a < 5) AND ... - let and_chain = - (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone())); - - // a < 5 AND (a < 5 AND (a < 5 AND ...)) - let expr = expr_base.and(and_chain); - - // Should work fine. - let bytes = expr.to_bytes().unwrap(); - - let decoded_expr = Expr::from_bytes(&bytes).expect( - "serialization worked, so deserialization should work as well", - ); - assert_eq!(decoded_expr, expr); - }) - .expect("spawning thread") - .join() - .expect("joining thread"); - } - - #[test] - fn roundtrip_deeply_nested() { - // we need more stack space so this doesn't overflow in dev builds - std::thread::Builder::new().stack_size(10_000_000).spawn(|| { - // don't know what "too much" is, so let's slowly try to increase complexity - let n_max = 100; - - for n in 1..n_max { - println!("testing: {n}"); - - let expr_base = col("a").lt(lit(5i32)); - // Generate a tree of AND and OR expressions (no subsequent ANDs or ORs). - let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) }); - - // Convert it to an opaque form - let bytes = match expr.to_bytes() { - Ok(bytes) => bytes, - Err(_) => { - // found expression that is too deeply nested - return; - } - }; - - // Decode bytes from somewhere (over network, etc. - let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well"); - assert_eq!(expr, decoded_expr); - } - - panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}") - }).expect("spawning thread").join().expect("joining thread"); - } - - /// return a `SessionContext` with a `dummy` function registered as a UDF - fn context_with_udf() -> SessionContext { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); - - let udf = create_udf( - "dummy", - vec![DataType::Utf8], - Arc::new(DataType::Utf8), - Volatility::Immutable, - scalar_fn, - ); - - let ctx = SessionContext::new(); - ctx.register_udf(udf); - - ctx - } -} diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 6c66b33a9fc4f..024bb949baa99 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -18,8 +18,9 @@ use std::{collections::HashSet, sync::Arc}; use datafusion::execution::registry::FunctionRegistry; +use datafusion_common::plan_err; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// A default [`FunctionRegistry`] registry that does not resolve any /// user defined functions @@ -31,14 +32,14 @@ impl FunctionRegistry for NoRegistry { } fn udf(&self, name: &str) -> Result> { - Err(DataFusionError::Plan( - format!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'")) - ) + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'") } fn udaf(&self, name: &str) -> Result> { - Err(DataFusionError::Plan( - format!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'")) - ) + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'") + } + + fn udwf(&self, name: &str) -> Result> { + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'") } } diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index ed826f5874137..b18831048e1ab 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -15,28 +15,24 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; -pub fn csv_delimiter_to_string(b: u8) -> Result { - let b = &[b]; - let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; - Ok(b.to_owned()) -} - -pub fn str_to_byte(s: &String) -> Result { +pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { if s.len() != 1 { - return Err(DataFusionError::Internal( - "Invalid CSV delimiter".to_owned(), - )); + return internal_err!( + "Invalid CSV {description}: expected single character, got {s}" + ); } Ok(s.as_bytes()[0]) } -pub fn byte_to_string(b: u8) -> Result { +pub(crate) fn byte_to_string(b: u8, description: &str) -> Result { let b = &[b]; - let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; + let b = std::str::from_utf8(b).map_err(|_| { + DataFusionError::Internal(format!( + "Invalid CSV {description}: can not represent {b:0x?} as utf8" + )) + })?; Ok(b.to_owned()) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 890fe7221a8e3..d506b5dcce531 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -47,8 +47,8 @@ impl serde::Serialize for AggregateExecNode { struct_ser.serialize_field("aggrExpr", &self.aggr_expr)?; } if self.mode != 0 { - let v = AggregateMode::from_i32(self.mode) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + let v = AggregateMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; struct_ser.serialize_field("mode", &v)?; } if let Some(v) = self.input.as_ref() { @@ -166,7 +166,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { formatter.write_str("struct datafusion.AggregateExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -181,73 +181,73 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut groups__ = None; let mut filter_expr__ = None; let mut order_by_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { if group_expr__.is_some() { return Err(serde::de::Error::duplicate_field("groupExpr")); } - group_expr__ = Some(map.next_value()?); + group_expr__ = Some(map_.next_value()?); } GeneratedField::AggrExpr => { if aggr_expr__.is_some() { return Err(serde::de::Error::duplicate_field("aggrExpr")); } - aggr_expr__ = Some(map.next_value()?); + aggr_expr__ = Some(map_.next_value()?); } GeneratedField::Mode => { if mode__.is_some() { return Err(serde::de::Error::duplicate_field("mode")); } - mode__ = Some(map.next_value::()? as i32); + mode__ = Some(map_.next_value::()? as i32); } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::GroupExprName => { if group_expr_name__.is_some() { return Err(serde::de::Error::duplicate_field("groupExprName")); } - group_expr_name__ = Some(map.next_value()?); + group_expr_name__ = Some(map_.next_value()?); } GeneratedField::AggrExprName => { if aggr_expr_name__.is_some() { return Err(serde::de::Error::duplicate_field("aggrExprName")); } - aggr_expr_name__ = Some(map.next_value()?); + aggr_expr_name__ = Some(map_.next_value()?); } GeneratedField::InputSchema => { if input_schema__.is_some() { return Err(serde::de::Error::duplicate_field("inputSchema")); } - input_schema__ = map.next_value()?; + input_schema__ = map_.next_value()?; } GeneratedField::NullExpr => { if null_expr__.is_some() { return Err(serde::de::Error::duplicate_field("nullExpr")); } - null_expr__ = Some(map.next_value()?); + null_expr__ = Some(map_.next_value()?); } GeneratedField::Groups => { if groups__.is_some() { return Err(serde::de::Error::duplicate_field("groups")); } - groups__ = Some(map.next_value()?); + groups__ = Some(map_.next_value()?); } GeneratedField::FilterExpr => { if filter_expr__.is_some() { return Err(serde::de::Error::duplicate_field("filterExpr")); } - filter_expr__ = Some(map.next_value()?); + filter_expr__ = Some(map_.next_value()?); } GeneratedField::OrderByExpr => { if order_by_expr__.is_some() { return Err(serde::de::Error::duplicate_field("orderByExpr")); } - order_by_expr__ = Some(map.next_value()?); + order_by_expr__ = Some(map_.next_value()?); } } } @@ -294,8 +294,8 @@ impl serde::Serialize for AggregateExprNode { } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; if self.aggr_function != 0 { - let v = AggregateFunction::from_i32(self.aggr_function) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; + let v = AggregateFunction::try_from(self.aggr_function) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; struct_ser.serialize_field("aggrFunction", &v)?; } if !self.expr.is_empty() { @@ -377,7 +377,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { formatter.write_str("struct datafusion.AggregateExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -386,37 +386,37 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::AggrFunction => { if aggr_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); } - aggr_function__ = Some(map.next_value::()? as i32); + aggr_function__ = Some(map_.next_value::()? as i32); } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } GeneratedField::Distinct => { if distinct__.is_some() { return Err(serde::de::Error::duplicate_field("distinct")); } - distinct__ = Some(map.next_value()?); + distinct__ = Some(map_.next_value()?); } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); } - filter__ = map.next_value()?; + filter__ = map_.next_value()?; } GeneratedField::OrderBy => { if order_by__.is_some() { return Err(serde::de::Error::duplicate_field("orderBy")); } - order_by__ = Some(map.next_value()?); + order_by__ = Some(map_.next_value()?); } } } @@ -465,6 +465,16 @@ impl serde::Serialize for AggregateFunction { Self::BoolOr => "BOOL_OR", Self::FirstValueAgg => "FIRST_VALUE_AGG", Self::LastValueAgg => "LAST_VALUE_AGG", + Self::RegrSlope => "REGR_SLOPE", + Self::RegrIntercept => "REGR_INTERCEPT", + Self::RegrCount => "REGR_COUNT", + Self::RegrR2 => "REGR_R2", + Self::RegrAvgx => "REGR_AVGX", + Self::RegrAvgy => "REGR_AVGY", + Self::RegrSxx => "REGR_SXX", + Self::RegrSyy => "REGR_SYY", + Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -502,6 +512,16 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BOOL_OR", "FIRST_VALUE_AGG", "LAST_VALUE_AGG", + "REGR_SLOPE", + "REGR_INTERCEPT", + "REGR_COUNT", + "REGR_R2", + "REGR_AVGX", + "REGR_AVGY", + "REGR_SXX", + "REGR_SYY", + "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -517,10 +537,9 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(AggregateFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -530,10 +549,9 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(AggregateFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -570,6 +588,16 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BOOL_OR" => Ok(AggregateFunction::BoolOr), "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg), "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg), + "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), + "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), + "REGR_COUNT" => Ok(AggregateFunction::RegrCount), + "REGR_R2" => Ok(AggregateFunction::RegrR2), + "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), + "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), + "REGR_SXX" => Ok(AggregateFunction::RegrSxx), + "REGR_SYY" => Ok(AggregateFunction::RegrSyy), + "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -588,6 +616,7 @@ impl serde::Serialize for AggregateMode { Self::Final => "FINAL", Self::FinalPartitioned => "FINAL_PARTITIONED", Self::Single => "SINGLE", + Self::SinglePartitioned => "SINGLE_PARTITIONED", }; serializer.serialize_str(variant) } @@ -603,6 +632,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL", "FINAL_PARTITIONED", "SINGLE", + "SINGLE_PARTITIONED", ]; struct GeneratedVisitor; @@ -618,10 +648,9 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(AggregateMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -631,10 +660,9 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(AggregateMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -649,6 +677,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode { "FINAL" => Ok(AggregateMode::Final), "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), "SINGLE" => Ok(AggregateMode::Single), + "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -744,32 +773,32 @@ impl<'de> serde::Deserialize<'de> for AggregateNode { formatter.write_str("struct datafusion.AggregateNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut group_expr__ = None; let mut aggr_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::GroupExpr => { if group_expr__.is_some() { return Err(serde::de::Error::duplicate_field("groupExpr")); } - group_expr__ = Some(map.next_value()?); + group_expr__ = Some(map_.next_value()?); } GeneratedField::AggrExpr => { if aggr_expr__.is_some() { return Err(serde::de::Error::duplicate_field("aggrExpr")); } - aggr_expr__ = Some(map.next_value()?); + aggr_expr__ = Some(map_.next_value()?); } } } @@ -880,7 +909,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { formatter.write_str("struct datafusion.AggregateUDFExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -888,31 +917,31 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { let mut args__ = None; let mut filter__ = None; let mut order_by__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { if fun_name__.is_some() { return Err(serde::de::Error::duplicate_field("funName")); } - fun_name__ = Some(map.next_value()?); + fun_name__ = Some(map_.next_value()?); } GeneratedField::Args => { if args__.is_some() { return Err(serde::de::Error::duplicate_field("args")); } - args__ = Some(map.next_value()?); + args__ = Some(map_.next_value()?); } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); } - filter__ = map.next_value()?; + filter__ = map_.next_value()?; } GeneratedField::OrderBy => { if order_by__.is_some() { return Err(serde::de::Error::duplicate_field("orderBy")); } - order_by__ = Some(map.next_value()?); + order_by__ = Some(map_.next_value()?); } } } @@ -941,6 +970,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { len += 1; } + if !self.relation.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -948,6 +980,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { struct_ser.serialize_field("alias", &self.alias)?; } + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } struct_ser.end() } } @@ -960,12 +995,14 @@ impl<'de> serde::Deserialize<'de> for AliasNode { const FIELDS: &[&str] = &[ "expr", "alias", + "relation", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Alias, + Relation, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -989,6 +1026,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { match value { "expr" => Ok(GeneratedField::Expr), "alias" => Ok(GeneratedField::Alias), + "relation" => Ok(GeneratedField::Relation), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1004,37 +1042,188 @@ impl<'de> serde::Deserialize<'de> for AliasNode { formatter.write_str("struct datafusion.AliasNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut alias__ = None; - while let Some(k) = map.next_key()? { + let mut relation__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Alias => { if alias__.is_some() { return Err(serde::de::Error::duplicate_field("alias")); } - alias__ = Some(map.next_value()?); + alias__ = Some(map_.next_value()?); + } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); } } } Ok(AliasNode { expr: expr__, alias: alias__.unwrap_or_default(), + relation: relation__.unwrap_or_default(), }) } } deserializer.deserialize_struct("datafusion.AliasNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for AnalyzeExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.verbose { + len += 1; + } + if self.show_statistics { + len += 1; + } + if self.input.is_some() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AnalyzeExecNode", len)?; + if self.verbose { + struct_ser.serialize_field("verbose", &self.verbose)?; + } + if self.show_statistics { + struct_ser.serialize_field("showStatistics", &self.show_statistics)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AnalyzeExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "verbose", + "show_statistics", + "showStatistics", + "input", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Verbose, + ShowStatistics, + Input, + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "verbose" => Ok(GeneratedField::Verbose), + "showStatistics" | "show_statistics" => Ok(GeneratedField::ShowStatistics), + "input" => Ok(GeneratedField::Input), + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AnalyzeExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.AnalyzeExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut verbose__ = None; + let mut show_statistics__ = None; + let mut input__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Verbose => { + if verbose__.is_some() { + return Err(serde::de::Error::duplicate_field("verbose")); + } + verbose__ = Some(map_.next_value()?); + } + GeneratedField::ShowStatistics => { + if show_statistics__.is_some() { + return Err(serde::de::Error::duplicate_field("showStatistics")); + } + show_statistics__ = Some(map_.next_value()?); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(AnalyzeExecNode { + verbose: verbose__.unwrap_or_default(), + show_statistics: show_statistics__.unwrap_or_default(), + input: input__, + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.AnalyzeExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AnalyzeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -1112,25 +1301,25 @@ impl<'de> serde::Deserialize<'de> for AnalyzeNode { formatter.write_str("struct datafusion.AnalyzeNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut verbose__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Verbose => { if verbose__.is_some() { return Err(serde::de::Error::duplicate_field("verbose")); } - verbose__ = Some(map.next_value()?); + verbose__ = Some(map_.next_value()?); } } } @@ -1212,18 +1401,18 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { formatter.write_str("struct datafusion.AnalyzedLogicalPlanType") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut analyzer_name__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::AnalyzerName => { if analyzer_name__.is_some() { return Err(serde::de::Error::duplicate_field("analyzerName")); } - analyzer_name__ = Some(map.next_value()?); + analyzer_name__ = Some(map_.next_value()?); } } } @@ -1310,26 +1499,26 @@ impl serde::Serialize for ArrowType { struct_ser.serialize_field("DATE64", v)?; } arrow_type::ArrowTypeEnum::Duration(v) => { - let v = TimeUnit::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("DURATION", &v)?; } arrow_type::ArrowTypeEnum::Timestamp(v) => { struct_ser.serialize_field("TIMESTAMP", v)?; } arrow_type::ArrowTypeEnum::Time32(v) => { - let v = TimeUnit::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("TIME32", &v)?; } arrow_type::ArrowTypeEnum::Time64(v) => { - let v = TimeUnit::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("TIME64", &v)?; } arrow_type::ArrowTypeEnum::Interval(v) => { - let v = IntervalUnit::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = IntervalUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("INTERVAL", &v)?; } arrow_type::ArrowTypeEnum::Decimal(v) => { @@ -1512,237 +1701,237 @@ impl<'de> serde::Deserialize<'de> for ArrowType { formatter.write_str("struct datafusion.ArrowType") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut arrow_type_enum__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::None => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("NONE")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::None) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::None) ; } GeneratedField::Bool => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("BOOL")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Bool) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Bool) ; } GeneratedField::Uint8 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UINT8")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint8) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint8) ; } GeneratedField::Int8 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("INT8")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int8) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int8) ; } GeneratedField::Uint16 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UINT16")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint16) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint16) ; } GeneratedField::Int16 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("INT16")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int16) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int16) ; } GeneratedField::Uint32 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UINT32")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint32) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint32) ; } GeneratedField::Int32 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("INT32")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int32) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int32) ; } GeneratedField::Uint64 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UINT64")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint64) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint64) ; } GeneratedField::Int64 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("INT64")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int64) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int64) ; } GeneratedField::Float16 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("FLOAT16")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float16) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float16) ; } GeneratedField::Float32 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("FLOAT32")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float32) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float32) ; } GeneratedField::Float64 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("FLOAT64")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float64) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float64) ; } GeneratedField::Utf8 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UTF8")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8) ; } GeneratedField::LargeUtf8 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("LARGEUTF8")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeUtf8) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeUtf8) ; } GeneratedField::Binary => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("BINARY")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Binary) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Binary) ; } GeneratedField::FixedSizeBinary => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("FIXEDSIZEBINARY")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| arrow_type::ArrowTypeEnum::FixedSizeBinary(x.0)); + arrow_type_enum__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| arrow_type::ArrowTypeEnum::FixedSizeBinary(x.0)); } GeneratedField::LargeBinary => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("LARGEBINARY")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeBinary) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeBinary) ; } GeneratedField::Date32 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("DATE32")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date32) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date32) ; } GeneratedField::Date64 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("DATE64")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date64) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date64) ; } GeneratedField::Duration => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("DURATION")); } - arrow_type_enum__ = map.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Duration(x as i32)); + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Duration(x as i32)); } GeneratedField::Timestamp => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("TIMESTAMP")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Timestamp) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Timestamp) ; } GeneratedField::Time32 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("TIME32")); } - arrow_type_enum__ = map.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time32(x as i32)); + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time32(x as i32)); } GeneratedField::Time64 => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("TIME64")); } - arrow_type_enum__ = map.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time64(x as i32)); + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time64(x as i32)); } GeneratedField::Interval => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("INTERVAL")); } - arrow_type_enum__ = map.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Interval(x as i32)); + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Interval(x as i32)); } GeneratedField::Decimal => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("DECIMAL")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) ; } GeneratedField::List => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("LIST")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::List) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::List) ; } GeneratedField::LargeList => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("LARGELIST")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeList) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeList) ; } GeneratedField::FixedSizeList => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("FIXEDSIZELIST")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::FixedSizeList) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::FixedSizeList) ; } GeneratedField::Struct => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("STRUCT")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Struct) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Struct) ; } GeneratedField::Union => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("UNION")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Union) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Union) ; } GeneratedField::Dictionary => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("DICTIONARY")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Dictionary) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Dictionary) ; } GeneratedField::Map => { if arrow_type_enum__.is_some() { return Err(serde::de::Error::duplicate_field("MAP")); } - arrow_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) ; } } @@ -1812,12 +2001,12 @@ impl<'de> serde::Deserialize<'de> for AvroFormat { formatter.write_str("struct datafusion.AvroFormat") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map.next_key::()?.is_some() { - let _ = map.next_value::()?; + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; } Ok(AvroFormat { }) @@ -1895,18 +2084,18 @@ impl<'de> serde::Deserialize<'de> for AvroScanExecNode { formatter.write_str("struct datafusion.AvroScanExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut base_conf__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::BaseConf => { if base_conf__.is_some() { return Err(serde::de::Error::duplicate_field("baseConf")); } - base_conf__ = map.next_value()?; + base_conf__ = map_.next_value()?; } } } @@ -1986,18 +2175,18 @@ impl<'de> serde::Deserialize<'de> for BareTableReference { formatter.write_str("struct datafusion.BareTableReference") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut table__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Table => { if table__.is_some() { return Err(serde::de::Error::duplicate_field("table")); } - table__ = Some(map.next_value()?); + table__ = Some(map_.next_value()?); } } } @@ -2104,7 +2293,7 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { formatter.write_str("struct datafusion.BetweenNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -2112,31 +2301,31 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { let mut negated__ = None; let mut low__ = None; let mut high__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } GeneratedField::Low => { if low__.is_some() { return Err(serde::de::Error::duplicate_field("low")); } - low__ = map.next_value()?; + low__ = map_.next_value()?; } GeneratedField::High => { if high__.is_some() { return Err(serde::de::Error::duplicate_field("high")); } - high__ = map.next_value()?; + high__ = map_.next_value()?; } } } @@ -2228,25 +2417,25 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { formatter.write_str("struct datafusion.BinaryExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut operands__ = None; let mut op__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Operands => { if operands__.is_some() { return Err(serde::de::Error::duplicate_field("operands")); } - operands__ = Some(map.next_value()?); + operands__ = Some(map_.next_value()?); } GeneratedField::Op => { if op__.is_some() { return Err(serde::de::Error::duplicate_field("op")); } - op__ = Some(map.next_value()?); + op__ = Some(map_.next_value()?); } } } @@ -2314,10 +2503,9 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(BuiltInWindowFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -2327,10 +2515,9 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(BuiltInWindowFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -2447,32 +2634,32 @@ impl<'de> serde::Deserialize<'de> for CaseNode { formatter.write_str("struct datafusion.CaseNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut when_then_expr__ = None; let mut else_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::WhenThenExpr => { if when_then_expr__.is_some() { return Err(serde::de::Error::duplicate_field("whenThenExpr")); } - when_then_expr__ = Some(map.next_value()?); + when_then_expr__ = Some(map_.next_value()?); } GeneratedField::ElseExpr => { if else_expr__.is_some() { return Err(serde::de::Error::duplicate_field("elseExpr")); } - else_expr__ = map.next_value()?; + else_expr__ = map_.next_value()?; } } } @@ -2564,25 +2751,25 @@ impl<'de> serde::Deserialize<'de> for CastNode { formatter.write_str("struct datafusion.CastNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut arrow_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::ArrowType => { if arrow_type__.is_some() { return Err(serde::de::Error::duplicate_field("arrowType")); } - arrow_type__ = map.next_value()?; + arrow_type__ = map_.next_value()?; } } } @@ -2673,26 +2860,26 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { formatter.write_str("struct datafusion.CoalesceBatchesExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut target_batch_size__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::TargetBatchSize => { if target_batch_size__.is_some() { return Err(serde::de::Error::duplicate_field("targetBatchSize")); } target_batch_size__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -2774,18 +2961,18 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { formatter.write_str("struct datafusion.CoalescePartitionsExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } } } @@ -2874,25 +3061,25 @@ impl<'de> serde::Deserialize<'de> for Column { formatter.write_str("struct datafusion.Column") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; let mut relation__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::Relation => { if relation__.is_some() { return Err(serde::de::Error::duplicate_field("relation")); } - relation__ = map.next_value()?; + relation__ = map_.next_value()?; } } } @@ -2924,8 +3111,8 @@ impl serde::Serialize for ColumnIndex { struct_ser.serialize_field("index", &self.index)?; } if self.side != 0 { - let v = JoinSide::from_i32(self.side) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.side)))?; + let v = JoinSide::try_from(self.side) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.side)))?; struct_ser.serialize_field("side", &v)?; } struct_ser.end() @@ -2984,27 +3171,27 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { formatter.write_str("struct datafusion.ColumnIndex") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut index__ = None; let mut side__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Index => { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } index__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::Side => { if side__.is_some() { return Err(serde::de::Error::duplicate_field("side")); } - side__ = Some(map.next_value::()? as i32); + side__ = Some(map_.next_value::()? as i32); } } } @@ -3085,18 +3272,18 @@ impl<'de> serde::Deserialize<'de> for ColumnRelation { formatter.write_str("struct datafusion.ColumnRelation") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut relation__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Relation => { if relation__.is_some() { return Err(serde::de::Error::duplicate_field("relation")); } - relation__ = Some(map.next_value()?); + relation__ = Some(map_.next_value()?); } } } @@ -3122,10 +3309,10 @@ impl serde::Serialize for ColumnStats { if self.max_value.is_some() { len += 1; } - if self.null_count != 0 { + if self.null_count.is_some() { len += 1; } - if self.distinct_count != 0 { + if self.distinct_count.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ColumnStats", len)?; @@ -3135,11 +3322,11 @@ impl serde::Serialize for ColumnStats { if let Some(v) = self.max_value.as_ref() { struct_ser.serialize_field("maxValue", v)?; } - if self.null_count != 0 { - struct_ser.serialize_field("nullCount", &self.null_count)?; + if let Some(v) = self.null_count.as_ref() { + struct_ser.serialize_field("nullCount", v)?; } - if self.distinct_count != 0 { - struct_ser.serialize_field("distinctCount", &self.distinct_count)?; + if let Some(v) = self.distinct_count.as_ref() { + struct_ser.serialize_field("distinctCount", v)?; } struct_ser.end() } @@ -3207,7 +3394,7 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { formatter.write_str("struct datafusion.ColumnStats") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -3215,50 +3402,126 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { let mut max_value__ = None; let mut null_count__ = None; let mut distinct_count__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::MinValue => { if min_value__.is_some() { return Err(serde::de::Error::duplicate_field("minValue")); } - min_value__ = map.next_value()?; + min_value__ = map_.next_value()?; } GeneratedField::MaxValue => { if max_value__.is_some() { return Err(serde::de::Error::duplicate_field("maxValue")); } - max_value__ = map.next_value()?; + max_value__ = map_.next_value()?; } GeneratedField::NullCount => { if null_count__.is_some() { return Err(serde::de::Error::duplicate_field("nullCount")); } - null_count__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + null_count__ = map_.next_value()?; } GeneratedField::DistinctCount => { if distinct_count__.is_some() { return Err(serde::de::Error::duplicate_field("distinctCount")); } - distinct_count__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + distinct_count__ = map_.next_value()?; } } } Ok(ColumnStats { min_value: min_value__, max_value: max_value__, - null_count: null_count__.unwrap_or_default(), - distinct_count: distinct_count__.unwrap_or_default(), + null_count: null_count__, + distinct_count: distinct_count__, }) } } deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogNode { +impl serde::Serialize for CompressionTypeVariant { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GZIP", + "BZIP2", + "XZ", + "ZSTD", + "UNCOMPRESSED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CompressionTypeVariant; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "GZIP" => Ok(CompressionTypeVariant::Gzip), + "BZIP2" => Ok(CompressionTypeVariant::Bzip2), + "XZ" => Ok(CompressionTypeVariant::Xz), + "ZSTD" => Ok(CompressionTypeVariant::Zstd), + "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for Constraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3266,47 +3529,39 @@ impl serde::Serialize for CreateCatalogNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog_name.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if self.schema.is_some() { + if self.constraint_mode.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; - if !self.catalog_name.is_empty() { - struct_ser.serialize_field("catalogName", &self.catalog_name)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Constraint", len)?; + if let Some(v) = self.constraint_mode.as_ref() { + match v { + constraint::ConstraintMode::PrimaryKey(v) => { + struct_ser.serialize_field("primaryKey", v)?; + } + constraint::ConstraintMode::Unique(v) => { + struct_ser.serialize_field("unique", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateCatalogNode { +impl<'de> serde::Deserialize<'de> for Constraint { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog_name", - "catalogName", - "if_not_exists", - "ifNotExists", - "schema", + "primary_key", + "primaryKey", + "unique", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - CatalogName, - IfNotExists, - Schema, + PrimaryKey, + Unique, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3328,9 +3583,8 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { E: serde::de::Error, { match value { - "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "schema" => Ok(GeneratedField::Schema), + "primaryKey" | "primary_key" => Ok(GeneratedField::PrimaryKey), + "unique" => Ok(GeneratedField::Unique), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3340,52 +3594,44 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateCatalogNode; + type Value = Constraint; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateCatalogNode") + formatter.write_str("struct datafusion.Constraint") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog_name__ = None; - let mut if_not_exists__ = None; - let mut schema__ = None; - while let Some(k) = map.next_key()? { + let mut constraint_mode__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::CatalogName => { - if catalog_name__.is_some() { - return Err(serde::de::Error::duplicate_field("catalogName")); - } - catalog_name__ = Some(map.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); + GeneratedField::PrimaryKey => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("primaryKey")); } - if_not_exists__ = Some(map.next_value()?); + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::PrimaryKey) +; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Unique => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("unique")); } - schema__ = map.next_value()?; + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::Unique) +; } } } - Ok(CreateCatalogNode { - catalog_name: catalog_name__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - schema: schema__, + Ok(Constraint { + constraint_mode: constraint_mode__, }) } } - deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Constraint", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogSchemaNode { +impl serde::Serialize for Constraints { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3393,11 +3639,229 @@ impl serde::Serialize for CreateCatalogSchemaNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema_name.is_empty() { + if !self.constraints.is_empty() { len += 1; } - if self.if_not_exists { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.Constraints", len)?; + if !self.constraints.is_empty() { + struct_ser.serialize_field("constraints", &self.constraints)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Constraints { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "constraints", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Constraints, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "constraints" => Ok(GeneratedField::Constraints), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Constraints; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Constraints") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut constraints__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); + } + constraints__ = Some(map_.next_value()?); + } + } + } + Ok(Constraints { + constraints: constraints__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CreateCatalogNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.catalog_name.is_empty() { + len += 1; + } + if self.if_not_exists { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; + if !self.catalog_name.is_empty() { + struct_ser.serialize_field("catalogName", &self.catalog_name)?; + } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CreateCatalogNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "catalog_name", + "catalogName", + "if_not_exists", + "ifNotExists", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + CatalogName, + IfNotExists, + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CreateCatalogNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CreateCatalogNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut catalog_name__ = None; + let mut if_not_exists__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::CatalogName => { + if catalog_name__.is_some() { + return Err(serde::de::Error::duplicate_field("catalogName")); + } + catalog_name__ = Some(map_.next_value()?); + } + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); + } + if_not_exists__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(CreateCatalogNode { + catalog_name: catalog_name__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CreateCatalogSchemaNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.schema_name.is_empty() { + len += 1; + } + if self.if_not_exists { + len += 1; } if self.schema.is_some() { len += 1; @@ -3473,32 +3937,32 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { formatter.write_str("struct datafusion.CreateCatalogSchemaNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut schema_name__ = None; let mut if_not_exists__ = None; let mut schema__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::SchemaName => { if schema_name__.is_some() { return Err(serde::de::Error::duplicate_field("schemaName")); } - schema_name__ = Some(map.next_value()?); + schema_name__ = Some(map_.next_value()?); } GeneratedField::IfNotExists => { if if_not_exists__.is_some() { return Err(serde::de::Error::duplicate_field("ifNotExists")); } - if_not_exists__ = Some(map.next_value()?); + if_not_exists__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } } } @@ -3559,6 +4023,12 @@ impl serde::Serialize for CreateExternalTableNode { if !self.options.is_empty() { len += 1; } + if self.constraints.is_some() { + len += 1; + } + if !self.column_defaults.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; if let Some(v) = self.name.as_ref() { struct_ser.serialize_field("name", v)?; @@ -3599,6 +4069,12 @@ impl serde::Serialize for CreateExternalTableNode { if !self.options.is_empty() { struct_ser.serialize_field("options", &self.options)?; } + if let Some(v) = self.constraints.as_ref() { + struct_ser.serialize_field("constraints", v)?; + } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; + } struct_ser.end() } } @@ -3628,6 +4104,9 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "orderExprs", "unbounded", "options", + "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] @@ -3645,6 +4124,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { OrderExprs, Unbounded, Options, + Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3679,6 +4160,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), "unbounded" => Ok(GeneratedField::Unbounded), "options" => Ok(GeneratedField::Options), + "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3694,7 +4177,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { formatter.write_str("struct datafusion.CreateExternalTableNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -3711,86 +4194,102 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut order_exprs__ = None; let mut unbounded__ = None; let mut options__ = None; - while let Some(k) = map.next_key()? { + let mut constraints__ = None; + let mut column_defaults__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = map.next_value()?; + name__ = map_.next_value()?; } GeneratedField::Location => { if location__.is_some() { return Err(serde::de::Error::duplicate_field("location")); } - location__ = Some(map.next_value()?); + location__ = Some(map_.next_value()?); } GeneratedField::FileType => { if file_type__.is_some() { return Err(serde::de::Error::duplicate_field("fileType")); } - file_type__ = Some(map.next_value()?); + file_type__ = Some(map_.next_value()?); } GeneratedField::HasHeader => { if has_header__.is_some() { return Err(serde::de::Error::duplicate_field("hasHeader")); } - has_header__ = Some(map.next_value()?); + has_header__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::TablePartitionCols => { if table_partition_cols__.is_some() { return Err(serde::de::Error::duplicate_field("tablePartitionCols")); } - table_partition_cols__ = Some(map.next_value()?); + table_partition_cols__ = Some(map_.next_value()?); } GeneratedField::IfNotExists => { if if_not_exists__.is_some() { return Err(serde::de::Error::duplicate_field("ifNotExists")); } - if_not_exists__ = Some(map.next_value()?); + if_not_exists__ = Some(map_.next_value()?); } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); } - delimiter__ = Some(map.next_value()?); + delimiter__ = Some(map_.next_value()?); } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); } - definition__ = Some(map.next_value()?); + definition__ = Some(map_.next_value()?); } GeneratedField::FileCompressionType => { if file_compression_type__.is_some() { return Err(serde::de::Error::duplicate_field("fileCompressionType")); } - file_compression_type__ = Some(map.next_value()?); + file_compression_type__ = Some(map_.next_value()?); } GeneratedField::OrderExprs => { if order_exprs__.is_some() { return Err(serde::de::Error::duplicate_field("orderExprs")); } - order_exprs__ = Some(map.next_value()?); + order_exprs__ = Some(map_.next_value()?); } GeneratedField::Unbounded => { if unbounded__.is_some() { return Err(serde::de::Error::duplicate_field("unbounded")); } - unbounded__ = Some(map.next_value()?); + unbounded__ = Some(map_.next_value()?); } GeneratedField::Options => { if options__.is_some() { return Err(serde::de::Error::duplicate_field("options")); } options__ = Some( - map.next_value::>()? + map_.next_value::>()? + ); + } + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); + } + constraints__ = map_.next_value()?; + } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? ); } } @@ -3809,6 +4308,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { order_exprs: order_exprs__.unwrap_or_default(), unbounded: unbounded__.unwrap_or_default(), options: options__.unwrap_or_default(), + constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), }) } } @@ -3911,7 +4412,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { formatter.write_str("struct datafusion.CreateViewNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -3919,31 +4420,31 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { let mut input__ = None; let mut or_replace__ = None; let mut definition__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = map.next_value()?; + name__ = map_.next_value()?; } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::OrReplace => { if or_replace__.is_some() { return Err(serde::de::Error::duplicate_field("orReplace")); } - or_replace__ = Some(map.next_value()?); + or_replace__ = Some(map_.next_value()?); } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); } - definition__ = Some(map.next_value()?); + definition__ = Some(map_.next_value()?); } } } @@ -4035,25 +4536,25 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { formatter.write_str("struct datafusion.CrossJoinExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut left__ = None; let mut right__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { if left__.is_some() { return Err(serde::de::Error::duplicate_field("left")); } - left__ = map.next_value()?; + left__ = map_.next_value()?; } GeneratedField::Right => { if right__.is_some() { return Err(serde::de::Error::duplicate_field("right")); } - right__ = map.next_value()?; + right__ = map_.next_value()?; } } } @@ -4143,25 +4644,25 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { formatter.write_str("struct datafusion.CrossJoinNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut left__ = None; let mut right__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { if left__.is_some() { return Err(serde::de::Error::duplicate_field("left")); } - left__ = map.next_value()?; + left__ = map_.next_value()?; } GeneratedField::Right => { if right__.is_some() { return Err(serde::de::Error::duplicate_field("right")); } - right__ = map.next_value()?; + right__ = map_.next_value()?; } } } @@ -4188,6 +4689,12 @@ impl serde::Serialize for CsvFormat { if !self.delimiter.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvFormat", len)?; if self.has_header { struct_ser.serialize_field("hasHeader", &self.has_header)?; @@ -4195,6 +4702,16 @@ impl serde::Serialize for CsvFormat { if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_format::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } + } struct_ser.end() } } @@ -4208,12 +4725,16 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { "has_header", "hasHeader", "delimiter", + "quote", + "escape", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { HasHeader, Delimiter, + Quote, + Escape, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4237,6 +4758,8 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { match value { "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4252,31 +4775,47 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { formatter.write_str("struct datafusion.CsvFormat") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut has_header__ = None; let mut delimiter__ = None; - while let Some(k) = map.next_key()? { + let mut quote__ = None; + let mut optional_escape__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { if has_header__.is_some() { return Err(serde::de::Error::duplicate_field("hasHeader")); } - has_header__ = Some(map.next_value()?); + has_header__ = Some(map_.next_value()?); } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); } - delimiter__ = Some(map.next_value()?); + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_format::OptionalEscape::Escape); } } } Ok(CsvFormat { has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, }) } } @@ -4300,6 +4839,12 @@ impl serde::Serialize for CsvScanExecNode { if !self.delimiter.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; if let Some(v) = self.base_conf.as_ref() { struct_ser.serialize_field("baseConf", v)?; @@ -4310,6 +4855,16 @@ impl serde::Serialize for CsvScanExecNode { if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if let Some(v) = self.optional_escape.as_ref() { + match v { + csv_scan_exec_node::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; + } + } + } struct_ser.end() } } @@ -4325,6 +4880,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "has_header", "hasHeader", "delimiter", + "quote", + "escape", ]; #[allow(clippy::enum_variant_names)] @@ -4332,6 +4889,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { BaseConf, HasHeader, Delimiter, + Quote, + Escape, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4356,6 +4915,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4371,32 +4932,46 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { formatter.write_str("struct datafusion.CsvScanExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut base_conf__ = None; let mut has_header__ = None; let mut delimiter__ = None; - while let Some(k) = map.next_key()? { + let mut quote__ = None; + let mut optional_escape__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::BaseConf => { if base_conf__.is_some() { return Err(serde::de::Error::duplicate_field("baseConf")); } - base_conf__ = map.next_value()?; + base_conf__ = map_.next_value()?; } GeneratedField::HasHeader => { if has_header__.is_some() { return Err(serde::de::Error::duplicate_field("hasHeader")); } - has_header__ = Some(map.next_value()?); + has_header__ = Some(map_.next_value()?); } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); } - delimiter__ = Some(map.next_value()?); + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); } } } @@ -4404,6 +4979,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { base_conf: base_conf__, has_header: has_header__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + optional_escape: optional_escape__, }) } } @@ -4478,18 +5055,18 @@ impl<'de> serde::Deserialize<'de> for CubeNode { formatter.write_str("struct datafusion.CubeNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } } } @@ -4538,6 +5115,7 @@ impl serde::Serialize for CustomTableScanNode { struct_ser.serialize_field("filters", &self.filters)?; } if !self.custom_table_data.is_empty() { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; } struct_ser.end() @@ -4607,7 +5185,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { formatter.write_str("struct datafusion.CustomTableScanNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -4616,38 +5194,38 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { let mut schema__ = None; let mut filters__ = None; let mut custom_table_data__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::TableName => { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = map.next_value()?; + table_name__ = map_.next_value()?; } GeneratedField::Projection => { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - projection__ = map.next_value()?; + projection__ = map_.next_value()?; } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::Filters => { if filters__.is_some() { return Err(serde::de::Error::duplicate_field("filters")); } - filters__ = Some(map.next_value()?); + filters__ = Some(map_.next_value()?); } GeneratedField::CustomTableData => { if custom_table_data__.is_some() { return Err(serde::de::Error::duplicate_field("customTableData")); } custom_table_data__ = - Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } } @@ -4701,10 +5279,9 @@ impl<'de> serde::Deserialize<'de> for DateUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(DateUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -4714,10 +5291,9 @@ impl<'de> serde::Deserialize<'de> for DateUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(DateUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -4814,20 +5390,20 @@ impl<'de> serde::Deserialize<'de> for Decimal { formatter.write_str("struct datafusion.Decimal") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut precision__ = None; let mut scale__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Precision => { if precision__.is_some() { return Err(serde::de::Error::duplicate_field("precision")); } precision__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::Scale => { @@ -4835,7 +5411,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { return Err(serde::de::Error::duplicate_field("scale")); } scale__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -4868,12 +5444,15 @@ impl serde::Serialize for Decimal128 { } let mut struct_ser = serializer.serialize_struct("datafusion.Decimal128", len)?; if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } if self.p != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } if self.s != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() @@ -4935,21 +5514,21 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { formatter.write_str("struct datafusion.Decimal128") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut value__ = None; let mut p__ = None; let mut s__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } value__ = - Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } GeneratedField::P => { @@ -4957,7 +5536,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { return Err(serde::de::Error::duplicate_field("p")); } p__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::S => { @@ -4965,7 +5544,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { return Err(serde::de::Error::duplicate_field("s")); } s__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -4980,7 +5559,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DfField { +impl serde::Serialize for Decimal256 { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4988,31 +5567,165 @@ impl serde::Serialize for DfField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field.is_some() { + if !self.value.is_empty() { len += 1; } - if self.qualifier.is_some() { + if self.p != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DfField", len)?; - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; + if self.s != 0 { + len += 1; } - if let Some(v) = self.qualifier.as_ref() { - struct_ser.serialize_field("qualifier", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DfField { +impl<'de> serde::Deserialize<'de> for Decimal256 { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field", - "qualifier", + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Decimal256") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for DfField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field.is_some() { + len += 1; + } + if self.qualifier.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DfField", len)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DfField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field", + "qualifier", ]; #[allow(clippy::enum_variant_names)] @@ -5057,25 +5770,25 @@ impl<'de> serde::Deserialize<'de> for DfField { formatter.write_str("struct datafusion.DfField") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut field__ = None; let mut qualifier__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Field => { if field__.is_some() { return Err(serde::de::Error::duplicate_field("field")); } - field__ = map.next_value()?; + field__ = map_.next_value()?; } GeneratedField::Qualifier => { if qualifier__.is_some() { return Err(serde::de::Error::duplicate_field("qualifier")); } - qualifier__ = map.next_value()?; + qualifier__ = map_.next_value()?; } } } @@ -5165,26 +5878,26 @@ impl<'de> serde::Deserialize<'de> for DfSchema { formatter.write_str("struct datafusion.DfSchema") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut columns__ = None; let mut metadata__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Columns => { if columns__.is_some() { return Err(serde::de::Error::duplicate_field("columns")); } - columns__ = Some(map.next_value()?); + columns__ = Some(map_.next_value()?); } GeneratedField::Metadata => { if metadata__.is_some() { return Err(serde::de::Error::duplicate_field("metadata")); } metadata__ = Some( - map.next_value::>()? + map_.next_value::>()? ); } } @@ -5275,25 +5988,25 @@ impl<'de> serde::Deserialize<'de> for Dictionary { formatter.write_str("struct datafusion.Dictionary") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut key__ = None; let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Key => { if key__.is_some() { return Err(serde::de::Error::duplicate_field("key")); } - key__ = map.next_value()?; + key__ = map_.next_value()?; } GeneratedField::Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = map.next_value()?; + value__ = map_.next_value()?; } } } @@ -5374,18 +6087,18 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { formatter.write_str("struct datafusion.DistinctNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } } } @@ -5397,6 +6110,151 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for DistinctOnNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.on_expr.is_empty() { + len += 1; + } + if !self.select_expr.is_empty() { + len += 1; + } + if !self.sort_expr.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; + } + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; + } + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DistinctOnNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OnExpr, + SelectExpr, + SortExpr, + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DistinctOnNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.DistinctOnNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); + } + on_expr__ = Some(map_.next_value()?); + } + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); + } + select_expr__ = Some(map_.next_value()?); + } + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map_.next_value()?); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5484,32 +6342,32 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { formatter.write_str("struct datafusion.DropViewNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; let mut if_exists__ = None; let mut schema__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = map.next_value()?; + name__ = map_.next_value()?; } GeneratedField::IfExists => { if if_exists__.is_some() { return Err(serde::de::Error::duplicate_field("ifExists")); } - if_exists__ = Some(map.next_value()?); + if_exists__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } } } @@ -5531,16 +6389,10 @@ impl serde::Serialize for EmptyExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { - len += 1; - } if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; - } if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } @@ -5554,14 +6406,11 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -5584,7 +6433,6 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -5601,30 +6449,22 @@ impl<'de> serde::Deserialize<'de> for EmptyExecNode { formatter.write_str("struct datafusion.EmptyExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; let mut schema__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); - } - produce_one_row__ = Some(map.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } } } Ok(EmptyExecNode { - produce_one_row: produce_one_row__.unwrap_or_default(), schema: schema__, }) } @@ -5689,12 +6529,12 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { formatter.write_str("struct datafusion.EmptyMessage") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map.next_key::()?.is_some() { - let _ = map.next_value::()?; + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; } Ok(EmptyMessage { }) @@ -5772,18 +6612,18 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { formatter.write_str("struct datafusion.EmptyRelationNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut produce_one_row__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::ProduceOneRow => { if produce_one_row__.is_some() { return Err(serde::de::Error::duplicate_field("produceOneRow")); } - produce_one_row__ = Some(map.next_value()?); + produce_one_row__ = Some(map_.next_value()?); } } } @@ -5882,32 +6722,32 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { formatter.write_str("struct datafusion.ExplainExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut schema__ = None; let mut stringified_plans__ = None; let mut verbose__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::StringifiedPlans => { if stringified_plans__.is_some() { return Err(serde::de::Error::duplicate_field("stringifiedPlans")); } - stringified_plans__ = Some(map.next_value()?); + stringified_plans__ = Some(map_.next_value()?); } GeneratedField::Verbose => { if verbose__.is_some() { return Err(serde::de::Error::duplicate_field("verbose")); } - verbose__ = Some(map.next_value()?); + verbose__ = Some(map_.next_value()?); } } } @@ -5998,25 +6838,25 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { formatter.write_str("struct datafusion.ExplainNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut verbose__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Verbose => { if verbose__.is_some() { return Err(serde::de::Error::duplicate_field("verbose")); } - verbose__ = Some(map.next_value()?); + verbose__ = Some(map_.next_value()?); } } } @@ -6049,6 +6889,15 @@ impl serde::Serialize for Field { if !self.children.is_empty() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -6062,6 +6911,16 @@ impl serde::Serialize for Field { if !self.children.is_empty() { struct_ser.serialize_field("children", &self.children)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } struct_ser.end() } } @@ -6077,6 +6936,11 @@ impl<'de> serde::Deserialize<'de> for Field { "arrowType", "nullable", "children", + "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -6085,6 +6949,9 @@ impl<'de> serde::Deserialize<'de> for Field { ArrowType, Nullable, Children, + Metadata, + DictId, + DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6110,6 +6977,9 @@ impl<'de> serde::Deserialize<'de> for Field { "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), + "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6125,7 +6995,7 @@ impl<'de> serde::Deserialize<'de> for Field { formatter.write_str("struct datafusion.Field") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -6133,31 +7003,56 @@ impl<'de> serde::Deserialize<'de> for Field { let mut arrow_type__ = None; let mut nullable__ = None; let mut children__ = None; - while let Some(k) = map.next_key()? { + let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::ArrowType => { if arrow_type__.is_some() { return Err(serde::de::Error::duplicate_field("arrowType")); } - arrow_type__ = map.next_value()?; + arrow_type__ = map_.next_value()?; } GeneratedField::Nullable => { if nullable__.is_some() { return Err(serde::de::Error::duplicate_field("nullable")); } - nullable__ = Some(map.next_value()?); + nullable__ = Some(map_.next_value()?); } GeneratedField::Children => { if children__.is_some() { return Err(serde::de::Error::duplicate_field("children")); } - children__ = Some(map.next_value()?); + children__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); } } } @@ -6166,6 +7061,9 @@ impl<'de> serde::Deserialize<'de> for Field { arrow_type: arrow_type__, nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -6240,18 +7138,18 @@ impl<'de> serde::Deserialize<'de> for FileGroup { formatter.write_str("struct datafusion.FileGroup") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut files__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Files => { if files__.is_some() { return Err(serde::de::Error::duplicate_field("files")); } - files__ = Some(map.next_value()?); + files__ = Some(map_.next_value()?); } } } @@ -6279,9 +7177,11 @@ impl serde::Serialize for FileRange { } let mut struct_ser = serializer.serialize_struct("datafusion.FileRange", len)?; if self.start != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; } if self.end != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; } struct_ser.end() @@ -6340,20 +7240,20 @@ impl<'de> serde::Deserialize<'de> for FileRange { formatter.write_str("struct datafusion.FileRange") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut start__ = None; let mut end__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Start => { if start__.is_some() { return Err(serde::de::Error::duplicate_field("start")); } start__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::End => { @@ -6361,7 +7261,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { return Err(serde::de::Error::duplicate_field("end")); } end__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -6510,7 +7410,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { formatter.write_str("struct datafusion.FileScanExecConf") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -6522,26 +7422,26 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { let mut table_partition_cols__ = None; let mut object_store_url__ = None; let mut output_ordering__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FileGroups => { if file_groups__.is_some() { return Err(serde::de::Error::duplicate_field("fileGroups")); } - file_groups__ = Some(map.next_value()?); + file_groups__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::Projection => { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } projection__ = - Some(map.next_value::>>()? + Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; } @@ -6549,31 +7449,31 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { if limit__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); } - limit__ = map.next_value()?; + limit__ = map_.next_value()?; } GeneratedField::Statistics => { if statistics__.is_some() { return Err(serde::de::Error::duplicate_field("statistics")); } - statistics__ = map.next_value()?; + statistics__ = map_.next_value()?; } GeneratedField::TablePartitionCols => { if table_partition_cols__.is_some() { return Err(serde::de::Error::duplicate_field("tablePartitionCols")); } - table_partition_cols__ = Some(map.next_value()?); + table_partition_cols__ = Some(map_.next_value()?); } GeneratedField::ObjectStoreUrl => { if object_store_url__.is_some() { return Err(serde::de::Error::duplicate_field("objectStoreUrl")); } - object_store_url__ = Some(map.next_value()?); + object_store_url__ = Some(map_.next_value()?); } GeneratedField::OutputOrdering => { if output_ordering__.is_some() { return Err(serde::de::Error::duplicate_field("outputOrdering")); } - output_ordering__ = Some(map.next_value()?); + output_ordering__ = Some(map_.next_value()?); } } } @@ -6592,6 +7492,338 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FileSinkConfig { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.object_store_url.is_empty() { + len += 1; + } + if !self.file_groups.is_empty() { + len += 1; + } + if !self.table_paths.is_empty() { + len += 1; + } + if self.output_schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.single_file_output { + len += 1; + } + if self.unbounded_input { + len += 1; + } + if self.overwrite { + len += 1; + } + if self.file_type_writer_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if !self.table_paths.is_empty() { + struct_ser.serialize_field("tablePaths", &self.table_paths)?; + } + if let Some(v) = self.output_schema.as_ref() { + struct_ser.serialize_field("outputSchema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.single_file_output { + struct_ser.serialize_field("singleFileOutput", &self.single_file_output)?; + } + if self.unbounded_input { + struct_ser.serialize_field("unboundedInput", &self.unbounded_input)?; + } + if self.overwrite { + struct_ser.serialize_field("overwrite", &self.overwrite)?; + } + if let Some(v) = self.file_type_writer_options.as_ref() { + struct_ser.serialize_field("fileTypeWriterOptions", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FileSinkConfig { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "object_store_url", + "objectStoreUrl", + "file_groups", + "fileGroups", + "table_paths", + "tablePaths", + "output_schema", + "outputSchema", + "table_partition_cols", + "tablePartitionCols", + "single_file_output", + "singleFileOutput", + "unbounded_input", + "unboundedInput", + "overwrite", + "file_type_writer_options", + "fileTypeWriterOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ObjectStoreUrl, + FileGroups, + TablePaths, + OutputSchema, + TablePartitionCols, + SingleFileOutput, + UnboundedInput, + Overwrite, + FileTypeWriterOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), + "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "singleFileOutput" | "single_file_output" => Ok(GeneratedField::SingleFileOutput), + "unboundedInput" | "unbounded_input" => Ok(GeneratedField::UnboundedInput), + "overwrite" => Ok(GeneratedField::Overwrite), + "fileTypeWriterOptions" | "file_type_writer_options" => Ok(GeneratedField::FileTypeWriterOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileSinkConfig; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FileSinkConfig") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut object_store_url__ = None; + let mut file_groups__ = None; + let mut table_paths__ = None; + let mut output_schema__ = None; + let mut table_partition_cols__ = None; + let mut single_file_output__ = None; + let mut unbounded_input__ = None; + let mut overwrite__ = None; + let mut file_type_writer_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); + } + object_store_url__ = Some(map_.next_value()?); + } + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); + } + file_groups__ = Some(map_.next_value()?); + } + GeneratedField::TablePaths => { + if table_paths__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePaths")); + } + table_paths__ = Some(map_.next_value()?); + } + GeneratedField::OutputSchema => { + if output_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("outputSchema")); + } + output_schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::SingleFileOutput => { + if single_file_output__.is_some() { + return Err(serde::de::Error::duplicate_field("singleFileOutput")); + } + single_file_output__ = Some(map_.next_value()?); + } + GeneratedField::UnboundedInput => { + if unbounded_input__.is_some() { + return Err(serde::de::Error::duplicate_field("unboundedInput")); + } + unbounded_input__ = Some(map_.next_value()?); + } + GeneratedField::Overwrite => { + if overwrite__.is_some() { + return Err(serde::de::Error::duplicate_field("overwrite")); + } + overwrite__ = Some(map_.next_value()?); + } + GeneratedField::FileTypeWriterOptions => { + if file_type_writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("fileTypeWriterOptions")); + } + file_type_writer_options__ = map_.next_value()?; + } + } + } + Ok(FileSinkConfig { + object_store_url: object_store_url__.unwrap_or_default(), + file_groups: file_groups__.unwrap_or_default(), + table_paths: table_paths__.unwrap_or_default(), + output_schema: output_schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + single_file_output: single_file_output__.unwrap_or_default(), + unbounded_input: unbounded_input__.unwrap_or_default(), + overwrite: overwrite__.unwrap_or_default(), + file_type_writer_options: file_type_writer_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FileTypeWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.file_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileTypeWriterOptions", len)?; + if let Some(v) = self.file_type.as_ref() { + match v { + file_type_writer_options::FileType::JsonOptions(v) => { + struct_ser.serialize_field("jsonOptions", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "json_options", + "jsonOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + JsonOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileTypeWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FileTypeWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut file_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::JsonOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::JsonOptions) +; + } + } + } + Ok(FileTypeWriterOptions { + file_type: file_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.FileTypeWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for FilterExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -6606,6 +7838,9 @@ impl serde::Serialize for FilterExecNode { if self.expr.is_some() { len += 1; } + if self.default_filter_selectivity != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -6613,6 +7848,9 @@ impl serde::Serialize for FilterExecNode { if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; + } struct_ser.end() } } @@ -6625,12 +7863,15 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + DefaultFilterSelectivity, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6654,6 +7895,7 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6669,31 +7911,41 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { formatter.write_str("struct datafusion.FilterExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; - while let Some(k) = map.next_key()? { + let mut default_filter_selectivity__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; + } + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); + } + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } Ok(FilterExecNode { input: input__, expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), }) } } @@ -6768,19 +8020,19 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { formatter.write_str("struct datafusion.FixedSizeBinary") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut length__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Length => { if length__.is_some() { return Err(serde::de::Error::duplicate_field("length")); } length__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -6872,26 +8124,26 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { formatter.write_str("struct datafusion.FixedSizeList") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut field_type__ = None; let mut list_size__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FieldType => { if field_type__.is_some() { return Err(serde::de::Error::duplicate_field("fieldType")); } - field_type__ = map.next_value()?; + field_type__ = map_.next_value()?; } GeneratedField::ListSize => { if list_size__.is_some() { return Err(serde::de::Error::duplicate_field("listSize")); } list_size__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -6991,32 +8243,32 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { formatter.write_str("struct datafusion.FullTableReference") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut catalog__ = None; let mut schema__ = None; let mut table__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Catalog => { if catalog__.is_some() { return Err(serde::de::Error::duplicate_field("catalog")); } - catalog__ = Some(map.next_value()?); + catalog__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = Some(map.next_value()?); + schema__ = Some(map_.next_value()?); } GeneratedField::Table => { if table__.is_some() { return Err(serde::de::Error::duplicate_field("table")); } - table__ = Some(map.next_value()?); + table__ = Some(map_.next_value()?); } } } @@ -7041,15 +8293,25 @@ impl serde::Serialize for GetIndexedField { if self.expr.is_some() { len += 1; } - if self.key.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + if let Some(v) = self.field.as_ref() { + match v { + get_indexed_field::Field::NamedStructField(v) => { + struct_ser.serialize_field("namedStructField", v)?; + } + get_indexed_field::Field::ListIndex(v) => { + struct_ser.serialize_field("listIndex", v)?; + } + get_indexed_field::Field::ListRange(v) => { + struct_ser.serialize_field("listRange", v)?; + } + } } struct_ser.end() } @@ -7062,13 +8324,20 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { { const FIELDS: &[&str] = &[ "expr", - "key", + "named_struct_field", + "namedStructField", + "list_index", + "listIndex", + "list_range", + "listRange", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - Key, + NamedStructField, + ListIndex, + ListRange, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7091,7 +8360,9 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { { match value { "expr" => Ok(GeneratedField::Expr), - "key" => Ok(GeneratedField::Key), + "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), + "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), + "listRange" | "list_range" => Ok(GeneratedField::ListRange), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7107,31 +8378,46 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { formatter.write_str("struct datafusion.GetIndexedField") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut key__ = None; - while let Some(k) = map.next_key()? { + let mut field__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::NamedStructField => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructField")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) +; + } + GeneratedField::ListIndex => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndex")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) +; + } + GeneratedField::ListRange => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRange")); } - key__ = map.next_value()?; + field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) +; } } } Ok(GetIndexedField { expr: expr__, - key: key__, + field: field__, }) } } @@ -7163,6 +8449,7 @@ impl serde::Serialize for GlobalLimitExecNode { struct_ser.serialize_field("skip", &self.skip)?; } if self.fetch != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() @@ -7224,27 +8511,27 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { formatter.write_str("struct datafusion.GlobalLimitExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut skip__ = None; let mut fetch__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Skip => { if skip__.is_some() { return Err(serde::de::Error::duplicate_field("skip")); } skip__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::Fetch => { @@ -7252,7 +8539,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { return Err(serde::de::Error::duplicate_field("fetch")); } fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -7335,18 +8622,18 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { formatter.write_str("struct datafusion.GroupingSetNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } } } @@ -7398,13 +8685,13 @@ impl serde::Serialize for HashJoinExecNode { struct_ser.serialize_field("on", &self.on)?; } if self.join_type != 0 { - let v = JoinType::from_i32(self.join_type) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; struct_ser.serialize_field("joinType", &v)?; } if self.partition_mode != 0 { - let v = PartitionMode::from_i32(self.partition_mode) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + let v = PartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } if self.null_equals_null { @@ -7487,7 +8774,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { formatter.write_str("struct datafusion.HashJoinExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -7498,49 +8785,49 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut partition_mode__ = None; let mut null_equals_null__ = None; let mut filter__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { if left__.is_some() { return Err(serde::de::Error::duplicate_field("left")); } - left__ = map.next_value()?; + left__ = map_.next_value()?; } GeneratedField::Right => { if right__.is_some() { return Err(serde::de::Error::duplicate_field("right")); } - right__ = map.next_value()?; + right__ = map_.next_value()?; } GeneratedField::On => { if on__.is_some() { return Err(serde::de::Error::duplicate_field("on")); } - on__ = Some(map.next_value()?); + on__ = Some(map_.next_value()?); } GeneratedField::JoinType => { if join_type__.is_some() { return Err(serde::de::Error::duplicate_field("joinType")); } - join_type__ = Some(map.next_value::()? as i32); + join_type__ = Some(map_.next_value::()? as i32); } GeneratedField::PartitionMode => { if partition_mode__.is_some() { return Err(serde::de::Error::duplicate_field("partitionMode")); } - partition_mode__ = Some(map.next_value::()? as i32); + partition_mode__ = Some(map_.next_value::()? as i32); } GeneratedField::NullEqualsNull => { if null_equals_null__.is_some() { return Err(serde::de::Error::duplicate_field("nullEqualsNull")); } - null_equals_null__ = Some(map.next_value()?); + null_equals_null__ = Some(map_.next_value()?); } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); } - filter__ = map.next_value()?; + filter__ = map_.next_value()?; } } } @@ -7577,6 +8864,7 @@ impl serde::Serialize for HashRepartition { struct_ser.serialize_field("hashExpr", &self.hash_expr)?; } if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; } struct_ser.end() @@ -7637,26 +8925,26 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { formatter.write_str("struct datafusion.HashRepartition") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut hash_expr__ = None; let mut partition_count__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::HashExpr => { if hash_expr__.is_some() { return Err(serde::de::Error::duplicate_field("hashExpr")); } - hash_expr__ = Some(map.next_value()?); + hash_expr__ = Some(map_.next_value()?); } GeneratedField::PartitionCount => { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } partition_count__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -7766,7 +9054,7 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { formatter.write_str("struct datafusion.ILikeNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -7774,31 +9062,31 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { let mut expr__ = None; let mut pattern__ = None; let mut escape_char__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Pattern => { if pattern__.is_some() { return Err(serde::de::Error::duplicate_field("pattern")); } - pattern__ = map.next_value()?; + pattern__ = map_.next_value()?; } GeneratedField::EscapeChar => { if escape_char__.is_some() { return Err(serde::de::Error::duplicate_field("escapeChar")); } - escape_char__ = Some(map.next_value()?); + escape_char__ = Some(map_.next_value()?); } } } @@ -7899,32 +9187,32 @@ impl<'de> serde::Deserialize<'de> for InListNode { formatter.write_str("struct datafusion.InListNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut list__ = None; let mut negated__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::List => { if list__.is_some() { return Err(serde::de::Error::duplicate_field("list")); } - list__ = Some(map.next_value()?); + list__ = Some(map_.next_value()?); } GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } } } @@ -7938,6 +9226,97 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InterleaveExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inputs", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inputs, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InterleaveExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.InterleaveExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } + } + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for IntervalMonthDayNanoValue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7963,6 +9342,7 @@ impl serde::Serialize for IntervalMonthDayNanoValue { struct_ser.serialize_field("days", &self.days)?; } if self.nanos != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; } struct_ser.end() @@ -8024,21 +9404,21 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { formatter.write_str("struct datafusion.IntervalMonthDayNanoValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut months__ = None; let mut days__ = None; let mut nanos__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Months => { if months__.is_some() { return Err(serde::de::Error::duplicate_field("months")); } months__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::Days => { @@ -8046,7 +9426,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { return Err(serde::de::Error::duplicate_field("days")); } days__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::Nanos => { @@ -8054,7 +9434,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { return Err(serde::de::Error::duplicate_field("nanos")); } nanos__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -8108,10 +9488,9 @@ impl<'de> serde::Deserialize<'de> for IntervalUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(IntervalUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -8121,10 +9500,9 @@ impl<'de> serde::Deserialize<'de> for IntervalUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(IntervalUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -8213,18 +9591,18 @@ impl<'de> serde::Deserialize<'de> for IsFalse { formatter.write_str("struct datafusion.IsFalse") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8304,18 +9682,18 @@ impl<'de> serde::Deserialize<'de> for IsNotFalse { formatter.write_str("struct datafusion.IsNotFalse") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8395,18 +9773,18 @@ impl<'de> serde::Deserialize<'de> for IsNotNull { formatter.write_str("struct datafusion.IsNotNull") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8486,18 +9864,18 @@ impl<'de> serde::Deserialize<'de> for IsNotTrue { formatter.write_str("struct datafusion.IsNotTrue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8577,18 +9955,18 @@ impl<'de> serde::Deserialize<'de> for IsNotUnknown { formatter.write_str("struct datafusion.IsNotUnknown") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8668,18 +10046,18 @@ impl<'de> serde::Deserialize<'de> for IsNull { formatter.write_str("struct datafusion.IsNull") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8759,18 +10137,18 @@ impl<'de> serde::Deserialize<'de> for IsTrue { formatter.write_str("struct datafusion.IsTrue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8850,18 +10228,18 @@ impl<'de> serde::Deserialize<'de> for IsUnknown { formatter.write_str("struct datafusion.IsUnknown") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -8910,10 +10288,9 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinConstraint::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -8923,10 +10300,9 @@ impl<'de> serde::Deserialize<'de> for JoinConstraint { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinConstraint::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -9033,32 +10409,32 @@ impl<'de> serde::Deserialize<'de> for JoinFilter { formatter.write_str("struct datafusion.JoinFilter") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expression__ = None; let mut column_indices__ = None; let mut schema__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expression => { if expression__.is_some() { return Err(serde::de::Error::duplicate_field("expression")); } - expression__ = map.next_value()?; + expression__ = map_.next_value()?; } GeneratedField::ColumnIndices => { if column_indices__.is_some() { return Err(serde::de::Error::duplicate_field("columnIndices")); } - column_indices__ = Some(map.next_value()?); + column_indices__ = Some(map_.next_value()?); } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } } } @@ -9112,13 +10488,13 @@ impl serde::Serialize for JoinNode { struct_ser.serialize_field("right", v)?; } if self.join_type != 0 { - let v = JoinType::from_i32(self.join_type) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; struct_ser.serialize_field("joinType", &v)?; } if self.join_constraint != 0 { - let v = JoinConstraint::from_i32(self.join_constraint) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; + let v = JoinConstraint::try_from(self.join_constraint) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; struct_ser.serialize_field("joinConstraint", &v)?; } if !self.left_join_key.is_empty() { @@ -9212,7 +10588,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { formatter.write_str("struct datafusion.JoinNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -9224,55 +10600,55 @@ impl<'de> serde::Deserialize<'de> for JoinNode { let mut right_join_key__ = None; let mut null_equals_null__ = None; let mut filter__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { if left__.is_some() { return Err(serde::de::Error::duplicate_field("left")); } - left__ = map.next_value()?; + left__ = map_.next_value()?; } GeneratedField::Right => { if right__.is_some() { return Err(serde::de::Error::duplicate_field("right")); } - right__ = map.next_value()?; + right__ = map_.next_value()?; } GeneratedField::JoinType => { if join_type__.is_some() { return Err(serde::de::Error::duplicate_field("joinType")); } - join_type__ = Some(map.next_value::()? as i32); + join_type__ = Some(map_.next_value::()? as i32); } GeneratedField::JoinConstraint => { if join_constraint__.is_some() { return Err(serde::de::Error::duplicate_field("joinConstraint")); } - join_constraint__ = Some(map.next_value::()? as i32); + join_constraint__ = Some(map_.next_value::()? as i32); } GeneratedField::LeftJoinKey => { if left_join_key__.is_some() { return Err(serde::de::Error::duplicate_field("leftJoinKey")); } - left_join_key__ = Some(map.next_value()?); + left_join_key__ = Some(map_.next_value()?); } GeneratedField::RightJoinKey => { if right_join_key__.is_some() { return Err(serde::de::Error::duplicate_field("rightJoinKey")); } - right_join_key__ = Some(map.next_value()?); + right_join_key__ = Some(map_.next_value()?); } GeneratedField::NullEqualsNull => { if null_equals_null__.is_some() { return Err(serde::de::Error::duplicate_field("nullEqualsNull")); } - null_equals_null__ = Some(map.next_value()?); + null_equals_null__ = Some(map_.next_value()?); } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); } - filter__ = map.next_value()?; + filter__ = map_.next_value()?; } } } @@ -9368,25 +10744,25 @@ impl<'de> serde::Deserialize<'de> for JoinOn { formatter.write_str("struct datafusion.JoinOn") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut left__ = None; let mut right__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Left => { if left__.is_some() { return Err(serde::de::Error::duplicate_field("left")); } - left__ = map.next_value()?; + left__ = map_.next_value()?; } GeneratedField::Right => { if right__.is_some() { return Err(serde::de::Error::duplicate_field("right")); } - right__ = map.next_value()?; + right__ = map_.next_value()?; } } } @@ -9436,10 +10812,9 @@ impl<'de> serde::Deserialize<'de> for JoinSide { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinSide::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -9449,10 +10824,9 @@ impl<'de> serde::Deserialize<'de> for JoinSide { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinSide::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -9521,10 +10895,9 @@ impl<'de> serde::Deserialize<'de> for JoinType { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinType::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -9534,10 +10907,9 @@ impl<'de> serde::Deserialize<'de> for JoinType { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(JoinType::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -9563,7 +10935,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for LikeNode { +impl serde::Serialize for JsonSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9571,54 +10943,29 @@ impl serde::Serialize for LikeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } - if self.expr.is_some() { + if self.config.is_some() { len += 1; } - if self.pattern.is_some() { - len += 1; - } - if !self.escape_char.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; - } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LikeNode { +impl<'de> serde::Deserialize<'de> for JsonSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "expr", - "pattern", - "escape_char", - "escapeChar", + "config", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - Expr, - Pattern, - EscapeChar, + Config, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9640,10 +10987,7 @@ impl<'de> serde::Deserialize<'de> for LikeNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + "config" => Ok(GeneratedField::Config), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9653,60 +10997,36 @@ impl<'de> serde::Deserialize<'de> for LikeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LikeNode; + type Value = JsonSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LikeNode") + formatter.write_str("struct datafusion.JsonSink") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; - while let Some(k) = map.next_key()? { + let mut config__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map.next_value()?); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map.next_value()?; - } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); - } - pattern__ = map.next_value()?; - } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - escape_char__ = Some(map.next_value()?); + config__ = map_.next_value()?; } } } - Ok(LikeNode { - negated: negated__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), + Ok(JsonSink { + config: config__, }) } } - deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for LimitNode { +impl serde::Serialize for JsonSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9717,26 +11037,32 @@ impl serde::Serialize for LimitNode { if self.input.is_some() { len += 1; } - if self.skip != 0 { + if self.sink.is_some() { len += 1; } - if self.fetch != 0 { + if self.sink_schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } - if self.skip != 0 { - struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; } - if self.fetch != 0 { - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for LimitNode { +impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -9744,15 +11070,19 @@ impl<'de> serde::Deserialize<'de> for LimitNode { { const FIELDS: &[&str] = &[ "input", - "skip", - "fetch", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, - Skip, - Fetch, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9775,8 +11105,9 @@ impl<'de> serde::Deserialize<'de> for LimitNode { { match value { "input" => Ok(GeneratedField::Input), - "skip" => Ok(GeneratedField::Skip), - "fetch" => Ok(GeneratedField::Fetch), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9786,56 +11117,60 @@ impl<'de> serde::Deserialize<'de> for LimitNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LimitNode; + type Value = JsonSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LimitNode") + formatter.write_str("struct datafusion.JsonSinkExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; - let mut skip__ = None; - let mut fetch__ = None; - while let Some(k) = map.next_key()? { + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - skip__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + sink__ = map_.next_value()?; } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; } } } - Ok(LimitNode { + Ok(JsonSinkExecNode { input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for List { +impl serde::Serialize for JsonWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9843,30 +11178,31 @@ impl serde::Serialize for List { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { + if self.compression != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JsonWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for List { +impl<'de> serde::Deserialize<'de> for JsonWriterOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", + "compression", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, + Compression, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9888,7 +11224,7 @@ impl<'de> serde::Deserialize<'de> for List { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "compression" => Ok(GeneratedField::Compression), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9898,36 +11234,36 @@ impl<'de> serde::Deserialize<'de> for List { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = List; + type Value = JsonWriterOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.List") + formatter.write_str("struct datafusion.JsonWriterOptions") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; - while let Some(k) = map.next_key()? { + let mut compression__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); } - field_type__ = map.next_value()?; + compression__ = Some(map_.next_value::()? as i32); } } } - Ok(List { - field_type: field_type__, + Ok(JsonWriterOptions { + compression: compression__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonWriterOptions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ListingTableScanNode { +impl serde::Serialize for LikeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9935,35 +11271,799 @@ impl serde::Serialize for ListingTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if !self.paths.is_empty() { - len += 1; - } - if !self.file_extension.is_empty() { + if self.negated { len += 1; } - if self.projection.is_some() { + if self.expr.is_some() { len += 1; } - if self.schema.is_some() { + if self.pattern.is_some() { len += 1; } - if !self.filters.is_empty() { + if !self.escape_char.is_empty() { len += 1; } - if !self.table_partition_cols.is_empty() { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } - if self.collect_stat { - len += 1; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if self.target_partitions != 0 { - len += 1; + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; } - if !self.file_sort_order.is_empty() { - len += 1; + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LikeNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "negated", + "expr", + "pattern", + "escape_char", + "escapeChar", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Negated, + Expr, + Pattern, + EscapeChar, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "negated" => Ok(GeneratedField::Negated), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LikeNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LikeNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut negated__ = None; + let mut expr__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); + } + pattern__ = map_.next_value()?; + } + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); + } + escape_char__ = Some(map_.next_value()?); + } + } + } + Ok(LikeNode { + negated: negated__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for LimitNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.skip != 0 { + len += 1; + } + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.skip != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for LimitNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "skip", + "fetch", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Skip, + Fetch, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "skip" => Ok(GeneratedField::Skip), + "fetch" => Ok(GeneratedField::Fetch), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LimitNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LimitNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut skip__ = None; + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); + } + skip__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(LimitNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for List { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for List { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.List") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + } + } + Ok(List { + field_type: field_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndex { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndex { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndex; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndex") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map_.next_value()?; + } + } + } + Ok(ListIndex { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndexExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndexExpr", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndexExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndexExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndexExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map_.next_value()?; + } + } + } + Ok(ListIndexExpr { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListRange { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListRange { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListRange; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListRange") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + } + } + Ok(ListRange { + start: start__, + stop: stop__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListRangeExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListRangeExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListRangeExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListRangeExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = map_.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map_.next_value()?; + } + } + } + Ok(ListRangeExpr { + start: start__, + stop: stop__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListingTableScanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.table_name.is_some() { + len += 1; + } + if !self.paths.is_empty() { + len += 1; + } + if !self.file_extension.is_empty() { + len += 1; + } + if self.projection.is_some() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.filters.is_empty() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.collect_stat { + len += 1; + } + if self.target_partitions != 0 { + len += 1; + } + if !self.file_sort_order.is_empty() { + len += 1; } if self.file_format_type.is_some() { len += 1; @@ -10107,7 +12207,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { formatter.write_str("struct datafusion.ListingTableScanNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -10122,89 +12222,89 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { let mut target_partitions__ = None; let mut file_sort_order__ = None; let mut file_format_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::TableName => { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = map.next_value()?; + table_name__ = map_.next_value()?; } GeneratedField::Paths => { if paths__.is_some() { return Err(serde::de::Error::duplicate_field("paths")); } - paths__ = Some(map.next_value()?); + paths__ = Some(map_.next_value()?); } GeneratedField::FileExtension => { if file_extension__.is_some() { return Err(serde::de::Error::duplicate_field("fileExtension")); } - file_extension__ = Some(map.next_value()?); + file_extension__ = Some(map_.next_value()?); } GeneratedField::Projection => { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - projection__ = map.next_value()?; + projection__ = map_.next_value()?; } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::Filters => { if filters__.is_some() { return Err(serde::de::Error::duplicate_field("filters")); } - filters__ = Some(map.next_value()?); + filters__ = Some(map_.next_value()?); } GeneratedField::TablePartitionCols => { if table_partition_cols__.is_some() { return Err(serde::de::Error::duplicate_field("tablePartitionCols")); } - table_partition_cols__ = Some(map.next_value()?); + table_partition_cols__ = Some(map_.next_value()?); } GeneratedField::CollectStat => { if collect_stat__.is_some() { return Err(serde::de::Error::duplicate_field("collectStat")); } - collect_stat__ = Some(map.next_value()?); + collect_stat__ = Some(map_.next_value()?); } GeneratedField::TargetPartitions => { if target_partitions__.is_some() { return Err(serde::de::Error::duplicate_field("targetPartitions")); } target_partitions__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::FileSortOrder => { if file_sort_order__.is_some() { return Err(serde::de::Error::duplicate_field("fileSortOrder")); } - file_sort_order__ = Some(map.next_value()?); + file_sort_order__ = Some(map_.next_value()?); } GeneratedField::Csv => { if file_format_type__.is_some() { return Err(serde::de::Error::duplicate_field("csv")); } - file_format_type__ = map.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) ; } GeneratedField::Parquet => { if file_format_type__.is_some() { return Err(serde::de::Error::duplicate_field("parquet")); } - file_format_type__ = map.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) ; } GeneratedField::Avro => { if file_format_type__.is_some() { return Err(serde::de::Error::duplicate_field("avro")); } - file_format_type__ = map.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) ; } } @@ -10304,26 +12404,26 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { formatter.write_str("struct datafusion.LocalLimitExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut fetch__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Fetch => { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -10405,18 +12505,18 @@ impl<'de> serde::Deserialize<'de> for LogicalExprList { formatter.write_str("struct datafusion.LogicalExprList") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } } } @@ -10719,248 +12819,249 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { formatter.write_str("struct datafusion.LogicalExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) ; } GeneratedField::Alias => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("alias")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) ; } GeneratedField::Literal => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("literal")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) ; } GeneratedField::BinaryExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("binaryExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) ; } GeneratedField::AggregateExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("aggregateExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) ; } GeneratedField::IsNullExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNullExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) ; } GeneratedField::IsNotNullExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNotNullExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) ; } GeneratedField::NotExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("notExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) ; } GeneratedField::Between => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("between")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) ; } GeneratedField::Case => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("case")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) ; } GeneratedField::Cast => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("cast")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) ; } GeneratedField::Sort => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("sort")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) ; } GeneratedField::Negative => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("negative")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) ; } GeneratedField::InList => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("inList")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) ; } GeneratedField::Wildcard => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("wildcard")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) +; } GeneratedField::ScalarFunction => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("scalarFunction")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarFunction) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarFunction) ; } GeneratedField::TryCast => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("tryCast")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) ; } GeneratedField::WindowExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("windowExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) ; } GeneratedField::AggregateUdfExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("aggregateUdfExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) ; } GeneratedField::ScalarUdfExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("scalarUdfExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) ; } GeneratedField::GetIndexedField => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("getIndexedField")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GetIndexedField) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GetIndexedField) ; } GeneratedField::GroupingSet => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("groupingSet")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) ; } GeneratedField::Cube => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("cube")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) ; } GeneratedField::Rollup => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("rollup")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) ; } GeneratedField::IsTrue => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isTrue")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) ; } GeneratedField::IsFalse => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isFalse")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) ; } GeneratedField::IsUnknown => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isUnknown")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) ; } GeneratedField::IsNotTrue => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNotTrue")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) ; } GeneratedField::IsNotFalse => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNotFalse")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) ; } GeneratedField::IsNotUnknown => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNotUnknown")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) ; } GeneratedField::Like => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("like")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) ; } GeneratedField::Ilike => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("ilike")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) ; } GeneratedField::SimilarTo => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("similarTo")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) ; } GeneratedField::Placeholder => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("placeholder")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) ; } } @@ -11042,18 +13143,18 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { formatter.write_str("struct datafusion.LogicalExprNodeCollection") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut logical_expr_nodes__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::LogicalExprNodes => { if logical_expr_nodes__.is_some() { return Err(serde::de::Error::duplicate_field("logicalExprNodes")); } - logical_expr_nodes__ = Some(map.next_value()?); + logical_expr_nodes__ = Some(map_.next_value()?); } } } @@ -11081,6 +13182,7 @@ impl serde::Serialize for LogicalExtensionNode { } let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExtensionNode", len)?; if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; } if !self.inputs.is_empty() { @@ -11142,27 +13244,27 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { formatter.write_str("struct datafusion.LogicalExtensionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut node__ = None; let mut inputs__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Node => { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } node__ = - Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } GeneratedField::Inputs => { if inputs__.is_some() { return Err(serde::de::Error::duplicate_field("inputs")); } - inputs__ = Some(map.next_value()?); + inputs__ = Some(map_.next_value()?); } } } @@ -11267,6 +13369,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DropView(v) => { struct_ser.serialize_field("dropView", v)?; } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } } } struct_ser.end() @@ -11316,6 +13421,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare", "drop_view", "dropView", + "distinct_on", + "distinctOn", ]; #[allow(clippy::enum_variant_names)] @@ -11346,6 +13453,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { CustomScan, Prepare, DropView, + DistinctOn, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11393,6 +13501,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11408,193 +13517,200 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { formatter.write_str("struct datafusion.LogicalPlanNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut logical_plan_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::ListingScan => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("listingScan")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) ; } GeneratedField::Projection => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) ; } GeneratedField::Selection => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("selection")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) ; } GeneratedField::Limit => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) ; } GeneratedField::Aggregate => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("aggregate")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) ; } GeneratedField::Join => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("join")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) ; } GeneratedField::Sort => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("sort")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) ; } GeneratedField::Repartition => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("repartition")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) ; } GeneratedField::EmptyRelation => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("emptyRelation")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) ; } GeneratedField::CreateExternalTable => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("createExternalTable")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) ; } GeneratedField::Explain => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("explain")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) ; } GeneratedField::Window => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("window")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) ; } GeneratedField::Analyze => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("analyze")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) ; } GeneratedField::CrossJoin => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("crossJoin")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) ; } GeneratedField::Values => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("values")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) ; } GeneratedField::Extension => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("extension")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) ; } GeneratedField::CreateCatalogSchema => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("createCatalogSchema")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) ; } GeneratedField::Union => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("union")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) ; } GeneratedField::CreateCatalog => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("createCatalog")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) ; } GeneratedField::SubqueryAlias => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("subqueryAlias")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) ; } GeneratedField::CreateView => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("createView")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) ; } GeneratedField::Distinct => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("distinct")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) ; } GeneratedField::ViewScan => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("viewScan")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) ; } GeneratedField::CustomScan => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("customScan")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) ; } GeneratedField::Prepare => { if logical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("prepare")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) +; + } + GeneratedField::DropView => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dropView")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) ; } - GeneratedField::DropView => { + GeneratedField::DistinctOn => { if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dropView")); + return Err(serde::de::Error::duplicate_field("distinctOn")); } - logical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) ; } } @@ -11686,25 +13802,25 @@ impl<'de> serde::Deserialize<'de> for Map { formatter.write_str("struct datafusion.Map") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut field_type__ = None; let mut keys_sorted__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FieldType => { if field_type__.is_some() { return Err(serde::de::Error::duplicate_field("fieldType")); } - field_type__ = map.next_value()?; + field_type__ = map_.next_value()?; } GeneratedField::KeysSorted => { if keys_sorted__.is_some() { return Err(serde::de::Error::duplicate_field("keysSorted")); } - keys_sorted__ = Some(map.next_value()?); + keys_sorted__ = Some(map_.next_value()?); } } } @@ -11785,18 +13901,18 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { formatter.write_str("struct datafusion.MaybeFilter") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -11877,18 +13993,18 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { formatter.write_str("struct datafusion.MaybePhysicalSortExprs") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut sort_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::SortExpr => { if sort_expr__.is_some() { return Err(serde::de::Error::duplicate_field("sortExpr")); } - sort_expr__ = Some(map.next_value()?); + sort_expr__ = Some(map_.next_value()?); } } } @@ -11900,6 +14016,188 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NamedStructField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NamedStructField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NamedStructField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NamedStructField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = map_.next_value()?; + } + } + } + Ok(NamedStructField { + name: name__, + }) + } + } + deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for NamedStructFieldExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructFieldExpr", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NamedStructFieldExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NamedStructFieldExpr") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = map_.next_value()?; + } + } + } + Ok(NamedStructFieldExpr { + name: name__, + }) + } + } + deserializer.deserialize_struct("datafusion.NamedStructFieldExpr", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -11911,26 +14209,144 @@ impl serde::Serialize for NegativeNode { if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NegativeNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NegativeNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NegativeNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + } + } + Ok(NegativeNode { + expr: expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for NestedLoopJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for NegativeNode { +impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "left", + "right", + "join_type", + "joinType", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Left, + Right, + JoinType, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11952,7 +14368,10 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11962,33 +14381,57 @@ impl<'de> serde::Deserialize<'de> for NegativeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NegativeNode; + type Value = NestedLoopJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NegativeNode") + formatter.write_str("struct datafusion.NestedLoopJoinExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - while let Some(k) = map.next_key()? { + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut filter__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); } - expr__ = map.next_value()?; + filter__ = map_.next_value()?; } } } - Ok(NegativeNode { - expr: expr__, + Ok(NestedLoopJoinExecNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for Not { @@ -12059,18 +14502,18 @@ impl<'de> serde::Deserialize<'de> for Not { formatter.write_str("struct datafusion.Not") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -12151,18 +14594,18 @@ impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { formatter.write_str("struct datafusion.OptimizedLogicalPlanType") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut optimizer_name__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::OptimizerName => { if optimizer_name__.is_some() { return Err(serde::de::Error::duplicate_field("optimizerName")); } - optimizer_name__ = Some(map.next_value()?); + optimizer_name__ = Some(map_.next_value()?); } } } @@ -12243,18 +14686,18 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { formatter.write_str("struct datafusion.OptimizedPhysicalPlanType") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut optimizer_name__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::OptimizerName => { if optimizer_name__.is_some() { return Err(serde::de::Error::duplicate_field("optimizerName")); } - optimizer_name__ = Some(map.next_value()?); + optimizer_name__ = Some(map_.next_value()?); } } } @@ -12350,32 +14793,32 @@ impl<'de> serde::Deserialize<'de> for OwnedTableReference { formatter.write_str("struct datafusion.OwnedTableReference") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut table_reference_enum__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Bare => { if table_reference_enum__.is_some() { return Err(serde::de::Error::duplicate_field("bare")); } - table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Bare) + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Bare) ; } GeneratedField::Partial => { if table_reference_enum__.is_some() { return Err(serde::de::Error::duplicate_field("partial")); } - table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Partial) + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Partial) ; } GeneratedField::Full => { if table_reference_enum__.is_some() { return Err(serde::de::Error::duplicate_field("full")); } - table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Full) + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Full) ; } } @@ -12431,7 +14874,207 @@ impl<'de> serde::Deserialize<'de> for ParquetFormat { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(ParquetFormat { + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetFormat", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetScanExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.base_conf.is_some() { + len += 1; + } + if self.predicate.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + if let Some(v) = self.predicate.as_ref() { + struct_ser.serialize_field("predicate", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "base_conf", + "baseConf", + "predicate", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + BaseConf, + Predicate, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + "predicate" => Ok(GeneratedField::Predicate), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetScanExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetScanExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut base_conf__ = None; + let mut predicate__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); + } + base_conf__ = map_.next_value()?; + } + GeneratedField::Predicate => { + if predicate__.is_some() { + return Err(serde::de::Error::duplicate_field("predicate")); + } + predicate__ = map_.next_value()?; + } + } + } + Ok(ParquetScanExecNode { + base_conf: base_conf__, + predicate: predicate__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PartialTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.schema.is_empty() { + len += 1; + } + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartialTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -12439,27 +15082,44 @@ impl<'de> serde::Deserialize<'de> for ParquetFormat { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetFormat; + type Value = PartialTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetFormat") + formatter.write_str("struct datafusion.PartialTableReference") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map.next_key::()?.is_some() { - let _ = map.next_value::()?; + let mut schema__ = None; + let mut table__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = Some(map_.next_value()?); + } + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map_.next_value()?); + } + } } - Ok(ParquetFormat { + Ok(PartialTableReference { + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ParquetFormat", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ParquetScanExecNode { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12467,38 +15127,29 @@ impl serde::Serialize for ParquetScanExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.base_conf.is_some() { - len += 1; - } - if self.predicate.is_some() { + if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; - } - if let Some(v) = self.predicate.as_ref() { - struct_ser.serialize_field("predicate", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", - "predicate", + "columns", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - BaseConf, - Predicate, + Columns, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12520,8 +15171,7 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { E: serde::de::Error, { match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "predicate" => Ok(GeneratedField::Predicate), + "columns" => Ok(GeneratedField::Columns), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12531,44 +15181,39 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetScanExecNode; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetScanExecNode") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut base_conf__ = None; - let mut predicate__ = None; - while let Some(k) = map.next_key()? { + let mut columns__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); - } - base_conf__ = map.next_value()?; - } - GeneratedField::Predicate => { - if predicate__.is_some() { - return Err(serde::de::Error::duplicate_field("predicate")); + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); } - predicate__ = map.next_value()?; + columns__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } } } - Ok(ParquetScanExecNode { - base_conf: base_conf__, - predicate: predicate__, + Ok(PartiallySortedInputOrderMode { + columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartialTableReference { +impl serde::Serialize for PartitionColumn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -12576,37 +15221,38 @@ impl serde::Serialize for PartialTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema.is_empty() { + if !self.name.is_empty() { len += 1; } - if !self.table.is_empty() { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartialTableReference { +impl<'de> serde::Deserialize<'de> for PartitionColumn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", - "table", + "name", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, - Table, + Name, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12628,8 +15274,8 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12639,41 +15285,41 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartialTableReference; + type Value = PartitionColumn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartialTableReference") + formatter.write_str("struct datafusion.PartitionColumn") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; - let mut table__ = None; - while let Some(k) = map.next_key()? { + let mut name__ = None; + let mut arrow_type__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - schema__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - table__ = Some(map.next_value()?); + arrow_type__ = map_.next_value()?; } } } - Ok(PartialTableReference { - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(PartitionColumn { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) } } impl serde::Serialize for PartitionMode { @@ -12715,10 +15361,9 @@ impl<'de> serde::Deserialize<'de> for PartitionMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(PartitionMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -12728,10 +15373,9 @@ impl<'de> serde::Deserialize<'de> for PartitionMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(PartitionMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -12774,12 +15418,15 @@ impl serde::Serialize for PartitionStats { } let mut struct_ser = serializer.serialize_struct("datafusion.PartitionStats", len)?; if self.num_rows != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; } if self.num_batches != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("numBatches", ToString::to_string(&self.num_batches).as_str())?; } if self.num_bytes != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("numBytes", ToString::to_string(&self.num_bytes).as_str())?; } if !self.column_stats.is_empty() { @@ -12851,7 +15498,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { formatter.write_str("struct datafusion.PartitionStats") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -12859,14 +15506,14 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { let mut num_batches__ = None; let mut num_bytes__ = None; let mut column_stats__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::NumRows => { if num_rows__.is_some() { return Err(serde::de::Error::duplicate_field("numRows")); } num_rows__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::NumBatches => { @@ -12874,7 +15521,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { return Err(serde::de::Error::duplicate_field("numBatches")); } num_batches__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::NumBytes => { @@ -12882,14 +15529,14 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { return Err(serde::de::Error::duplicate_field("numBytes")); } num_bytes__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::ColumnStats => { if column_stats__.is_some() { return Err(serde::de::Error::duplicate_field("columnStats")); } - column_stats__ = Some(map.next_value()?); + column_stats__ = Some(map_.next_value()?); } } } @@ -12932,9 +15579,11 @@ impl serde::Serialize for PartitionedFile { struct_ser.serialize_field("path", &self.path)?; } if self.size != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("size", ToString::to_string(&self.size).as_str())?; } if self.last_modified_ns != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("lastModifiedNs", ToString::to_string(&self.last_modified_ns).as_str())?; } if !self.partition_values.is_empty() { @@ -13010,7 +15659,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { formatter.write_str("struct datafusion.PartitionedFile") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -13019,20 +15668,20 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { let mut last_modified_ns__ = None; let mut partition_values__ = None; let mut range__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Path => { if path__.is_some() { return Err(serde::de::Error::duplicate_field("path")); } - path__ = Some(map.next_value()?); + path__ = Some(map_.next_value()?); } GeneratedField::Size => { if size__.is_some() { return Err(serde::de::Error::duplicate_field("size")); } size__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::LastModifiedNs => { @@ -13040,20 +15689,20 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { return Err(serde::de::Error::duplicate_field("lastModifiedNs")); } last_modified_ns__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::PartitionValues => { if partition_values__.is_some() { return Err(serde::de::Error::duplicate_field("partitionValues")); } - partition_values__ = Some(map.next_value()?); + partition_values__ = Some(map_.next_value()?); } GeneratedField::Range => { if range__.is_some() { return Err(serde::de::Error::duplicate_field("range")); } - range__ = map.next_value()?; + range__ = map_.next_value()?; } } } @@ -13080,6 +15729,9 @@ impl serde::Serialize for PhysicalAggregateExprNode { if !self.expr.is_empty() { len += 1; } + if !self.ordering_req.is_empty() { + len += 1; + } if self.distinct { len += 1; } @@ -13090,14 +15742,17 @@ impl serde::Serialize for PhysicalAggregateExprNode { if !self.expr.is_empty() { struct_ser.serialize_field("expr", &self.expr)?; } + if !self.ordering_req.is_empty() { + struct_ser.serialize_field("orderingReq", &self.ordering_req)?; + } if self.distinct { struct_ser.serialize_field("distinct", &self.distinct)?; } if let Some(v) = self.aggregate_function.as_ref() { match v { physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { - let v = AggregateFunction::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = AggregateFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("aggrFunction", &v)?; } physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { @@ -13116,6 +15771,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { const FIELDS: &[&str] = &[ "expr", + "ordering_req", + "orderingReq", "distinct", "aggr_function", "aggrFunction", @@ -13126,6 +15783,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, + OrderingReq, Distinct, AggrFunction, UserDefinedAggrFunction, @@ -13151,6 +15809,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { match value { "expr" => Ok(GeneratedField::Expr), + "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), "distinct" => Ok(GeneratedField::Distinct), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), @@ -13169,43 +15828,51 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { formatter.write_str("struct datafusion.PhysicalAggregateExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; + let mut ordering_req__ = None; let mut distinct__ = None; let mut aggregate_function__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); + } + GeneratedField::OrderingReq => { + if ordering_req__.is_some() { + return Err(serde::de::Error::duplicate_field("orderingReq")); + } + ordering_req__ = Some(map_.next_value()?); } GeneratedField::Distinct => { if distinct__.is_some() { return Err(serde::de::Error::duplicate_field("distinct")); } - distinct__ = Some(map.next_value()?); + distinct__ = Some(map_.next_value()?); } GeneratedField::AggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); } - aggregate_function__ = map.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); + aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); } GeneratedField::UserDefinedAggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); } - aggregate_function__ = map.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); + aggregate_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); } } } Ok(PhysicalAggregateExprNode { expr: expr__.unwrap_or_default(), + ordering_req: ordering_req__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), aggregate_function: aggregate_function__, }) @@ -13291,25 +15958,25 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { formatter.write_str("struct datafusion.PhysicalAliasNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut alias__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Alias => { if alias__.is_some() { return Err(serde::de::Error::duplicate_field("alias")); } - alias__ = Some(map.next_value()?); + alias__ = Some(map_.next_value()?); } } } @@ -13408,32 +16075,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { formatter.write_str("struct datafusion.PhysicalBinaryExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut l__ = None; let mut r__ = None; let mut op__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::L => { if l__.is_some() { return Err(serde::de::Error::duplicate_field("l")); } - l__ = map.next_value()?; + l__ = map_.next_value()?; } GeneratedField::R => { if r__.is_some() { return Err(serde::de::Error::duplicate_field("r")); } - r__ = map.next_value()?; + r__ = map_.next_value()?; } GeneratedField::Op => { if op__.is_some() { return Err(serde::de::Error::duplicate_field("op")); } - op__ = Some(map.next_value()?); + op__ = Some(map_.next_value()?); } } } @@ -13535,32 +16202,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { formatter.write_str("struct datafusion.PhysicalCaseNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut when_then_expr__ = None; let mut else_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::WhenThenExpr => { if when_then_expr__.is_some() { return Err(serde::de::Error::duplicate_field("whenThenExpr")); } - when_then_expr__ = Some(map.next_value()?); + when_then_expr__ = Some(map_.next_value()?); } GeneratedField::ElseExpr => { if else_expr__.is_some() { return Err(serde::de::Error::duplicate_field("elseExpr")); } - else_expr__ = map.next_value()?; + else_expr__ = map_.next_value()?; } } } @@ -13652,25 +16319,25 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { formatter.write_str("struct datafusion.PhysicalCastNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut arrow_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::ArrowType => { if arrow_type__.is_some() { return Err(serde::de::Error::duplicate_field("arrowType")); } - arrow_type__ = map.next_value()?; + arrow_type__ = map_.next_value()?; } } } @@ -13760,26 +16427,26 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { formatter.write_str("struct datafusion.PhysicalColumn") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; let mut index__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::Index => { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } index__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -13879,32 +16546,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { formatter.write_str("struct datafusion.PhysicalDateTimeIntervalExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut l__ = None; let mut r__ = None; let mut op__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::L => { if l__.is_some() { return Err(serde::de::Error::duplicate_field("l")); } - l__ = map.next_value()?; + l__ = map_.next_value()?; } GeneratedField::R => { if r__.is_some() { return Err(serde::de::Error::duplicate_field("r")); } - r__ = map.next_value()?; + r__ = map_.next_value()?; } GeneratedField::Op => { if op__.is_some() { return Err(serde::de::Error::duplicate_field("op")); } - op__ = Some(map.next_value()?); + op__ = Some(map_.next_value()?); } } } @@ -13980,9 +16647,6 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::ScalarUdf(v) => { struct_ser.serialize_field("scalarUdf", v)?; } - physical_expr_node::ExprType::DateTimeIntervalExpr(v) => { - struct_ser.serialize_field("dateTimeIntervalExpr", v)?; - } physical_expr_node::ExprType::LikeExpr(v) => { struct_ser.serialize_field("likeExpr", v)?; } @@ -14028,8 +16692,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "windowExpr", "scalar_udf", "scalarUdf", - "date_time_interval_expr", - "dateTimeIntervalExpr", "like_expr", "likeExpr", "get_indexed_field_expr", @@ -14054,7 +16716,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { TryCast, WindowExpr, ScalarUdf, - DateTimeIntervalExpr, LikeExpr, GetIndexedFieldExpr, } @@ -14094,7 +16755,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), - "dateTimeIntervalExpr" | "date_time_interval_expr" => Ok(GeneratedField::DateTimeIntervalExpr), "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), "getIndexedFieldExpr" | "get_indexed_field_expr" => Ok(GeneratedField::GetIndexedFieldExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -14112,144 +16772,137 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { formatter.write_str("struct datafusion.PhysicalExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) ; } GeneratedField::Literal => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("literal")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) ; } GeneratedField::BinaryExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("binaryExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) ; } GeneratedField::AggregateExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("aggregateExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) ; } GeneratedField::IsNullExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNullExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) ; } GeneratedField::IsNotNullExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("isNotNullExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) ; } GeneratedField::NotExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("notExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) ; } GeneratedField::Case => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("case")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) ; } GeneratedField::Cast => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("cast")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) ; } GeneratedField::Sort => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("sort")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) ; } GeneratedField::Negative => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("negative")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) ; } GeneratedField::InList => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("inList")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) ; } GeneratedField::ScalarFunction => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("scalarFunction")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarFunction) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarFunction) ; } GeneratedField::TryCast => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("tryCast")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) ; } GeneratedField::WindowExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("windowExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) ; } GeneratedField::ScalarUdf => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("scalarUdf")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) -; - } - GeneratedField::DateTimeIntervalExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dateTimeIntervalExpr")); - } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::DateTimeIntervalExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) ; } GeneratedField::LikeExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("likeExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) ; } GeneratedField::GetIndexedFieldExpr => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("getIndexedFieldExpr")); } - expr_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::GetIndexedFieldExpr) + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::GetIndexedFieldExpr) ; } } @@ -14278,6 +16931,7 @@ impl serde::Serialize for PhysicalExtensionNode { } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionNode", len)?; if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; } if !self.inputs.is_empty() { @@ -14339,27 +16993,27 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { formatter.write_str("struct datafusion.PhysicalExtensionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut node__ = None; let mut inputs__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Node => { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } node__ = - Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } GeneratedField::Inputs => { if inputs__.is_some() { return Err(serde::de::Error::duplicate_field("inputs")); } - inputs__ = Some(map.next_value()?); + inputs__ = Some(map_.next_value()?); } } } @@ -14383,15 +17037,25 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { if self.arg.is_some() { len += 1; } - if self.key.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", len)?; if let Some(v) = self.arg.as_ref() { struct_ser.serialize_field("arg", v)?; } - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + if let Some(v) = self.field.as_ref() { + match v { + physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(v) => { + struct_ser.serialize_field("namedStructFieldExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { + struct_ser.serialize_field("listIndexExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { + struct_ser.serialize_field("listRangeExpr", v)?; + } + } } struct_ser.end() } @@ -14404,13 +17068,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { { const FIELDS: &[&str] = &[ "arg", - "key", + "named_struct_field_expr", + "namedStructFieldExpr", + "list_index_expr", + "listIndexExpr", + "list_range_expr", + "listRangeExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Arg, - Key, + NamedStructFieldExpr, + ListIndexExpr, + ListRangeExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14433,7 +17104,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { { match value { "arg" => Ok(GeneratedField::Arg), - "key" => Ok(GeneratedField::Key), + "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), + "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), + "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14449,31 +17122,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { formatter.write_str("struct datafusion.PhysicalGetIndexedFieldExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut arg__ = None; - let mut key__ = None; - while let Some(k) = map.next_key()? { + let mut field__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Arg => { if arg__.is_some() { return Err(serde::de::Error::duplicate_field("arg")); } - arg__ = map.next_value()?; + arg__ = map_.next_value()?; } - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::NamedStructFieldExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructFieldExpr")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr) +; + } + GeneratedField::ListIndexExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndexExpr")); + } + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) +; + } + GeneratedField::ListRangeExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRangeExpr")); } - key__ = map.next_value()?; + field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) +; } } } Ok(PhysicalGetIndexedFieldExprNode { arg: arg__, - key: key__, + field: field__, }) } } @@ -14499,6 +17187,7 @@ impl serde::Serialize for PhysicalHashRepartition { struct_ser.serialize_field("hashExpr", &self.hash_expr)?; } if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; } struct_ser.end() @@ -14559,26 +17248,26 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { formatter.write_str("struct datafusion.PhysicalHashRepartition") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut hash_expr__ = None; let mut partition_count__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::HashExpr => { if hash_expr__.is_some() { return Err(serde::de::Error::duplicate_field("hashExpr")); } - hash_expr__ = Some(map.next_value()?); + hash_expr__ = Some(map_.next_value()?); } GeneratedField::PartitionCount => { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } partition_count__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -14678,32 +17367,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { formatter.write_str("struct datafusion.PhysicalInListNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut list__ = None; let mut negated__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::List => { if list__.is_some() { return Err(serde::de::Error::duplicate_field("list")); } - list__ = Some(map.next_value()?); + list__ = Some(map_.next_value()?); } GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } } } @@ -14785,18 +17474,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { formatter.write_str("struct datafusion.PhysicalIsNotNull") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -14876,18 +17565,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { formatter.write_str("struct datafusion.PhysicalIsNull") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -14995,7 +17684,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { formatter.write_str("struct datafusion.PhysicalLikeExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -15003,31 +17692,31 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { let mut case_insensitive__ = None; let mut expr__ = None; let mut pattern__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } GeneratedField::CaseInsensitive => { if case_insensitive__.is_some() { return Err(serde::de::Error::duplicate_field("caseInsensitive")); } - case_insensitive__ = Some(map.next_value()?); + case_insensitive__ = Some(map_.next_value()?); } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Pattern => { if pattern__.is_some() { return Err(serde::de::Error::duplicate_field("pattern")); } - pattern__ = map.next_value()?; + pattern__ = map_.next_value()?; } } } @@ -15110,18 +17799,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { formatter.write_str("struct datafusion.PhysicalNegativeNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -15201,18 +17890,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { formatter.write_str("struct datafusion.PhysicalNot") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -15298,6 +17987,24 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { struct_ser.serialize_field("sortPreservingMerge", v)?; } + physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { + struct_ser.serialize_field("nestedLoopJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Analyze(v) => { + struct_ser.serialize_field("analyze", v)?; + } + physical_plan_node::PhysicalPlanType::JsonSink(v) => { + struct_ser.serialize_field("jsonSink", v)?; + } + physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { + struct_ser.serialize_field("symmetricHashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } } } struct_ser.end() @@ -15339,6 +18046,16 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "explain", "sort_preserving_merge", "sortPreservingMerge", + "nested_loop_join", + "nestedLoopJoin", + "analyze", + "json_sink", + "jsonSink", + "symmetric_hash_join", + "symmetricHashJoin", + "interleave", + "placeholder_row", + "placeholderRow", ]; #[allow(clippy::enum_variant_names)] @@ -15363,6 +18080,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { Union, Explain, SortPreservingMerge, + NestedLoopJoin, + Analyze, + JsonSink, + SymmetricHashJoin, + Interleave, + PlaceholderRow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15404,6 +18127,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "union" => Ok(GeneratedField::Union), "explain" => Ok(GeneratedField::Explain), "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), + "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), + "analyze" => Ok(GeneratedField::Analyze), + "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), + "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15419,151 +18148,193 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { formatter.write_str("struct datafusion.PhysicalPlanNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut physical_plan_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::ParquetScan => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("parquetScan")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) ; } GeneratedField::CsvScan => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("csvScan")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) ; } GeneratedField::Empty => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("empty")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) ; } GeneratedField::Projection => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) ; } GeneratedField::GlobalLimit => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("globalLimit")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) ; } GeneratedField::LocalLimit => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("localLimit")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) ; } GeneratedField::Aggregate => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("aggregate")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) ; } GeneratedField::HashJoin => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("hashJoin")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) ; } GeneratedField::Sort => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("sort")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) ; } GeneratedField::CoalesceBatches => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("coalesceBatches")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) ; } GeneratedField::Filter => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) ; } GeneratedField::Merge => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("merge")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) ; } GeneratedField::Repartition => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("repartition")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) ; } GeneratedField::Window => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("window")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) ; } GeneratedField::CrossJoin => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("crossJoin")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) ; } GeneratedField::AvroScan => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("avroScan")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) ; } GeneratedField::Extension => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("extension")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) ; } GeneratedField::Union => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("union")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) ; } GeneratedField::Explain => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("explain")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) ; } GeneratedField::SortPreservingMerge => { if physical_plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); } - physical_plan_type__ = map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) +; + } + GeneratedField::NestedLoopJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) +; + } + GeneratedField::Analyze => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("analyze")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) +; + } + GeneratedField::JsonSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) +; + } + GeneratedField::SymmetricHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) ; } } @@ -15601,8 +18372,8 @@ impl serde::Serialize for PhysicalScalarFunctionNode { struct_ser.serialize_field("name", &self.name)?; } if self.fun != 0 { - let v = ScalarFunction::from_i32(self.fun) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; + let v = ScalarFunction::try_from(self.fun) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; struct_ser.serialize_field("fun", &v)?; } if !self.args.is_empty() { @@ -15674,7 +18445,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { formatter.write_str("struct datafusion.PhysicalScalarFunctionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -15682,31 +18453,31 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarFunctionNode { let mut fun__ = None; let mut args__ = None; let mut return_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::Fun => { if fun__.is_some() { return Err(serde::de::Error::duplicate_field("fun")); } - fun__ = Some(map.next_value::()? as i32); + fun__ = Some(map_.next_value::()? as i32); } GeneratedField::Args => { if args__.is_some() { return Err(serde::de::Error::duplicate_field("args")); } - args__ = Some(map.next_value()?); + args__ = Some(map_.next_value()?); } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); } - return_type__ = map.next_value()?; + return_type__ = map_.next_value()?; } } } @@ -15808,32 +18579,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { formatter.write_str("struct datafusion.PhysicalScalarUdfNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; let mut args__ = None; let mut return_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::Args => { if args__.is_some() { return Err(serde::de::Error::duplicate_field("args")); } - args__ = Some(map.next_value()?); + args__ = Some(map_.next_value()?); } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); } - return_type__ = map.next_value()?; + return_type__ = map_.next_value()?; } } } @@ -15934,32 +18705,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { formatter.write_str("struct datafusion.PhysicalSortExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut asc__ = None; let mut nulls_first__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Asc => { if asc__.is_some() { return Err(serde::de::Error::duplicate_field("asc")); } - asc__ = Some(map.next_value()?); + asc__ = Some(map_.next_value()?); } GeneratedField::NullsFirst => { if nulls_first__.is_some() { return Err(serde::de::Error::duplicate_field("nullsFirst")); } - nulls_first__ = Some(map.next_value()?); + nulls_first__ = Some(map_.next_value()?); } } } @@ -16042,18 +18813,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { formatter.write_str("struct datafusion.PhysicalSortExprNodeCollection") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut physical_sort_expr_nodes__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::PhysicalSortExprNodes => { if physical_sort_expr_nodes__.is_some() { return Err(serde::de::Error::duplicate_field("physicalSortExprNodes")); } - physical_sort_expr_nodes__ = Some(map.next_value()?); + physical_sort_expr_nodes__ = Some(map_.next_value()?); } } } @@ -16143,25 +18914,25 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { formatter.write_str("struct datafusion.PhysicalTryCastNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut arrow_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::ArrowType => { if arrow_type__.is_some() { return Err(serde::de::Error::duplicate_field("arrowType")); } - arrow_type__ = map.next_value()?; + arrow_type__ = map_.next_value()?; } } } @@ -16253,25 +19024,25 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { formatter.write_str("struct datafusion.PhysicalWhenThen") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut when_expr__ = None; let mut then_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::WhenExpr => { if when_expr__.is_some() { return Err(serde::de::Error::duplicate_field("whenExpr")); } - when_expr__ = map.next_value()?; + when_expr__ = map_.next_value()?; } GeneratedField::ThenExpr => { if then_expr__.is_some() { return Err(serde::de::Error::duplicate_field("thenExpr")); } - then_expr__ = map.next_value()?; + then_expr__ = map_.next_value()?; } } } @@ -16292,26 +19063,50 @@ impl serde::Serialize for PhysicalWindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.args.is_empty() { + len += 1; + } + if !self.partition_by.is_empty() { + len += 1; + } + if !self.order_by.is_empty() { + len += 1; + } + if self.window_frame.is_some() { + len += 1; + } + if !self.name.is_empty() { len += 1; } if self.window_function.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if !self.partition_by.is_empty() { + struct_ser.serialize_field("partitionBy", &self.partition_by)?; + } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } + if let Some(v) = self.window_frame.as_ref() { + struct_ser.serialize_field("windowFrame", v)?; + } + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } if let Some(v) = self.window_function.as_ref() { match v { physical_window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = AggregateFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("aggrFunction", &v)?; } physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { - let v = BuiltInWindowFunction::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = BuiltInWindowFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("builtInFunction", &v)?; } } @@ -16326,7 +19121,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "args", + "partition_by", + "partitionBy", + "order_by", + "orderBy", + "window_frame", + "windowFrame", + "name", "aggr_function", "aggrFunction", "built_in_function", @@ -16335,7 +19137,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Args, + PartitionBy, + OrderBy, + WindowFrame, + Name, AggrFunction, BuiltInFunction, } @@ -16359,7 +19165,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "args" => Ok(GeneratedField::Args), + "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), + "name" => Ok(GeneratedField::Name), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -16377,36 +19187,68 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { formatter.write_str("struct datafusion.PhysicalWindowExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut args__ = None; + let mut partition_by__ = None; + let mut order_by__ = None; + let mut window_frame__ = None; + let mut name__ = None; let mut window_function__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); + } + args__ = Some(map_.next_value()?); + } + GeneratedField::PartitionBy => { + if partition_by__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionBy")); + } + partition_by__ = Some(map_.next_value()?); + } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map_.next_value()?); + } + GeneratedField::WindowFrame => { + if window_frame__.is_some() { + return Err(serde::de::Error::duplicate_field("windowFrame")); + } + window_frame__ = map_.next_value()?; + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - expr__ = map.next_value()?; + name__ = Some(map_.next_value()?); } GeneratedField::AggrFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); } - window_function__ = map.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); } - window_function__ = map.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); } } } Ok(PhysicalWindowExprNode { - expr: expr__, + args: args__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), + order_by: order_by__.unwrap_or_default(), + window_frame: window_frame__, + name: name__.unwrap_or_default(), window_function: window_function__, }) } @@ -16492,25 +19334,25 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { formatter.write_str("struct datafusion.PlaceholderNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut id__ = None; let mut data_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Id => { if id__.is_some() { return Err(serde::de::Error::duplicate_field("id")); } - id__ = Some(map.next_value()?); + id__ = Some(map_.next_value()?); } GeneratedField::DataType => { if data_type__.is_some() { return Err(serde::de::Error::duplicate_field("dataType")); } - data_type__ = map.next_value()?; + data_type__ = map_.next_value()?; } } } @@ -16523,6 +19365,97 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PlaceholderRowExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlaceholderRowExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlaceholderRowExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(PlaceholderRowExecNode { + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -16555,12 +19488,18 @@ impl serde::Serialize for PlanType { plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { struct_ser.serialize_field("InitialPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; + } plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; } plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { struct_ser.serialize_field("FinalPhysicalPlan", v)?; } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; + } } } struct_ser.end() @@ -16579,8 +19518,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan", "FinalLogicalPlan", "InitialPhysicalPlan", + "InitialPhysicalPlanWithStats", "OptimizedPhysicalPlan", "FinalPhysicalPlan", + "FinalPhysicalPlanWithStats", ]; #[allow(clippy::enum_variant_names)] @@ -16591,8 +19532,10 @@ impl<'de> serde::Deserialize<'de> for PlanType { OptimizedLogicalPlan, FinalLogicalPlan, InitialPhysicalPlan, + InitialPhysicalPlanWithStats, OptimizedPhysicalPlan, FinalPhysicalPlan, + FinalPhysicalPlanWithStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16620,8 +19563,177 @@ impl<'de> serde::Deserialize<'de> for PlanType { "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), + "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PlanType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PlanType") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut plan_type_enum__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::InitialLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) +; + } + GeneratedField::AnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) +; + } + GeneratedField::FinalAnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) +; + } + GeneratedField::OptimizedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) +; + } + GeneratedField::FinalLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) +; + } + GeneratedField::InitialPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) +; + } + GeneratedField::InitialPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) +; + } + GeneratedField::OptimizedPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) +; + } + } + } + Ok(PlanType { + plan_type_enum: plan_type_enum__, + }) + } + } + deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Precision { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision_info != 0 { + len += 1; + } + if self.val.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Precision", len)?; + if self.precision_info != 0 { + let v = PrecisionInfo::try_from(self.precision_info) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.precision_info)))?; + struct_ser.serialize_field("precisionInfo", &v)?; + } + if let Some(v) = self.val.as_ref() { + struct_ser.serialize_field("val", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Precision { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision_info", + "precisionInfo", + "val", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + PrecisionInfo, + Val, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precisionInfo" | "precision_info" => Ok(GeneratedField::PrecisionInfo), + "val" => Ok(GeneratedField::Val), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16631,83 +19743,115 @@ impl<'de> serde::Deserialize<'de> for PlanType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlanType; + type Value = Precision; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlanType") + formatter.write_str("struct datafusion.Precision") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut plan_type_enum__ = None; - while let Some(k) = map.next_key()? { + let mut precision_info__ = None; + let mut val__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::InitialLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) -; - } - GeneratedField::AnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) -; - } - GeneratedField::FinalAnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) -; - } - GeneratedField::OptimizedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) -; - } - GeneratedField::FinalLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); + GeneratedField::PrecisionInfo => { + if precision_info__.is_some() { + return Err(serde::de::Error::duplicate_field("precisionInfo")); } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) -; - } - GeneratedField::InitialPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) -; - } - GeneratedField::OptimizedPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); - } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) -; + precision_info__ = Some(map_.next_value::()? as i32); } - GeneratedField::FinalPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); + GeneratedField::Val => { + if val__.is_some() { + return Err(serde::de::Error::duplicate_field("val")); } - plan_type_enum__ = map.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) -; + val__ = map_.next_value()?; } } } - Ok(PlanType { - plan_type_enum: plan_type_enum__, + Ok(Precision { + precision_info: precision_info__.unwrap_or_default(), + val: val__, }) } } - deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Precision", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PrecisionInfo { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for PrecisionInfo { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "EXACT", + "INEXACT", + "ABSENT", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrecisionInfo; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "EXACT" => Ok(PrecisionInfo::Exact), + "INEXACT" => Ok(PrecisionInfo::Inexact), + "ABSENT" => Ok(PrecisionInfo::Absent), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) } } impl serde::Serialize for PrepareNode { @@ -16797,32 +19941,32 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { formatter.write_str("struct datafusion.PrepareNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; let mut data_types__ = None; let mut input__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = Some(map_.next_value()?); } GeneratedField::DataTypes => { if data_types__.is_some() { return Err(serde::de::Error::duplicate_field("dataTypes")); } - data_types__ = Some(map.next_value()?); + data_types__ = Some(map_.next_value()?); } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } } } @@ -16836,6 +19980,100 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PrimaryKeyConstraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.indices.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PrimaryKeyConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "indices", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Indices, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "indices" => Ok(GeneratedField::Indices), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrimaryKeyConstraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PrimaryKeyConstraint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut indices__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); + } + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(PrimaryKeyConstraint { + indices: indices__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ProjectionColumns { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -16904,18 +20142,18 @@ impl<'de> serde::Deserialize<'de> for ProjectionColumns { formatter.write_str("struct datafusion.ProjectionColumns") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut columns__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Columns => { if columns__.is_some() { return Err(serde::de::Error::duplicate_field("columns")); } - columns__ = Some(map.next_value()?); + columns__ = Some(map_.next_value()?); } } } @@ -17014,32 +20252,32 @@ impl<'de> serde::Deserialize<'de> for ProjectionExecNode { formatter.write_str("struct datafusion.ProjectionExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; let mut expr_name__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } GeneratedField::ExprName => { if expr_name__.is_some() { return Err(serde::de::Error::duplicate_field("exprName")); } - expr_name__ = Some(map.next_value()?); + expr_name__ = Some(map_.next_value()?); } } } @@ -17143,32 +20381,32 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { formatter.write_str("struct datafusion.ProjectionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; let mut optional_alias__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } GeneratedField::Alias => { if optional_alias__.is_some() { return Err(serde::de::Error::duplicate_field("alias")); } - optional_alias__ = map.next_value::<::std::option::Option<_>>()?.map(projection_node::OptionalAlias::Alias); + optional_alias__ = map_.next_value::<::std::option::Option<_>>()?.map(projection_node::OptionalAlias::Alias); } } } @@ -17203,12 +20441,14 @@ impl serde::Serialize for RepartitionExecNode { if let Some(v) = self.partition_method.as_ref() { match v { repartition_exec_node::PartitionMethod::RoundRobin(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; } repartition_exec_node::PartitionMethod::Hash(v) => { struct_ser.serialize_field("hash", v)?; } repartition_exec_node::PartitionMethod::Unknown(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("unknown", ToString::to_string(&v).as_str())?; } } @@ -17276,38 +20516,38 @@ impl<'de> serde::Deserialize<'de> for RepartitionExecNode { formatter.write_str("struct datafusion.RepartitionExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut partition_method__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::RoundRobin => { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("roundRobin")); } - partition_method__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::RoundRobin(x.0)); + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::RoundRobin(x.0)); } GeneratedField::Hash => { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("hash")); } - partition_method__ = map.next_value::<::std::option::Option<_>>()?.map(repartition_exec_node::PartitionMethod::Hash) + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_exec_node::PartitionMethod::Hash) ; } GeneratedField::Unknown => { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("unknown")); } - partition_method__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::Unknown(x.0)); + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::Unknown(x.0)); } } } @@ -17341,6 +20581,7 @@ impl serde::Serialize for RepartitionNode { if let Some(v) = self.partition_method.as_ref() { match v { repartition_node::PartitionMethod::RoundRobin(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; } repartition_node::PartitionMethod::Hash(v) => { @@ -17408,31 +20649,31 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { formatter.write_str("struct datafusion.RepartitionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut partition_method__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::RoundRobin => { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("roundRobin")); } - partition_method__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_node::PartitionMethod::RoundRobin(x.0)); + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_node::PartitionMethod::RoundRobin(x.0)); } GeneratedField::Hash => { if partition_method__.is_some() { return Err(serde::de::Error::duplicate_field("hash")); } - partition_method__ = map.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) ; } } @@ -17514,18 +20755,18 @@ impl<'de> serde::Deserialize<'de> for RollupNode { formatter.write_str("struct datafusion.RollupNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } } } @@ -17615,25 +20856,25 @@ impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { formatter.write_str("struct datafusion.ScalarDictionaryValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut index_type__ = None; let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::IndexType => { if index_type__.is_some() { return Err(serde::de::Error::duplicate_field("indexType")); } - index_type__ = map.next_value()?; + index_type__ = map_.next_value()?; } GeneratedField::Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = map.next_value()?; + value__ = map_.next_value()?; } } } @@ -17662,6 +20903,7 @@ impl serde::Serialize for ScalarFixedSizeBinary { } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarFixedSizeBinary", len)?; if !self.values.is_empty() { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("values", pbjson::private::base64::encode(&self.values).as_str())?; } if self.length != 0 { @@ -17723,20 +20965,20 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { formatter.write_str("struct datafusion.ScalarFixedSizeBinary") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut values__ = None; let mut length__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Values => { if values__.is_some() { return Err(serde::de::Error::duplicate_field("values")); } values__ = - Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } GeneratedField::Length => { @@ -17744,7 +20986,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { return Err(serde::de::Error::duplicate_field("length")); } length__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -17854,7 +21096,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayAppend => "ArrayAppend", Self::ArrayConcat => "ArrayConcat", Self::ArrayDims => "ArrayDims", - Self::ArrayFill => "ArrayFill", + Self::ArrayRepeat => "ArrayRepeat", Self::ArrayLength => "ArrayLength", Self::ArrayNdims => "ArrayNdims", Self::ArrayPosition => "ArrayPosition", @@ -17864,7 +21106,37 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplace => "ArrayReplace", Self::ArrayToString => "ArrayToString", Self::Cardinality => "Cardinality", - Self::TrimArray => "TrimArray", + Self::ArrayElement => "ArrayElement", + Self::ArraySlice => "ArraySlice", + Self::Encode => "Encode", + Self::Decode => "Decode", + Self::Cot => "Cot", + Self::ArrayHas => "ArrayHas", + Self::ArrayHasAny => "ArrayHasAny", + Self::ArrayHasAll => "ArrayHasAll", + Self::ArrayRemoveN => "ArrayRemoveN", + Self::ArrayReplaceN => "ArrayReplaceN", + Self::ArrayRemoveAll => "ArrayRemoveAll", + Self::ArrayReplaceAll => "ArrayReplaceAll", + Self::Nanvl => "Nanvl", + Self::Flatten => "Flatten", + Self::Isnan => "Isnan", + Self::Iszero => "Iszero", + Self::ArrayEmpty => "ArrayEmpty", + Self::ArrayPopBack => "ArrayPopBack", + Self::StringToArray => "StringToArray", + Self::ToTimestampNanos => "ToTimestampNanos", + Self::ArrayIntersect => "ArrayIntersect", + Self::ArrayUnion => "ArrayUnion", + Self::OverLay => "OverLay", + Self::Range => "Range", + Self::ArrayExcept => "ArrayExcept", + Self::ArrayPopFront => "ArrayPopFront", + Self::Levenshtein => "Levenshtein", + Self::SubstrIndex => "SubstrIndex", + Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", + Self::ArrayDistinct => "ArrayDistinct", }; serializer.serialize_str(variant) } @@ -17965,7 +21237,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend", "ArrayConcat", "ArrayDims", - "ArrayFill", + "ArrayRepeat", "ArrayLength", "ArrayNdims", "ArrayPosition", @@ -17975,7 +21247,37 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace", "ArrayToString", "Cardinality", - "TrimArray", + "ArrayElement", + "ArraySlice", + "Encode", + "Decode", + "Cot", + "ArrayHas", + "ArrayHasAny", + "ArrayHasAll", + "ArrayRemoveN", + "ArrayReplaceN", + "ArrayRemoveAll", + "ArrayReplaceAll", + "Nanvl", + "Flatten", + "Isnan", + "Iszero", + "ArrayEmpty", + "ArrayPopBack", + "StringToArray", + "ToTimestampNanos", + "ArrayIntersect", + "ArrayUnion", + "OverLay", + "Range", + "ArrayExcept", + "ArrayPopFront", + "Levenshtein", + "SubstrIndex", + "FindInSet", + "ArraySort", + "ArrayDistinct", ]; struct GeneratedVisitor; @@ -17991,10 +21293,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(ScalarFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -18004,10 +21305,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(ScalarFunction::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -18107,7 +21407,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), "ArrayDims" => Ok(ScalarFunction::ArrayDims), - "ArrayFill" => Ok(ScalarFunction::ArrayFill), + "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), "ArrayLength" => Ok(ScalarFunction::ArrayLength), "ArrayNdims" => Ok(ScalarFunction::ArrayNdims), "ArrayPosition" => Ok(ScalarFunction::ArrayPosition), @@ -18117,7 +21417,37 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "ArrayToString" => Ok(ScalarFunction::ArrayToString), "Cardinality" => Ok(ScalarFunction::Cardinality), - "TrimArray" => Ok(ScalarFunction::TrimArray), + "ArrayElement" => Ok(ScalarFunction::ArrayElement), + "ArraySlice" => Ok(ScalarFunction::ArraySlice), + "Encode" => Ok(ScalarFunction::Encode), + "Decode" => Ok(ScalarFunction::Decode), + "Cot" => Ok(ScalarFunction::Cot), + "ArrayHas" => Ok(ScalarFunction::ArrayHas), + "ArrayHasAny" => Ok(ScalarFunction::ArrayHasAny), + "ArrayHasAll" => Ok(ScalarFunction::ArrayHasAll), + "ArrayRemoveN" => Ok(ScalarFunction::ArrayRemoveN), + "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), + "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), + "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), + "Nanvl" => Ok(ScalarFunction::Nanvl), + "Flatten" => Ok(ScalarFunction::Flatten), + "Isnan" => Ok(ScalarFunction::Isnan), + "Iszero" => Ok(ScalarFunction::Iszero), + "ArrayEmpty" => Ok(ScalarFunction::ArrayEmpty), + "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), + "StringToArray" => Ok(ScalarFunction::StringToArray), + "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), + "ArrayUnion" => Ok(ScalarFunction::ArrayUnion), + "OverLay" => Ok(ScalarFunction::OverLay), + "Range" => Ok(ScalarFunction::Range), + "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), + "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), + "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), + "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), + "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -18141,8 +21471,8 @@ impl serde::Serialize for ScalarFunctionNode { } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarFunctionNode", len)?; if self.fun != 0 { - let v = ScalarFunction::from_i32(self.fun) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; + let v = ScalarFunction::try_from(self.fun) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.fun)))?; struct_ser.serialize_field("fun", &v)?; } if !self.args.is_empty() { @@ -18204,25 +21534,25 @@ impl<'de> serde::Deserialize<'de> for ScalarFunctionNode { formatter.write_str("struct datafusion.ScalarFunctionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut fun__ = None; let mut args__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Fun => { if fun__.is_some() { return Err(serde::de::Error::duplicate_field("fun")); } - fun__ = Some(map.next_value::()? as i32); + fun__ = Some(map_.next_value::()? as i32); } GeneratedField::Args => { if args__.is_some() { return Err(serde::de::Error::duplicate_field("args")); } - args__ = Some(map.next_value()?); + args__ = Some(map_.next_value()?); } } } @@ -18243,24 +21573,26 @@ impl serde::Serialize for ScalarListValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.is_null { + if !self.ipc_message.is_empty() { len += 1; } - if self.field.is_some() { + if !self.arrow_data.is_empty() { len += 1; } - if !self.values.is_empty() { + if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarListValue", len)?; - if self.is_null { - struct_ser.serialize_field("isNull", &self.is_null)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } - if !self.values.is_empty() { - struct_ser.serialize_field("values", &self.values)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } @@ -18272,17 +21604,18 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "is_null", - "isNull", - "field", - "values", + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - IsNull, - Field, - Values, + IpcMessage, + ArrowData, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18304,9 +21637,9 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { E: serde::de::Error, { match value { - "isNull" | "is_null" => Ok(GeneratedField::IsNull), - "field" => Ok(GeneratedField::Field), - "values" => Ok(GeneratedField::Values), + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18322,39 +21655,43 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { formatter.write_str("struct datafusion.ScalarListValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut is_null__ = None; - let mut field__ = None; - let mut values__ = None; - while let Some(k) = map.next_key()? { + let mut ipc_message__ = None; + let mut arrow_data__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::IsNull => { - if is_null__.is_some() { - return Err(serde::de::Error::duplicate_field("isNull")); + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); } - is_null__ = Some(map.next_value()?); + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); } - field__ = map.next_value()?; + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Values => { - if values__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - values__ = Some(map.next_value()?); + schema__ = map_.next_value()?; } } } Ok(ScalarListValue { - is_null: is_null__.unwrap_or_default(), - field: field__, - values: values__.unwrap_or_default(), + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + schema: schema__, }) } } @@ -18441,24 +21778,24 @@ impl<'de> serde::Deserialize<'de> for ScalarTime32Value { formatter.write_str("struct datafusion.ScalarTime32Value") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Time32SecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time32SecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32SecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32SecondValue(x.0)); } GeneratedField::Time32MillisecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time32MillisecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32MillisecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32MillisecondValue(x.0)); } } } @@ -18485,9 +21822,11 @@ impl serde::Serialize for ScalarTime64Value { if let Some(v) = self.value.as_ref() { match v { scalar_time64_value::Value::Time64MicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("time64MicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_time64_value::Value::Time64NanosecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("time64NanosecondValue", ToString::to_string(&v).as_str())?; } } @@ -18550,24 +21889,24 @@ impl<'de> serde::Deserialize<'de> for ScalarTime64Value { formatter.write_str("struct datafusion.ScalarTime64Value") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Time64MicrosecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time64MicrosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64MicrosecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64MicrosecondValue(x.0)); } GeneratedField::Time64NanosecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time64NanosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64NanosecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64NanosecondValue(x.0)); } } } @@ -18600,15 +21939,19 @@ impl serde::Serialize for ScalarTimestampValue { if let Some(v) = self.value.as_ref() { match v { scalar_timestamp_value::Value::TimeMicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("timeMicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeNanosecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("timeNanosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeSecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("timeSecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeMillisecondValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("timeMillisecondValue", ToString::to_string(&v).as_str())?; } } @@ -18682,43 +22025,43 @@ impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { formatter.write_str("struct datafusion.ScalarTimestampValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut timezone__ = None; let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Timezone => { if timezone__.is_some() { return Err(serde::de::Error::duplicate_field("timezone")); } - timezone__ = Some(map.next_value()?); + timezone__ = Some(map_.next_value()?); } GeneratedField::TimeMicrosecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timeMicrosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMicrosecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMicrosecondValue(x.0)); } GeneratedField::TimeNanosecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timeNanosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeNanosecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeNanosecondValue(x.0)); } GeneratedField::TimeSecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timeSecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeSecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeSecondValue(x.0)); } GeneratedField::TimeMillisecondValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timeMillisecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMillisecondValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMillisecondValue(x.0)); } } } @@ -18809,25 +22152,25 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { formatter.write_str("struct datafusion.ScalarUDFExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut fun_name__ = None; let mut args__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { if fun_name__.is_some() { return Err(serde::de::Error::duplicate_field("funName")); } - fun_name__ = Some(map.next_value()?); + fun_name__ = Some(map_.next_value()?); } GeneratedField::Args => { if args__.is_some() { return Err(serde::de::Error::duplicate_field("args")); } - args__ = Some(map.next_value()?); + args__ = Some(map_.next_value()?); } } } @@ -18876,6 +22219,7 @@ impl serde::Serialize for ScalarValue { struct_ser.serialize_field("int32Value", v)?; } scalar_value::Value::Int64Value(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("int64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Uint8Value(v) => { @@ -18888,6 +22232,7 @@ impl serde::Serialize for ScalarValue { struct_ser.serialize_field("uint32Value", v)?; } scalar_value::Value::Uint64Value(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("uint64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Float32Value(v) => { @@ -18902,21 +22247,48 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Time32Value(v) => { struct_ser.serialize_field("time32Value", v)?; } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } + scalar_value::Value::FixedSizeListValue(v) => { + struct_ser.serialize_field("fixedSizeListValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } + scalar_value::Value::Decimal256Value(v) => { + struct_ser.serialize_field("decimal256Value", v)?; + } scalar_value::Value::Date64Value(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::IntervalYearmonthValue(v) => { struct_ser.serialize_field("intervalYearmonthValue", v)?; } scalar_value::Value::IntervalDaytimeValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("intervalDaytimeValue", ToString::to_string(&v).as_str())?; } + scalar_value::Value::DurationSecondValue(v) => { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMillisecondValue(v) => { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationNanosecondValue(v) => { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; + } scalar_value::Value::TimestampValue(v) => { struct_ser.serialize_field("timestampValue", v)?; } @@ -18924,9 +22296,11 @@ impl serde::Serialize for ScalarValue { struct_ser.serialize_field("dictionaryValue", v)?; } scalar_value::Value::BinaryValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("binaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::LargeBinaryValue(v) => { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::Time64Value(v) => { @@ -18985,16 +22359,30 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value", "time32_value", "time32Value", + "large_list_value", + "largeListValue", "list_value", "listValue", + "fixed_size_list_value", + "fixedSizeListValue", "decimal128_value", "decimal128Value", + "decimal256_value", + "decimal256Value", "date_64_value", "date64Value", "interval_yearmonth_value", "intervalYearmonthValue", "interval_daytime_value", "intervalDaytimeValue", + "duration_second_value", + "durationSecondValue", + "duration_millisecond_value", + "durationMillisecondValue", + "duration_microsecond_value", + "durationMicrosecondValue", + "duration_nanosecond_value", + "durationNanosecondValue", "timestamp_value", "timestampValue", "dictionary_value", @@ -19031,11 +22419,18 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Float64Value, Date32Value, Time32Value, + LargeListValue, ListValue, + FixedSizeListValue, Decimal128Value, + Decimal256Value, Date64Value, IntervalYearmonthValue, IntervalDaytimeValue, + DurationSecondValue, + DurationMillisecondValue, + DurationMicrosecondValue, + DurationNanosecondValue, TimestampValue, DictionaryValue, BinaryValue, @@ -19081,11 +22476,18 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), + "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), + "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), + "durationSecondValue" | "duration_second_value" => Ok(GeneratedField::DurationSecondValue), + "durationMillisecondValue" | "duration_millisecond_value" => Ok(GeneratedField::DurationMillisecondValue), + "durationMicrosecondValue" | "duration_microsecond_value" => Ok(GeneratedField::DurationMicrosecondValue), + "durationNanosecondValue" | "duration_nanosecond_value" => Ok(GeneratedField::DurationNanosecondValue), "timestampValue" | "timestamp_value" => Ok(GeneratedField::TimestampValue), "dictionaryValue" | "dictionary_value" => Ok(GeneratedField::DictionaryValue), "binaryValue" | "binary_value" => Ok(GeneratedField::BinaryValue), @@ -19109,195 +22511,240 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { formatter.write_str("struct datafusion.ScalarValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::NullValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("nullValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::NullValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::NullValue) ; } GeneratedField::BoolValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("boolValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::BoolValue); + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::BoolValue); } GeneratedField::Utf8Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("utf8Value")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8Value); + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8Value); } GeneratedField::LargeUtf8Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("largeUtf8Value")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeUtf8Value); + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeUtf8Value); } GeneratedField::Int8Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("int8Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int8Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int8Value(x.0)); } GeneratedField::Int16Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("int16Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int16Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int16Value(x.0)); } GeneratedField::Int32Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("int32Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int32Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int32Value(x.0)); } GeneratedField::Int64Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("int64Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int64Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int64Value(x.0)); } GeneratedField::Uint8Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("uint8Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint8Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint8Value(x.0)); } GeneratedField::Uint16Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("uint16Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint16Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint16Value(x.0)); } GeneratedField::Uint32Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("uint32Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint32Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint32Value(x.0)); } GeneratedField::Uint64Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("uint64Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint64Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint64Value(x.0)); } GeneratedField::Float32Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("float32Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float32Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float32Value(x.0)); } GeneratedField::Float64Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("float64Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float64Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float64Value(x.0)); } GeneratedField::Date32Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("date32Value")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date32Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date32Value(x.0)); } GeneratedField::Time32Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time32Value")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) ; } GeneratedField::ListValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("listValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) +; + } + GeneratedField::FixedSizeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) ; } GeneratedField::Decimal128Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("decimal128Value")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) -; + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) +; + } + GeneratedField::Decimal256Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal256Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value) +; + } + GeneratedField::Date64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("date64Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date64Value(x.0)); + } + GeneratedField::IntervalYearmonthValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("intervalYearmonthValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalYearmonthValue(x.0)); + } + GeneratedField::IntervalDaytimeValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("intervalDaytimeValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalDaytimeValue(x.0)); + } + GeneratedField::DurationSecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationSecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationSecondValue(x.0)); } - GeneratedField::Date64Value => { + GeneratedField::DurationMillisecondValue => { if value__.is_some() { - return Err(serde::de::Error::duplicate_field("date64Value")); + return Err(serde::de::Error::duplicate_field("durationMillisecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date64Value(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMillisecondValue(x.0)); } - GeneratedField::IntervalYearmonthValue => { + GeneratedField::DurationMicrosecondValue => { if value__.is_some() { - return Err(serde::de::Error::duplicate_field("intervalYearmonthValue")); + return Err(serde::de::Error::duplicate_field("durationMicrosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalYearmonthValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMicrosecondValue(x.0)); } - GeneratedField::IntervalDaytimeValue => { + GeneratedField::DurationNanosecondValue => { if value__.is_some() { - return Err(serde::de::Error::duplicate_field("intervalDaytimeValue")); + return Err(serde::de::Error::duplicate_field("durationNanosecondValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalDaytimeValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationNanosecondValue(x.0)); } GeneratedField::TimestampValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("timestampValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::TimestampValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::TimestampValue) ; } GeneratedField::DictionaryValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("dictionaryValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::DictionaryValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::DictionaryValue) ; } GeneratedField::BinaryValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("binaryValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryValue(x.0)); } GeneratedField::LargeBinaryValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("largeBinaryValue")); } - value__ = map.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::LargeBinaryValue(x.0)); + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::LargeBinaryValue(x.0)); } GeneratedField::Time64Value => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("time64Value")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time64Value) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time64Value) ; } GeneratedField::IntervalMonthDayNano => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("intervalMonthDayNano")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalMonthDayNano) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalMonthDayNano) ; } GeneratedField::StructValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("structValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) ; } GeneratedField::FixedSizeBinaryValue => { if value__.is_some() { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } - value__ = map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) ; } } @@ -19378,19 +22825,19 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { formatter.write_str("struct datafusion.ScanLimit") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut limit__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Limit => { if limit__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); } limit__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -19414,10 +22861,16 @@ impl serde::Serialize for Schema { if !self.columns.is_empty() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.Schema", len)?; if !self.columns.is_empty() { struct_ser.serialize_field("columns", &self.columns)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } struct_ser.end() } } @@ -19429,11 +22882,13 @@ impl<'de> serde::Deserialize<'de> for Schema { { const FIELDS: &[&str] = &[ "columns", + "metadata", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Columns, + Metadata, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19456,6 +22911,7 @@ impl<'de> serde::Deserialize<'de> for Schema { { match value { "columns" => Ok(GeneratedField::Columns), + "metadata" => Ok(GeneratedField::Metadata), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19471,23 +22927,33 @@ impl<'de> serde::Deserialize<'de> for Schema { formatter.write_str("struct datafusion.Schema") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut columns__ = None; - while let Some(k) = map.next_key()? { + let mut metadata__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Columns => { if columns__.is_some() { return Err(serde::de::Error::duplicate_field("columns")); } - columns__ = Some(map.next_value()?); + columns__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); } } } Ok(Schema { columns: columns__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), }) } } @@ -19562,18 +23028,18 @@ impl<'de> serde::Deserialize<'de> for SelectionExecNode { formatter.write_str("struct datafusion.SelectionExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -19662,25 +23128,25 @@ impl<'de> serde::Deserialize<'de> for SelectionNode { formatter.write_str("struct datafusion.SelectionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } } } @@ -19789,7 +23255,7 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { formatter.write_str("struct datafusion.SimilarToNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -19797,31 +23263,31 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { let mut expr__ = None; let mut pattern__ = None; let mut escape_char__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Negated => { if negated__.is_some() { return Err(serde::de::Error::duplicate_field("negated")); } - negated__ = Some(map.next_value()?); + negated__ = Some(map_.next_value()?); } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Pattern => { if pattern__.is_some() { return Err(serde::de::Error::duplicate_field("pattern")); } - pattern__ = map.next_value()?; + pattern__ = map_.next_value()?; } GeneratedField::EscapeChar => { if escape_char__.is_some() { return Err(serde::de::Error::duplicate_field("escapeChar")); } - escape_char__ = Some(map.next_value()?); + escape_char__ = Some(map_.next_value()?); } } } @@ -19864,6 +23330,7 @@ impl serde::Serialize for SortExecNode { struct_ser.serialize_field("expr", &self.expr)?; } if self.fetch != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } if self.preserve_partitioning { @@ -19932,7 +23399,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { formatter.write_str("struct datafusion.SortExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -19940,33 +23407,33 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { let mut expr__ = None; let mut fetch__ = None; let mut preserve_partitioning__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } GeneratedField::Fetch => { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::PreservePartitioning => { if preserve_partitioning__.is_some() { return Err(serde::de::Error::duplicate_field("preservePartitioning")); } - preserve_partitioning__ = Some(map.next_value()?); + preserve_partitioning__ = Some(map_.next_value()?); } } } @@ -20068,32 +23535,32 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { formatter.write_str("struct datafusion.SortExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut asc__ = None; let mut nulls_first__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::Asc => { if asc__.is_some() { return Err(serde::de::Error::duplicate_field("asc")); } - asc__ = Some(map.next_value()?); + asc__ = Some(map_.next_value()?); } GeneratedField::NullsFirst => { if nulls_first__.is_some() { return Err(serde::de::Error::duplicate_field("nullsFirst")); } - nulls_first__ = Some(map.next_value()?); + nulls_first__ = Some(map_.next_value()?); } } } @@ -20132,6 +23599,7 @@ impl serde::Serialize for SortNode { struct_ser.serialize_field("expr", &self.expr)?; } if self.fetch != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() @@ -20193,33 +23661,33 @@ impl<'de> serde::Deserialize<'de> for SortNode { formatter.write_str("struct datafusion.SortNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; let mut fetch__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); } GeneratedField::Fetch => { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } @@ -20248,6 +23716,9 @@ impl serde::Serialize for SortPreservingMergeExecNode { if !self.expr.is_empty() { len += 1; } + if self.fetch != 0 { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -20255,6 +23726,10 @@ impl serde::Serialize for SortPreservingMergeExecNode { if !self.expr.is_empty() { struct_ser.serialize_field("expr", &self.expr)?; } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + } struct_ser.end() } } @@ -20267,12 +23742,14 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { const FIELDS: &[&str] = &[ "input", "expr", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, Expr, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20296,6 +23773,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { match value { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20311,31 +23789,41 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { formatter.write_str("struct datafusion.SortPreservingMergeExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; - while let Some(k) = map.next_key()? { + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map.next_value()?); + expr__ = Some(map_.next_value()?); + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } Ok(SortPreservingMergeExecNode { input: input__, expr: expr__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), }) } } @@ -20350,31 +23838,25 @@ impl serde::Serialize for Statistics { { use serde::ser::SerializeStruct; let mut len = 0; - if self.num_rows != 0 { + if self.num_rows.is_some() { len += 1; } - if self.total_byte_size != 0 { + if self.total_byte_size.is_some() { len += 1; } if !self.column_stats.is_empty() { len += 1; } - if self.is_exact { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion.Statistics", len)?; - if self.num_rows != 0 { - struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; + if let Some(v) = self.num_rows.as_ref() { + struct_ser.serialize_field("numRows", v)?; } - if self.total_byte_size != 0 { - struct_ser.serialize_field("totalByteSize", ToString::to_string(&self.total_byte_size).as_str())?; + if let Some(v) = self.total_byte_size.as_ref() { + struct_ser.serialize_field("totalByteSize", v)?; } if !self.column_stats.is_empty() { struct_ser.serialize_field("columnStats", &self.column_stats)?; } - if self.is_exact { - struct_ser.serialize_field("isExact", &self.is_exact)?; - } struct_ser.end() } } @@ -20391,8 +23873,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { "totalByteSize", "column_stats", "columnStats", - "is_exact", - "isExact", ]; #[allow(clippy::enum_variant_names)] @@ -20400,7 +23880,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { NumRows, TotalByteSize, ColumnStats, - IsExact, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20425,7 +23904,6 @@ impl<'de> serde::Deserialize<'de> for Statistics { "numRows" | "num_rows" => Ok(GeneratedField::NumRows), "totalByteSize" | "total_byte_size" => Ok(GeneratedField::TotalByteSize), "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), - "isExact" | "is_exact" => Ok(GeneratedField::IsExact), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20441,57 +23919,116 @@ impl<'de> serde::Deserialize<'de> for Statistics { formatter.write_str("struct datafusion.Statistics") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut num_rows__ = None; let mut total_byte_size__ = None; let mut column_stats__ = None; - let mut is_exact__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::NumRows => { if num_rows__.is_some() { return Err(serde::de::Error::duplicate_field("numRows")); } - num_rows__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + num_rows__ = map_.next_value()?; } GeneratedField::TotalByteSize => { if total_byte_size__.is_some() { return Err(serde::de::Error::duplicate_field("totalByteSize")); } - total_byte_size__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + total_byte_size__ = map_.next_value()?; } GeneratedField::ColumnStats => { if column_stats__.is_some() { return Err(serde::de::Error::duplicate_field("columnStats")); } - column_stats__ = Some(map.next_value()?); - } - GeneratedField::IsExact => { - if is_exact__.is_some() { - return Err(serde::de::Error::duplicate_field("isExact")); - } - is_exact__ = Some(map.next_value()?); + column_stats__ = Some(map_.next_value()?); } } } Ok(Statistics { - num_rows: num_rows__.unwrap_or_default(), - total_byte_size: total_byte_size__.unwrap_or_default(), + num_rows: num_rows__, + total_byte_size: total_byte_size__, column_stats: column_stats__.unwrap_or_default(), - is_exact: is_exact__.unwrap_or_default(), }) } } deserializer.deserialize_struct("datafusion.Statistics", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for StreamPartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for StreamPartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "SINGLE_PARTITION", + "PARTITIONED_EXEC", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = StreamPartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), + "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for StringifiedPlan { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20570,25 +24107,25 @@ impl<'de> serde::Deserialize<'de> for StringifiedPlan { formatter.write_str("struct datafusion.StringifiedPlan") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut plan_type__ = None; let mut plan__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::PlanType => { if plan_type__.is_some() { return Err(serde::de::Error::duplicate_field("planType")); } - plan_type__ = map.next_value()?; + plan_type__ = map_.next_value()?; } GeneratedField::Plan => { if plan__.is_some() { return Err(serde::de::Error::duplicate_field("plan")); } - plan__ = Some(map.next_value()?); + plan__ = Some(map_.next_value()?); } } } @@ -20670,18 +24207,18 @@ impl<'de> serde::Deserialize<'de> for Struct { formatter.write_str("struct datafusion.Struct") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut sub_field_types__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::SubFieldTypes => { if sub_field_types__.is_some() { return Err(serde::de::Error::duplicate_field("subFieldTypes")); } - sub_field_types__ = Some(map.next_value()?); + sub_field_types__ = Some(map_.next_value()?); } } } @@ -20771,25 +24308,25 @@ impl<'de> serde::Deserialize<'de> for StructValue { formatter.write_str("struct datafusion.StructValue") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut field_values__ = None; let mut fields__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::FieldValues => { if field_values__.is_some() { return Err(serde::de::Error::duplicate_field("fieldValues")); } - field_values__ = Some(map.next_value()?); + field_values__ = Some(map_.next_value()?); } GeneratedField::Fields => { if fields__.is_some() { return Err(serde::de::Error::duplicate_field("fields")); } - fields__ = Some(map.next_value()?); + fields__ = Some(map_.next_value()?); } } } @@ -20862,8 +24399,168 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "alias" => Ok(GeneratedField::Alias), + "input" => Ok(GeneratedField::Input), + "alias" => Ok(GeneratedField::Alias), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SubqueryAliasNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SubqueryAliasNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut alias__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); + } + alias__ = map_.next_value()?; + } + } + } + Ok(SubqueryAliasNode { + input: input__, + alias: alias__, + }) + } + } + deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = StreamPartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20873,41 +24570,81 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SubqueryAliasNode; + type Value = SymmetricHashJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SubqueryAliasNode") + formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut alias__ = None; - while let Some(k) = map.next_key()? { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; + while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - input__ = map.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); } - alias__ = map.next_value()?; + filter__ = map_.next_value()?; } } } - Ok(SubqueryAliasNode { - input: input__, - alias: alias__, + Ok(SymmetricHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for TimeUnit { @@ -20951,10 +24688,9 @@ impl<'de> serde::Deserialize<'de> for TimeUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(TimeUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -20964,10 +24700,9 @@ impl<'de> serde::Deserialize<'de> for TimeUnit { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(TimeUnit::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -21005,8 +24740,8 @@ impl serde::Serialize for Timestamp { } let mut struct_ser = serializer.serialize_struct("datafusion.Timestamp", len)?; if self.time_unit != 0 { - let v = TimeUnit::from_i32(self.time_unit) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.time_unit)))?; + let v = TimeUnit::try_from(self.time_unit) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.time_unit)))?; struct_ser.serialize_field("timeUnit", &v)?; } if !self.timezone.is_empty() { @@ -21069,25 +24804,25 @@ impl<'de> serde::Deserialize<'de> for Timestamp { formatter.write_str("struct datafusion.Timestamp") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut time_unit__ = None; let mut timezone__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::TimeUnit => { if time_unit__.is_some() { return Err(serde::de::Error::duplicate_field("timeUnit")); } - time_unit__ = Some(map.next_value::()? as i32); + time_unit__ = Some(map_.next_value::()? as i32); } GeneratedField::Timezone => { if timezone__.is_some() { return Err(serde::de::Error::duplicate_field("timezone")); } - timezone__ = Some(map.next_value()?); + timezone__ = Some(map_.next_value()?); } } } @@ -21178,25 +24913,25 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { formatter.write_str("struct datafusion.TryCastNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; let mut arrow_type__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::ArrowType => { if arrow_type__.is_some() { return Err(serde::de::Error::duplicate_field("arrowType")); } - arrow_type__ = map.next_value()?; + arrow_type__ = map_.next_value()?; } } } @@ -21231,8 +24966,8 @@ impl serde::Serialize for Union { struct_ser.serialize_field("unionTypes", &self.union_types)?; } if self.union_mode != 0 { - let v = UnionMode::from_i32(self.union_mode) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.union_mode)))?; + let v = UnionMode::try_from(self.union_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.union_mode)))?; struct_ser.serialize_field("unionMode", &v)?; } if !self.type_ids.is_empty() { @@ -21300,33 +25035,33 @@ impl<'de> serde::Deserialize<'de> for Union { formatter.write_str("struct datafusion.Union") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut union_types__ = None; let mut union_mode__ = None; let mut type_ids__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::UnionTypes => { if union_types__.is_some() { return Err(serde::de::Error::duplicate_field("unionTypes")); } - union_types__ = Some(map.next_value()?); + union_types__ = Some(map_.next_value()?); } GeneratedField::UnionMode => { if union_mode__.is_some() { return Err(serde::de::Error::duplicate_field("unionMode")); } - union_mode__ = Some(map.next_value::()? as i32); + union_mode__ = Some(map_.next_value::()? as i32); } GeneratedField::TypeIds => { if type_ids__.is_some() { return Err(serde::de::Error::duplicate_field("typeIds")); } type_ids__ = - Some(map.next_value::>>()? + Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; } @@ -21410,18 +25145,18 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { formatter.write_str("struct datafusion.UnionExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut inputs__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Inputs => { if inputs__.is_some() { return Err(serde::de::Error::duplicate_field("inputs")); } - inputs__ = Some(map.next_value()?); + inputs__ = Some(map_.next_value()?); } } } @@ -21470,10 +25205,9 @@ impl<'de> serde::Deserialize<'de> for UnionMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(UnionMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -21483,10 +25217,9 @@ impl<'de> serde::Deserialize<'de> for UnionMode { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(UnionMode::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -21574,18 +25307,18 @@ impl<'de> serde::Deserialize<'de> for UnionNode { formatter.write_str("struct datafusion.UnionNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut inputs__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Inputs => { if inputs__.is_some() { return Err(serde::de::Error::duplicate_field("inputs")); } - inputs__ = Some(map.next_value()?); + inputs__ = Some(map_.next_value()?); } } } @@ -21597,6 +25330,100 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UniqueConstraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.indices.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UniqueConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UniqueConstraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "indices", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Indices, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "indices" => Ok(GeneratedField::Indices), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UniqueConstraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UniqueConstraint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut indices__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); + } + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(UniqueConstraint { + indices: indices__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UniqueConstraint", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ValuesNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -21613,6 +25440,7 @@ impl serde::Serialize for ValuesNode { } let mut struct_ser = serializer.serialize_struct("datafusion.ValuesNode", len)?; if self.n_cols != 0 { + #[allow(clippy::needless_borrow)] struct_ser.serialize_field("nCols", ToString::to_string(&self.n_cols).as_str())?; } if !self.values_list.is_empty() { @@ -21676,27 +25504,27 @@ impl<'de> serde::Deserialize<'de> for ValuesNode { formatter.write_str("struct datafusion.ValuesNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut n_cols__ = None; let mut values_list__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::NCols => { if n_cols__.is_some() { return Err(serde::de::Error::duplicate_field("nCols")); } n_cols__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } GeneratedField::ValuesList => { if values_list__.is_some() { return Err(serde::de::Error::duplicate_field("valuesList")); } - values_list__ = Some(map.next_value()?); + values_list__ = Some(map_.next_value()?); } } } @@ -21814,7 +25642,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { formatter.write_str("struct datafusion.ViewTableScanNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -21823,37 +25651,37 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { let mut schema__ = None; let mut projection__ = None; let mut definition__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::TableName => { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = map.next_value()?; + table_name__ = map_.next_value()?; } GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = map.next_value()?; + schema__ = map_.next_value()?; } GeneratedField::Projection => { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - projection__ = map.next_value()?; + projection__ = map_.next_value()?; } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); } - definition__ = Some(map.next_value()?); + definition__ = Some(map_.next_value()?); } } } @@ -21948,25 +25776,25 @@ impl<'de> serde::Deserialize<'de> for WhenThen { formatter.write_str("struct datafusion.WhenThen") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut when_expr__ = None; let mut then_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::WhenExpr => { if when_expr__.is_some() { return Err(serde::de::Error::duplicate_field("whenExpr")); } - when_expr__ = map.next_value()?; + when_expr__ = map_.next_value()?; } GeneratedField::ThenExpr => { if then_expr__.is_some() { return Err(serde::de::Error::duplicate_field("thenExpr")); } - then_expr__ = map.next_value()?; + then_expr__ = map_.next_value()?; } } } @@ -21979,6 +25807,97 @@ impl<'de> serde::Deserialize<'de> for WhenThen { deserializer.deserialize_struct("datafusion.WhenThen", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for Wildcard { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.qualifier.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Wildcard { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "qualifier", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Qualifier, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "qualifier" => Ok(GeneratedField::Qualifier), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Wildcard; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.Wildcard") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut qualifier__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); + } + qualifier__ = map_.next_value()?; + } + } + } + Ok(Wildcard { + qualifier: qualifier__, + }) + } + } + deserializer.deserialize_struct("datafusion.Wildcard", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for WindowAggExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -21993,10 +25912,10 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { len += 1; } - if !self.window_expr_name.is_empty() { + if !self.partition_keys.is_empty() { len += 1; } - if self.input_schema.is_some() { + if self.input_order_mode.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowAggExecNode", len)?; @@ -22006,11 +25925,21 @@ impl serde::Serialize for WindowAggExecNode { if !self.window_expr.is_empty() { struct_ser.serialize_field("windowExpr", &self.window_expr)?; } - if !self.window_expr_name.is_empty() { - struct_ser.serialize_field("windowExprName", &self.window_expr_name)?; + if !self.partition_keys.is_empty() { + struct_ser.serialize_field("partitionKeys", &self.partition_keys)?; } - if let Some(v) = self.input_schema.as_ref() { - struct_ser.serialize_field("inputSchema", v)?; + if let Some(v) = self.input_order_mode.as_ref() { + match v { + window_agg_exec_node::InputOrderMode::Linear(v) => { + struct_ser.serialize_field("linear", v)?; + } + window_agg_exec_node::InputOrderMode::PartiallySorted(v) => { + struct_ser.serialize_field("partiallySorted", v)?; + } + window_agg_exec_node::InputOrderMode::Sorted(v) => { + struct_ser.serialize_field("sorted", v)?; + } + } } struct_ser.end() } @@ -22025,18 +25954,22 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { "input", "window_expr", "windowExpr", - "window_expr_name", - "windowExprName", - "input_schema", - "inputSchema", + "partition_keys", + "partitionKeys", + "linear", + "partially_sorted", + "partiallySorted", + "sorted", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, WindowExpr, - WindowExprName, - InputSchema, + PartitionKeys, + Linear, + PartiallySorted, + Sorted, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22060,8 +25993,10 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { match value { "input" => Ok(GeneratedField::Input), "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "windowExprName" | "window_expr_name" => Ok(GeneratedField::WindowExprName), - "inputSchema" | "input_schema" => Ok(GeneratedField::InputSchema), + "partitionKeys" | "partition_keys" => Ok(GeneratedField::PartitionKeys), + "linear" => Ok(GeneratedField::Linear), + "partiallySorted" | "partially_sorted" => Ok(GeneratedField::PartiallySorted), + "sorted" => Ok(GeneratedField::Sorted), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22077,47 +26012,62 @@ impl<'de> serde::Deserialize<'de> for WindowAggExecNode { formatter.write_str("struct datafusion.WindowAggExecNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut window_expr__ = None; - let mut window_expr_name__ = None; - let mut input_schema__ = None; - while let Some(k) = map.next_key()? { + let mut partition_keys__ = None; + let mut input_order_mode__ = None; + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::WindowExpr => { if window_expr__.is_some() { return Err(serde::de::Error::duplicate_field("windowExpr")); } - window_expr__ = Some(map.next_value()?); + window_expr__ = Some(map_.next_value()?); } - GeneratedField::WindowExprName => { - if window_expr_name__.is_some() { - return Err(serde::de::Error::duplicate_field("windowExprName")); + GeneratedField::PartitionKeys => { + if partition_keys__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionKeys")); } - window_expr_name__ = Some(map.next_value()?); + partition_keys__ = Some(map_.next_value()?); } - GeneratedField::InputSchema => { - if input_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("inputSchema")); + GeneratedField::Linear => { + if input_order_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("linear")); } - input_schema__ = map.next_value()?; + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Linear) +; + } + GeneratedField::PartiallySorted => { + if input_order_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partiallySorted")); + } + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::PartiallySorted) +; + } + GeneratedField::Sorted => { + if input_order_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("sorted")); + } + input_order_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(window_agg_exec_node::InputOrderMode::Sorted) +; } } } Ok(WindowAggExecNode { input: input__, window_expr: window_expr__.unwrap_or_default(), - window_expr_name: window_expr_name__.unwrap_or_default(), - input_schema: input_schema__, + partition_keys: partition_keys__.unwrap_or_default(), + input_order_mode: input_order_mode__, }) } } @@ -22163,15 +26113,21 @@ impl serde::Serialize for WindowExprNode { if let Some(v) = self.window_function.as_ref() { match v { window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = AggregateFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("aggrFunction", &v)?; } window_expr_node::WindowFunction::BuiltInFunction(v) => { - let v = BuiltInWindowFunction::from_i32(*v) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + let v = BuiltInWindowFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("builtInFunction", &v)?; } + window_expr_node::WindowFunction::Udaf(v) => { + struct_ser.serialize_field("udaf", v)?; + } + window_expr_node::WindowFunction::Udwf(v) => { + struct_ser.serialize_field("udwf", v)?; + } } } struct_ser.end() @@ -22195,6 +26151,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "aggrFunction", "built_in_function", "builtInFunction", + "udaf", + "udwf", ]; #[allow(clippy::enum_variant_names)] @@ -22205,6 +26163,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { WindowFrame, AggrFunction, BuiltInFunction, + Udaf, + Udwf, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22232,6 +26192,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "udaf" => Ok(GeneratedField::Udaf), + "udwf" => Ok(GeneratedField::Udwf), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22247,7 +26209,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { formatter.write_str("struct datafusion.WindowExprNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -22256,43 +26218,55 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut order_by__ = None; let mut window_frame__ = None; let mut window_function__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map.next_value()?; + expr__ = map_.next_value()?; } GeneratedField::PartitionBy => { if partition_by__.is_some() { return Err(serde::de::Error::duplicate_field("partitionBy")); } - partition_by__ = Some(map.next_value()?); + partition_by__ = Some(map_.next_value()?); } GeneratedField::OrderBy => { if order_by__.is_some() { return Err(serde::de::Error::duplicate_field("orderBy")); } - order_by__ = Some(map.next_value()?); + order_by__ = Some(map_.next_value()?); } GeneratedField::WindowFrame => { if window_frame__.is_some() { return Err(serde::de::Error::duplicate_field("windowFrame")); } - window_frame__ = map.next_value()?; + window_frame__ = map_.next_value()?; } GeneratedField::AggrFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("aggrFunction")); } - window_function__ = map.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::AggrFunction(x as i32)); + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::AggrFunction(x as i32)); } GeneratedField::BuiltInFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("builtInFunction")); } - window_function__ = map.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::BuiltInFunction(x as i32)); + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::BuiltInFunction(x as i32)); + } + GeneratedField::Udaf => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("udaf")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(window_expr_node::WindowFunction::Udaf); + } + GeneratedField::Udwf => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("udwf")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(window_expr_node::WindowFunction::Udwf); } } } @@ -22327,8 +26301,8 @@ impl serde::Serialize for WindowFrame { } let mut struct_ser = serializer.serialize_struct("datafusion.WindowFrame", len)?; if self.window_frame_units != 0 { - let v = WindowFrameUnits::from_i32(self.window_frame_units) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.window_frame_units)))?; + let v = WindowFrameUnits::try_from(self.window_frame_units) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.window_frame_units)))?; struct_ser.serialize_field("windowFrameUnits", &v)?; } if let Some(v) = self.start_bound.as_ref() { @@ -22402,32 +26376,32 @@ impl<'de> serde::Deserialize<'de> for WindowFrame { formatter.write_str("struct datafusion.WindowFrame") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut window_frame_units__ = None; let mut start_bound__ = None; let mut end_bound__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::WindowFrameUnits => { if window_frame_units__.is_some() { return Err(serde::de::Error::duplicate_field("windowFrameUnits")); } - window_frame_units__ = Some(map.next_value::()? as i32); + window_frame_units__ = Some(map_.next_value::()? as i32); } GeneratedField::StartBound => { if start_bound__.is_some() { return Err(serde::de::Error::duplicate_field("startBound")); } - start_bound__ = map.next_value()?; + start_bound__ = map_.next_value()?; } GeneratedField::Bound => { if end_bound__.is_some() { return Err(serde::de::Error::duplicate_field("bound")); } - end_bound__ = map.next_value::<::std::option::Option<_>>()?.map(window_frame::EndBound::Bound) + end_bound__ = map_.next_value::<::std::option::Option<_>>()?.map(window_frame::EndBound::Bound) ; } } @@ -22458,8 +26432,8 @@ impl serde::Serialize for WindowFrameBound { } let mut struct_ser = serializer.serialize_struct("datafusion.WindowFrameBound", len)?; if self.window_frame_bound_type != 0 { - let v = WindowFrameBoundType::from_i32(self.window_frame_bound_type) - .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.window_frame_bound_type)))?; + let v = WindowFrameBoundType::try_from(self.window_frame_bound_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.window_frame_bound_type)))?; struct_ser.serialize_field("windowFrameBoundType", &v)?; } if let Some(v) = self.bound_value.as_ref() { @@ -22523,25 +26497,25 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBound { formatter.write_str("struct datafusion.WindowFrameBound") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut window_frame_bound_type__ = None; let mut bound_value__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::WindowFrameBoundType => { if window_frame_bound_type__.is_some() { return Err(serde::de::Error::duplicate_field("windowFrameBoundType")); } - window_frame_bound_type__ = Some(map.next_value::()? as i32); + window_frame_bound_type__ = Some(map_.next_value::()? as i32); } GeneratedField::BoundValue => { if bound_value__.is_some() { return Err(serde::de::Error::duplicate_field("boundValue")); } - bound_value__ = map.next_value()?; + bound_value__ = map_.next_value()?; } } } @@ -22593,10 +26567,9 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBoundType { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(WindowFrameBoundType::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -22606,10 +26579,9 @@ impl<'de> serde::Deserialize<'de> for WindowFrameBoundType { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(WindowFrameBoundType::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -22669,10 +26641,9 @@ impl<'de> serde::Deserialize<'de> for WindowFrameUnits { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(WindowFrameUnits::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) }) @@ -22682,10 +26653,9 @@ impl<'de> serde::Deserialize<'de> for WindowFrameUnits { where E: serde::de::Error, { - use std::convert::TryFrom; i32::try_from(v) .ok() - .and_then(WindowFrameUnits::from_i32) + .and_then(|x| x.try_into().ok()) .ok_or_else(|| { serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) }) @@ -22784,25 +26754,25 @@ impl<'de> serde::Deserialize<'de> for WindowNode { formatter.write_str("struct datafusion.WindowNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut window_expr__ = None; - while let Some(k) = map.next_key()? { + while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { if input__.is_some() { return Err(serde::de::Error::duplicate_field("input")); } - input__ = map.next_value()?; + input__ = map_.next_value()?; } GeneratedField::WindowExpr => { if window_expr__.is_some() { return Err(serde::de::Error::duplicate_field("windowExpr")); } - window_expr__ = Some(map.next_value()?); + window_expr__ = Some(map_.next_value()?); } } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b1ae0058dcb2c..8aadc96349ca5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28" )] pub logical_plan_type: ::core::option::Option, } @@ -99,6 +99,8 @@ pub mod logical_plan_node { Prepare(::prost::alloc::boxed::Box), #[prost(message, tag = "27")] DropView(super::DropViewNode), + #[prost(message, tag = "28")] + DistinctOn(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -122,6 +124,19 @@ pub struct CsvFormat { pub has_header: bool, #[prost(string, tag = "2")] pub delimiter: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub quote: ::prost::alloc::string::String, + #[prost(oneof = "csv_format::OptionalEscape", tags = "4")] + pub optional_escape: ::core::option::Option, +} +/// Nested message and enum types in `CsvFormat`. +pub mod csv_format { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum OptionalEscape { + #[prost(string, tag = "4")] + Escape(::prost::alloc::string::String), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -278,6 +293,41 @@ pub struct EmptyRelationNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrimaryKeyConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UniqueConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraint { + #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] + pub constraint_mode: ::core::option::Option, +} +/// Nested message and enum types in `Constraint`. +pub mod constraint { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ConstraintMode { + #[prost(message, tag = "1")] + PrimaryKey(super::PrimaryKeyConstraint), + #[prost(message, tag = "2")] + Unique(super::UniqueConstraint), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraints { + #[prost(message, repeated, tag = "1")] + pub constraints: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateExternalTableNode { #[prost(message, optional, tag = "12")] pub name: ::core::option::Option, @@ -308,6 +358,13 @@ pub struct CreateExternalTableNode { ::prost::alloc::string::String, ::prost::alloc::string::String, >, + #[prost(message, optional, tag = "15")] + pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "16")] + pub column_defaults: ::std::collections::HashMap< + ::prost::alloc::string::String, + LogicalExprNode, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -433,6 +490,18 @@ pub struct DistinctNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DistinctOnNode { + #[prost(message, repeated, tag = "1")] + pub on_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub select_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub sort_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "4")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -519,8 +588,8 @@ pub mod logical_expr_node { Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] InList(::prost::alloc::boxed::Box), - #[prost(bool, tag = "15")] - Wildcard(bool), + #[prost(message, tag = "15")] + Wildcard(super::Wildcard), #[prost(message, tag = "16")] ScalarFunction(super::ScalarFunctionNode), #[prost(message, tag = "17")] @@ -566,6 +635,12 @@ pub mod logical_expr_node { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Wildcard { + #[prost(string, optional, tag = "1")] + pub qualifier: ::core::option::Option<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, @@ -598,11 +673,44 @@ pub struct RollupNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NamedStructField { + #[prost(message, optional, tag = "1")] + pub name: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListIndex { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListRange { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct GetIndexedField { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] + pub field: ::core::option::Option, +} +/// Nested message and enum types in `GetIndexedField`. +pub mod get_indexed_field { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Field { + #[prost(message, tag = "2")] + NamedStructField(super::NamedStructField), + #[prost(message, tag = "3")] + ListIndex(::prost::alloc::boxed::Box), + #[prost(message, tag = "4")] + ListRange(::prost::alloc::boxed::Box), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -665,6 +773,8 @@ pub struct AliasNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub relation: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -747,7 +857,7 @@ pub struct WindowExprNode { /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, - #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2")] + #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] pub window_function: ::core::option::Option, } /// Nested message and enum types in `WindowExprNode`. @@ -757,9 +867,12 @@ pub mod window_expr_node { pub enum WindowFunction { #[prost(enumeration = "super::AggregateFunction", tag = "1")] AggrFunction(i32), - /// udaf = 3 #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), + #[prost(string, tag = "3")] + Udaf(::prost::alloc::string::String), + #[prost(string, tag = "9")] + Udwf(::prost::alloc::string::String), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -890,6 +1003,11 @@ pub struct WindowFrameBound { pub struct Schema { #[prost(message, repeated, tag = "1")] pub columns: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "2")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -904,6 +1022,15 @@ pub struct Field { /// for complex data types like structs, unions #[prost(message, repeated, tag = "4")] pub children: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "5")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -976,14 +1103,12 @@ pub struct Union { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarListValue { - /// encode null explicitly to distinguish a list with a null value - /// from a list with no values) - #[prost(bool, tag = "3")] - pub is_null: bool, - #[prost(message, optional, tag = "1")] - pub field: ::core::option::Option, - #[prost(message, repeated, tag = "2")] - pub values: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1084,7 +1209,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1128,17 +1253,30 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), - /// WAS: ScalarType null_list_value = 18; + #[prost(message, tag = "16")] + LargeListValue(super::ScalarListValue), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarListValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), + #[prost(message, tag = "39")] + Decimal256Value(super::Decimal256), #[prost(int64, tag = "21")] Date64Value(i64), #[prost(int32, tag = "24")] IntervalYearmonthValue(i32), #[prost(int64, tag = "25")] IntervalDaytimeValue(i64), + #[prost(int64, tag = "35")] + DurationSecondValue(i64), + #[prost(int64, tag = "36")] + DurationMillisecondValue(i64), + #[prost(int64, tag = "37")] + DurationMicrosecondValue(i64), + #[prost(int64, tag = "38")] + DurationNanosecondValue(i64), #[prost(message, tag = "26")] TimestampValue(super::ScalarTimestampValue), #[prost(message, tag = "27")] @@ -1167,6 +1305,16 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} /// Serialized data type #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1288,7 +1436,7 @@ pub struct OptimizedPhysicalPlanType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 5, 6")] + #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. @@ -1308,10 +1456,14 @@ pub mod plan_type { FinalLogicalPlan(super::EmptyMessage), #[prost(message, tag = "4")] InitialPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "9")] + InitialPhysicalPlanWithStats(super::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] FinalPhysicalPlan(super::EmptyMessage), + #[prost(message, tag = "10")] + FinalPhysicalPlanWithStats(super::EmptyMessage), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1373,7 +1525,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" )] pub physical_plan_type: ::core::option::Option, } @@ -1424,10 +1576,91 @@ pub mod physical_plan_node { SortPreservingMerge( ::prost::alloc::boxed::Box, ), + #[prost(message, tag = "22")] + NestedLoopJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "23")] + Analyze(::prost::alloc::boxed::Box), + #[prost(message, tag = "24")] + JsonSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "25")] + SymmetricHashJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + Interleave(super::InterleaveExecNode), + #[prost(message, tag = "27")] + PlaceholderRow(super::PlaceholderRowExecNode), } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartitionColumn { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub arrow_type: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileTypeWriterOptions { + #[prost(oneof = "file_type_writer_options::FileType", tags = "1")] + pub file_type: ::core::option::Option, +} +/// Nested message and enum types in `FileTypeWriterOptions`. +pub mod file_type_writer_options { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum FileType { + #[prost(message, tag = "1")] + JsonOptions(super::JsonWriterOptions), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileSinkConfig { + #[prost(string, tag = "1")] + pub object_store_url: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "2")] + pub file_groups: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "4")] + pub output_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub table_partition_cols: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "7")] + pub single_file_output: bool, + #[prost(bool, tag = "8")] + pub unbounded_input: bool, + #[prost(bool, tag = "9")] + pub overwrite: bool, + #[prost(message, optional, tag = "10")] + pub file_type_writer_options: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, @@ -1440,7 +1673,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19" )] pub expr_type: ::core::option::Option, } @@ -1483,13 +1716,9 @@ pub mod physical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "15")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::PhysicalWindowExprNode), #[prost(message, tag = "16")] ScalarUdf(super::PhysicalScalarUdfNode), - #[prost(message, tag = "17")] - DateTimeIntervalExpr( - ::prost::alloc::boxed::Box, - ), #[prost(message, tag = "18")] LikeExpr(::prost::alloc::boxed::Box), #[prost(message, tag = "19")] @@ -1513,6 +1742,8 @@ pub struct PhysicalScalarUdfNode { pub struct PhysicalAggregateExprNode { #[prost(message, repeated, tag = "2")] pub expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "5")] + pub ordering_req: ::prost::alloc::vec::Vec, #[prost(bool, tag = "3")] pub distinct: bool, #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] @@ -1534,8 +1765,16 @@ pub mod physical_aggregate_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalWindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub args: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "5")] + pub partition_by: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "6")] + pub order_by: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "7")] + pub window_frame: ::core::option::Option, + #[prost(string, tag = "8")] + pub name: ::prost::alloc::string::String, #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, @@ -1592,10 +1831,10 @@ pub struct PhysicalBinaryExprNode { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalDateTimeIntervalExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub l: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub r: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "1")] + pub l: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub r: ::core::option::Option, #[prost(string, tag = "3")] pub op: ::prost::alloc::string::String, } @@ -1690,6 +1929,8 @@ pub struct FilterExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub default_filter_selectivity: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1747,6 +1988,19 @@ pub struct CsvScanExecNode { pub has_header: bool, #[prost(string, tag = "3")] pub delimiter: ::prost::alloc::string::String, + #[prost(string, tag = "4")] + pub quote: ::prost::alloc::string::String, + #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] + pub optional_escape: ::core::option::Option, +} +/// Nested message and enum types in `CsvScanExecNode`. +pub mod csv_scan_exec_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum OptionalEscape { + #[prost(string, tag = "5")] + Escape(::prost::alloc::string::String), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1774,6 +2028,30 @@ pub struct HashJoinExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SymmetricHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "StreamPartitionMode", tag = "6")] + pub partition_mode: i32, + #[prost(bool, tag = "7")] + pub null_equals_null: bool, + #[prost(message, optional, tag = "8")] + pub filter: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InterleaveExecNode { + #[prost(message, repeated, tag = "1")] + pub inputs: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -1790,6 +2068,18 @@ pub struct ExplainExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct AnalyzeExecNode { + #[prost(bool, tag = "1")] + pub verbose: bool, + #[prost(bool, tag = "2")] + pub show_statistics: bool, + #[prost(message, optional, boxed, tag = "3")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "4")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct CrossJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -1815,9 +2105,13 @@ pub struct JoinOn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { - #[prost(bool, tag = "1")] - pub produce_one_row: bool, - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PlaceholderRowExecNode { + #[prost(message, optional, tag = "1")] pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1832,15 +2126,36 @@ pub struct ProjectionExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartiallySortedInputOrderMode { + #[prost(uint64, repeated, tag = "6")] + pub columns: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowAggExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub window_expr: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "3")] - pub window_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(message, optional, tag = "4")] - pub input_schema: ::core::option::Option, + pub window_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "5")] + pub partition_keys: ::prost::alloc::vec::Vec, + /// Set optional to `None` for `BoundedWindowAggExec`. + #[prost(oneof = "window_agg_exec_node::InputOrderMode", tags = "7, 8, 9")] + pub input_order_mode: ::core::option::Option, +} +/// Nested message and enum types in `WindowAggExecNode`. +pub mod window_agg_exec_node { + /// Set optional to `None` for `BoundedWindowAggExec`. + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum InputOrderMode { + #[prost(message, tag = "7")] + Linear(super::EmptyMessage), + #[prost(message, tag = "8")] + PartiallySorted(super::PartiallySortedInputOrderMode), + #[prost(message, tag = "9")] + Sorted(super::EmptyMessage), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1921,6 +2236,21 @@ pub struct SortPreservingMergeExecNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] pub expr: ::prost::alloc::vec::Vec, + /// Maximum number of highest/lowest rows to fetch; negative means no limit + #[prost(int64, tag = "3")] + pub fetch: i64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NestedLoopJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(enumeration = "JoinType", tag = "3")] + pub join_type: i32, + #[prost(message, optional, tag = "4")] + pub filter: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2019,35 +2349,74 @@ pub struct PartitionStats { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Precision { + #[prost(enumeration = "PrecisionInfo", tag = "1")] + pub precision_info: i32, + #[prost(message, optional, tag = "2")] + pub val: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Statistics { - #[prost(int64, tag = "1")] - pub num_rows: i64, - #[prost(int64, tag = "2")] - pub total_byte_size: i64, + #[prost(message, optional, tag = "1")] + pub num_rows: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub total_byte_size: ::core::option::Option, #[prost(message, repeated, tag = "3")] pub column_stats: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "4")] - pub is_exact: bool, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnStats { #[prost(message, optional, tag = "1")] - pub min_value: ::core::option::Option, + pub min_value: ::core::option::Option, #[prost(message, optional, tag = "2")] - pub max_value: ::core::option::Option, - #[prost(uint32, tag = "3")] - pub null_count: u32, - #[prost(uint32, tag = "4")] - pub distinct_count: u32, + pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub null_count: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub distinct_count: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NamedStructFieldExpr { + #[prost(message, optional, tag = "1")] + pub name: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListIndexExpr { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListRangeExpr { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalGetIndexedFieldExprNode { #[prost(message, optional, boxed, tag = "1")] pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4")] + pub field: ::core::option::Option, +} +/// Nested message and enum types in `PhysicalGetIndexedFieldExprNode`. +pub mod physical_get_indexed_field_expr_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Field { + #[prost(message, tag = "2")] + NamedStructFieldExpr(super::NamedStructFieldExpr), + #[prost(message, tag = "3")] + ListIndexExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "4")] + ListRangeExpr(::prost::alloc::boxed::Box), + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -2211,7 +2580,7 @@ pub enum ScalarFunction { ArrayAppend = 86, ArrayConcat = 87, ArrayDims = 88, - ArrayFill = 89, + ArrayRepeat = 89, ArrayLength = 90, ArrayNdims = 91, ArrayPosition = 92, @@ -2221,7 +2590,37 @@ pub enum ScalarFunction { ArrayReplace = 96, ArrayToString = 97, Cardinality = 98, - TrimArray = 99, + ArrayElement = 99, + ArraySlice = 100, + Encode = 101, + Decode = 102, + Cot = 103, + ArrayHas = 104, + ArrayHasAny = 105, + ArrayHasAll = 106, + ArrayRemoveN = 107, + ArrayReplaceN = 108, + ArrayRemoveAll = 109, + ArrayReplaceAll = 110, + Nanvl = 111, + Flatten = 112, + Isnan = 113, + Iszero = 114, + ArrayEmpty = 115, + ArrayPopBack = 116, + StringToArray = 117, + ToTimestampNanos = 118, + ArrayIntersect = 119, + ArrayUnion = 120, + OverLay = 121, + Range = 122, + ArrayExcept = 123, + ArrayPopFront = 124, + Levenshtein = 125, + SubstrIndex = 126, + FindInSet = 127, + ArraySort = 128, + ArrayDistinct = 129, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2319,7 +2718,7 @@ impl ScalarFunction { ScalarFunction::ArrayAppend => "ArrayAppend", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", - ScalarFunction::ArrayFill => "ArrayFill", + ScalarFunction::ArrayRepeat => "ArrayRepeat", ScalarFunction::ArrayLength => "ArrayLength", ScalarFunction::ArrayNdims => "ArrayNdims", ScalarFunction::ArrayPosition => "ArrayPosition", @@ -2329,7 +2728,37 @@ impl ScalarFunction { ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::ArrayToString => "ArrayToString", ScalarFunction::Cardinality => "Cardinality", - ScalarFunction::TrimArray => "TrimArray", + ScalarFunction::ArrayElement => "ArrayElement", + ScalarFunction::ArraySlice => "ArraySlice", + ScalarFunction::Encode => "Encode", + ScalarFunction::Decode => "Decode", + ScalarFunction::Cot => "Cot", + ScalarFunction::ArrayHas => "ArrayHas", + ScalarFunction::ArrayHasAny => "ArrayHasAny", + ScalarFunction::ArrayHasAll => "ArrayHasAll", + ScalarFunction::ArrayRemoveN => "ArrayRemoveN", + ScalarFunction::ArrayReplaceN => "ArrayReplaceN", + ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", + ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", + ScalarFunction::Nanvl => "Nanvl", + ScalarFunction::Flatten => "Flatten", + ScalarFunction::Isnan => "Isnan", + ScalarFunction::Iszero => "Iszero", + ScalarFunction::ArrayEmpty => "ArrayEmpty", + ScalarFunction::ArrayPopBack => "ArrayPopBack", + ScalarFunction::StringToArray => "StringToArray", + ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::ArrayIntersect => "ArrayIntersect", + ScalarFunction::ArrayUnion => "ArrayUnion", + ScalarFunction::OverLay => "OverLay", + ScalarFunction::Range => "Range", + ScalarFunction::ArrayExcept => "ArrayExcept", + ScalarFunction::ArrayPopFront => "ArrayPopFront", + ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::SubstrIndex => "SubstrIndex", + ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", + ScalarFunction::ArrayDistinct => "ArrayDistinct", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2424,7 +2853,7 @@ impl ScalarFunction { "ArrayAppend" => Some(Self::ArrayAppend), "ArrayConcat" => Some(Self::ArrayConcat), "ArrayDims" => Some(Self::ArrayDims), - "ArrayFill" => Some(Self::ArrayFill), + "ArrayRepeat" => Some(Self::ArrayRepeat), "ArrayLength" => Some(Self::ArrayLength), "ArrayNdims" => Some(Self::ArrayNdims), "ArrayPosition" => Some(Self::ArrayPosition), @@ -2434,7 +2863,37 @@ impl ScalarFunction { "ArrayReplace" => Some(Self::ArrayReplace), "ArrayToString" => Some(Self::ArrayToString), "Cardinality" => Some(Self::Cardinality), - "TrimArray" => Some(Self::TrimArray), + "ArrayElement" => Some(Self::ArrayElement), + "ArraySlice" => Some(Self::ArraySlice), + "Encode" => Some(Self::Encode), + "Decode" => Some(Self::Decode), + "Cot" => Some(Self::Cot), + "ArrayHas" => Some(Self::ArrayHas), + "ArrayHasAny" => Some(Self::ArrayHasAny), + "ArrayHasAll" => Some(Self::ArrayHasAll), + "ArrayRemoveN" => Some(Self::ArrayRemoveN), + "ArrayReplaceN" => Some(Self::ArrayReplaceN), + "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), + "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), + "Nanvl" => Some(Self::Nanvl), + "Flatten" => Some(Self::Flatten), + "Isnan" => Some(Self::Isnan), + "Iszero" => Some(Self::Iszero), + "ArrayEmpty" => Some(Self::ArrayEmpty), + "ArrayPopBack" => Some(Self::ArrayPopBack), + "StringToArray" => Some(Self::StringToArray), + "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "ArrayIntersect" => Some(Self::ArrayIntersect), + "ArrayUnion" => Some(Self::ArrayUnion), + "OverLay" => Some(Self::OverLay), + "Range" => Some(Self::Range), + "ArrayExcept" => Some(Self::ArrayExcept), + "ArrayPopFront" => Some(Self::ArrayPopFront), + "Levenshtein" => Some(Self::Levenshtein), + "SubstrIndex" => Some(Self::SubstrIndex), + "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), + "ArrayDistinct" => Some(Self::ArrayDistinct), _ => None, } } @@ -2470,6 +2929,16 @@ pub enum AggregateFunction { /// we append "_AGG" to obey name scoping rules. FirstValueAgg = 24, LastValueAgg = 25, + RegrSlope = 26, + RegrIntercept = 27, + RegrCount = 28, + RegrR2 = 29, + RegrAvgx = 30, + RegrAvgy = 31, + RegrSxx = 32, + RegrSyy = 33, + RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2506,6 +2975,16 @@ impl AggregateFunction { AggregateFunction::BoolOr => "BOOL_OR", AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG", AggregateFunction::LastValueAgg => "LAST_VALUE_AGG", + AggregateFunction::RegrSlope => "REGR_SLOPE", + AggregateFunction::RegrIntercept => "REGR_INTERCEPT", + AggregateFunction::RegrCount => "REGR_COUNT", + AggregateFunction::RegrR2 => "REGR_R2", + AggregateFunction::RegrAvgx => "REGR_AVGX", + AggregateFunction::RegrAvgy => "REGR_AVGY", + AggregateFunction::RegrSxx => "REGR_SXX", + AggregateFunction::RegrSyy => "REGR_SYY", + AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2539,6 +3018,16 @@ impl AggregateFunction { "BOOL_OR" => Some(Self::BoolOr), "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg), "LAST_VALUE_AGG" => Some(Self::LastValueAgg), + "REGR_SLOPE" => Some(Self::RegrSlope), + "REGR_INTERCEPT" => Some(Self::RegrIntercept), + "REGR_COUNT" => Some(Self::RegrCount), + "REGR_R2" => Some(Self::RegrR2), + "REGR_AVGX" => Some(Self::RegrAvgx), + "REGR_AVGY" => Some(Self::RegrAvgy), + "REGR_SXX" => Some(Self::RegrSxx), + "REGR_SYY" => Some(Self::RegrSyy), + "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } @@ -2769,6 +3258,41 @@ impl UnionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + CompressionTypeVariant::Gzip => "GZIP", + CompressionTypeVariant::Bzip2 => "BZIP2", + CompressionTypeVariant::Xz => "XZ", + CompressionTypeVariant::Zstd => "ZSTD", + CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, @@ -2798,11 +3322,38 @@ impl PartitionMode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum StreamPartitionMode { + SinglePartition = 0, + PartitionedExec = 1, +} +impl StreamPartitionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + StreamPartitionMode::SinglePartition => "SINGLE_PARTITION", + StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "SINGLE_PARTITION" => Some(Self::SinglePartition), + "PARTITIONED_EXEC" => Some(Self::PartitionedExec), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum AggregateMode { Partial = 0, Final = 1, FinalPartitioned = 2, Single = 3, + SinglePartitioned = 4, } impl AggregateMode { /// String value of the enum field names used in the ProtoBuf definition. @@ -2815,6 +3366,7 @@ impl AggregateMode { AggregateMode::Final => "FINAL", AggregateMode::FinalPartitioned => "FINAL_PARTITIONED", AggregateMode::Single => "SINGLE", + AggregateMode::SinglePartitioned => "SINGLE_PARTITIONED", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2824,6 +3376,7 @@ impl AggregateMode { "FINAL" => Some(Self::Final), "FINAL_PARTITIONED" => Some(Self::FinalPartitioned), "SINGLE" => Some(Self::Single), + "SINGLE_PARTITIONED" => Some(Self::SinglePartitioned), _ => None, } } @@ -2854,3 +3407,32 @@ impl JoinSide { } } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PrecisionInfo { + Exact = 0, + Inexact = 1, + Absent = 2, +} +impl PrecisionInfo { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + PrecisionInfo::Exact => "EXACT", + PrecisionInfo::Inexact => "INEXACT", + PrecisionInfo::Absent => "ABSENT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "EXACT" => Some(Self::Exact), + "INEXACT" => Some(Self::Inexact), + "ABSENT" => Some(Self::Absent), + _ => None, + } + } +} diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ab2985f448a81..193e0947d6d9c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,45 +19,57 @@ use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, - UnionFields, UnionMode, +use arrow::{ + buffer::Buffer, + datatypes::{ + i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + UnionFields, UnionMode, + }, + ipc::{reader::read_record_batch, root_as_message}, }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, - ScalarValue, + internal_err, plan_datafusion_err, Column, Constraint, Constraints, DFField, + DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; -use datafusion_expr::expr::Placeholder; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill, - array_length, array_ndims, array_position, array_positions, array_prepend, - array_remove, array_replace, array_to_string, ascii, asin, asinh, atan, atan2, atanh, - bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, - concat_expr, concat_ws_expr, cos, cosh, date_bin, date_part, date_trunc, degrees, - digest, exp, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, + array_element, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_position, array_positions, + array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, + array_replace, array_replace_all, array_replace_n, array_slice, array_sort, + array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, + btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, + date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, + lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, now, nullif, octet_length, pi, power, radians, random, - regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, - sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, - strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trim, trim_array, trunc, upper, - uuid, - window_frame::regularize, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, + radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, + sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, + substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetIndexedField, GroupingSet, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use datafusion_expr::{ + array_empty, array_pop_back, array_pop_front, + expr::{Alias, Placeholder}, +}; use std::sync::Arc; #[derive(Debug)] @@ -326,8 +338,8 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { .collect::>()?, ), arrow_type::ArrowTypeEnum::Union(union) => { - let union_mode = protobuf::UnionMode::from_i32(union.union_mode) - .ok_or_else(|| Error::unknown("UnionMode", union.union_mode))?; + let union_mode = protobuf::UnionMode::try_from(union.union_mode) + .map_err(|_| Error::unknown("UnionMode", union.union_mode))?; let union_mode = match union_mode { protobuf::UnionMode::Dense => UnionMode::Dense, protobuf::UnionMode::Sparse => UnionMode::Sparse, @@ -365,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field { type Error = Error; fn try_from(field: &protobuf::Field) -> Result { let datatype = field.arrow_type.as_deref().required("arrow_type")?; - - Ok(Self::new(field.name.as_str(), datatype, field.nullable)) + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) } } @@ -396,12 +420,14 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, + InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), } } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, + FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, }, plan: Arc::new(stringified_plan.plan.clone()), } @@ -417,6 +443,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, ScalarFunction::Tan => Self::Tan, + ScalarFunction::Cot => Self::Cot, ScalarFunction::Asin => Self::Asin, ScalarFunction::Acos => Self::Acos, ScalarFunction::Atan => Self::Atan, @@ -449,20 +476,38 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, + ScalarFunction::ArrayEmpty => Self::ArrayEmpty, + ScalarFunction::ArrayExcept => Self::ArrayExcept, + ScalarFunction::ArrayHasAll => Self::ArrayHasAll, + ScalarFunction::ArrayHasAny => Self::ArrayHasAny, + ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, - ScalarFunction::ArrayFill => Self::ArrayFill, + ScalarFunction::ArrayDistinct => Self::ArrayDistinct, + ScalarFunction::ArrayElement => Self::ArrayElement, + ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, + ScalarFunction::ArrayPopFront => Self::ArrayPopFront, + ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, ScalarFunction::ArrayPrepend => Self::ArrayPrepend, + ScalarFunction::ArrayRepeat => Self::ArrayRepeat, ScalarFunction::ArrayRemove => Self::ArrayRemove, + ScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, + ScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, ScalarFunction::ArrayReplace => Self::ArrayReplace, + ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, + ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, + ScalarFunction::ArrayUnion => Self::ArrayUnion, + ScalarFunction::Range => Self::Range, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, - ScalarFunction::TrimArray => Self::TrimArray, ScalarFunction::NullIf => Self::NullIf, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, @@ -473,6 +518,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sha384 => Self::SHA384, ScalarFunction::Sha512 => Self::SHA512, ScalarFunction::Digest => Self::Digest, + ScalarFunction::Encode => Self::Encode, + ScalarFunction::Decode => Self::Decode, ScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, @@ -493,11 +540,13 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Right => Self::Right, ScalarFunction::Rpad => Self::Rpad, ScalarFunction::SplitPart => Self::SplitPart, + ScalarFunction::StringToArray => Self::StringToArray, ScalarFunction::StartsWith => Self::StartsWith, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::ToHex => Self::ToHex, ScalarFunction::ToTimestampMicros => Self::ToTimestampMicros, + ScalarFunction::ToTimestampNanos => Self::ToTimestampNanos, ScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds, ScalarFunction::Now => Self::Now, ScalarFunction::CurrentDate => Self::CurrentDate, @@ -511,7 +560,14 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::StructFun => Self::Struct, ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, + ScalarFunction::Nanvl => Self::Nanvl, + ScalarFunction::Isnan => Self::Isnan, + ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, + ScalarFunction::OverLay => Self::OverLay, + ScalarFunction::Levenshtein => Self::Levenshtein, + ScalarFunction::SubstrIndex => Self::SubstrIndex, + ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -538,6 +594,15 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, protobuf::AggregateFunction::Correlation => Self::Correlation, + protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, + protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, + protobuf::AggregateFunction::RegrCount => Self::RegrCount, + protobuf::AggregateFunction::RegrR2 => Self::RegrR2, + protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, + protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, + protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, + protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, + protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } @@ -549,6 +614,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -578,19 +644,9 @@ impl TryFrom<&protobuf::Schema> for Schema { let fields = schema .columns .iter() - .map(|c| { - let pb_arrow_type_res = c - .arrow_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization error: Field message was missing required field 'arrow_type'")); - let pb_arrow_type: &protobuf::ArrowType = match pb_arrow_type_res { - Ok(res) => res, - Err(e) => return Err(e), - }; - Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable)) - }) + .map(Field::try_from) .collect::, _>>()?; - Ok(Self::new(fields)) + Ok(Self::new_with_metadata(fields, schema.metadata.clone())) } } @@ -620,25 +676,56 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - Value::ListValue(scalar_list) => { + // ScalarValue::List is serialized using arrow IPC format + Value::ListValue(scalar_list) + | Value::FixedSizeListValue(scalar_list) + | Value::LargeListValue(scalar_list) => { let protobuf::ScalarListValue { - is_null, - values, - field, + ipc_message, + arrow_data, + schema, } = &scalar_list; - let field: Field = field.as_ref().required("field")?; - let field = Arc::new(field); - - let values: Result, Error> = - values.iter().map(|val| val.try_into()).collect(); - let values = values?; + let schema: Schema = if let Some(schema_ref) = schema { + schema_ref.try_into()? + } else { + return Err(Error::General( + "Invalid schema while deserializing ScalarValue::List" + .to_string(), + )); + }; - validate_list_values(field.as_ref(), &values)?; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data); - let values = if *is_null { None } else { Some(values) }; + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List" + .to_string(), + ) + })?; - Self::List(values, field) + let record_batch = read_record_batch( + &buffer, + ipc_batch, + Arc::new(schema), + &Default::default(), + None, + &message.version(), + ) + .map_err(DataFusionError::ArrowError) + .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + let arr = record_batch.column(0); + match value { + Value::ListValue(_) => Self::List(arr.to_owned()), + Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), + Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + _ => unreachable!(), + } } Value::NullValue(v) => { let null_type: DataType = v.try_into()?; @@ -652,6 +739,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { val.s as i8, ) } + Value::Decimal256Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal256( + Some(i256::from_be_bytes(array)), + val.p as u8, + val.s as i8, + ) + } Value::Date64Value(v) => Self::Date64(Some(*v)), Value::Time32Value(v) => { let time_value = @@ -679,6 +774,10 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } Value::IntervalYearmonthValue(v) => Self::IntervalYearMonth(Some(*v)), Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some(*v)), + Value::DurationSecondValue(v) => Self::DurationSecond(Some(*v)), + Value::DurationMillisecondValue(v) => Self::DurationMillisecond(Some(*v)), + Value::DurationMicrosecondValue(v) => Self::DurationMicrosecond(Some(*v)), + Value::DurationNanosecondValue(v) => Self::DurationNanosecond(Some(*v)), Value::TimestampValue(v) => { let timezone = if v.timezone.is_empty() { None @@ -758,8 +857,8 @@ impl TryFrom for WindowFrame { type Error = Error; fn try_from(window: protobuf::WindowFrame) -> Result { - let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units) - .ok_or_else(|| Error::unknown("WindowFrameUnits", window.window_frame_units))? + let units = protobuf::WindowFrameUnits::try_from(window.window_frame_units) + .map_err(|_| Error::unknown("WindowFrameUnits", window.window_frame_units))? .into(); let start_bound = window.start_bound.required("start_bound")?; let end_bound = window @@ -784,8 +883,8 @@ impl TryFrom for WindowFrameBound { fn try_from(bound: protobuf::WindowFrameBound) -> Result { let bound_type = - protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type) - .ok_or_else(|| { + protobuf::WindowFrameBoundType::try_from(bound.window_frame_bound_type) + .map_err(|_| { Error::unknown("WindowFrameBoundType", bound.window_frame_bound_type) })?; match bound_type { @@ -847,38 +946,49 @@ impl From for JoinConstraint { } } +impl From for Constraints { + fn from(constraints: protobuf::Constraints) -> Self { + Constraints::new_unverified( + constraints + .constraints + .into_iter() + .map(|item| item.into()) + .collect(), + ) + } +} + +impl From for Constraint { + fn from(value: protobuf::Constraint) -> Self { + match value.constraint_mode.unwrap() { + protobuf::constraint::ConstraintMode::PrimaryKey(elem) => { + Constraint::PrimaryKey( + elem.indices.into_iter().map(|item| item as usize).collect(), + ) + } + protobuf::constraint::ConstraintMode::Unique(elem) => Constraint::Unique( + elem.indices.into_iter().map(|item| item as usize).collect(), + ), + } + } +} + pub fn parse_i32_to_time_unit(value: &i32) -> Result { - protobuf::TimeUnit::from_i32(*value) + protobuf::TimeUnit::try_from(*value) .map(|t| t.into()) - .ok_or_else(|| Error::unknown("TimeUnit", *value)) + .map_err(|_| Error::unknown("TimeUnit", *value)) } pub fn parse_i32_to_interval_unit(value: &i32) -> Result { - protobuf::IntervalUnit::from_i32(*value) + protobuf::IntervalUnit::try_from(*value) .map(|t| t.into()) - .ok_or_else(|| Error::unknown("IntervalUnit", *value)) + .map_err(|_| Error::unknown("IntervalUnit", *value)) } pub fn parse_i32_to_aggregate_function(value: &i32) -> Result { - protobuf::AggregateFunction::from_i32(*value) + protobuf::AggregateFunction::try_from(*value) .map(|a| a.into()) - .ok_or_else(|| Error::unknown("AggregateFunction", *value)) -} - -/// Ensures that all `values` are of type DataType::List and have the -/// same type as field -fn validate_list_values(field: &Field, values: &[ScalarValue]) -> Result<(), Error> { - for value in values { - let field_type = field.data_type(); - let value_type = value.get_datatype(); - - if field_type != &value_type { - return Err(proto_error(format!( - "Expected field type {field_type:?}, got scalar of type: {value_type:?}" - ))); - } - } - Ok(()) + .map_err(|_| Error::unknown("AggregateFunction", *value)) } pub fn parse_expr( @@ -916,18 +1026,48 @@ pub fn parse_expr( }) .expect("Binary expression could not be reduced to a single expression.")) } - ExprType::GetIndexedField(field) => { - let key = field - .key - .as_ref() - .ok_or_else(|| Error::required("value"))? - .try_into()?; - - let expr = parse_required_expr(field.expr.as_deref(), registry, "expr")?; + ExprType::GetIndexedField(get_indexed_field) => { + let expr = + parse_required_expr(get_indexed_field.expr.as_deref(), registry, "expr")?; + let field = match &get_indexed_field.field { + Some(protobuf::get_indexed_field::Field::NamedStructField( + named_struct_field, + )) => GetFieldAccess::NamedStructField { + name: named_struct_field + .name + .as_ref() + .ok_or_else(|| Error::required("value"))? + .try_into()?, + }, + Some(protobuf::get_indexed_field::Field::ListIndex(list_index)) => { + GetFieldAccess::ListIndex { + key: Box::new(parse_required_expr( + list_index.key.as_deref(), + registry, + "key", + )?), + } + } + Some(protobuf::get_indexed_field::Field::ListRange(list_range)) => { + GetFieldAccess::ListRange { + start: Box::new(parse_required_expr( + list_range.start.as_deref(), + registry, + "start", + )?), + stop: Box::new(parse_required_expr( + list_range.stop.as_deref(), + registry, + "stop", + )?), + } + } + None => return Err(proto_error("Field must not be None")), + }; Ok(Expr::GetIndexedField(GetIndexedField::new( Box::new(expr), - key, + field, ))) } ExprType::Column(column) => Ok(Expr::Column(column.into())), @@ -945,7 +1085,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?; - let order_by = expr + let mut order_by = expr .order_by .iter() .map(|e| parse_expr(e, registry)) @@ -955,7 +1095,8 @@ pub fn parse_expr( .as_ref() .map::, _>(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { @@ -963,6 +1104,7 @@ pub fn parse_expr( "missing window frame during deserialization".to_string(), ) })?; + regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { @@ -979,8 +1121,8 @@ pub fn parse_expr( ))) } window_expr_node::WindowFunction::BuiltInFunction(i) => { - let built_in_function = protobuf::BuiltInWindowFunction::from_i32(*i) - .ok_or_else(|| Error::unknown("BuiltInWindowFunction", *i))? + let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) + .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); let args = parse_optional_expr(expr.expr.as_deref(), registry)? @@ -997,6 +1139,36 @@ pub fn parse_expr( window_frame, ))) } + window_expr_node::WindowFunction::Udaf(udaf_name) => { + let udaf_function = registry.udaf(udaf_name)?; + let args = parse_optional_expr(expr.expr.as_deref(), registry)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); + Ok(Expr::WindowFunction(WindowFunction::new( + datafusion_expr::window_function::WindowFunction::AggregateUDF( + udaf_function, + ), + args, + partition_by, + order_by, + window_frame, + ))) + } + window_expr_node::WindowFunction::Udwf(udwf_name) => { + let udwf_function = registry.udwf(udwf_name)?; + let args = parse_optional_expr(expr.expr.as_deref(), registry)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); + Ok(Expr::WindowFunction(WindowFunction::new( + datafusion_expr::window_function::WindowFunction::WindowUDF( + udwf_function, + ), + args, + partition_by, + order_by, + window_frame, + ))) + } } } ExprType::AggregateExpr(expr) => { @@ -1013,14 +1185,15 @@ pub fn parse_expr( parse_vec_expr(&expr.order_by, registry)?, ))) } - ExprType::Alias(alias) => Ok(Expr::Alias( - Box::new(parse_required_expr( - alias.expr.as_deref(), - registry, - "expr", - )?), + ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( + parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + alias + .relation + .first() + .map(|r| OwnedTableReference::try_from(r.clone())) + .transpose()?, alias.alias.clone(), - )), + ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), registry, @@ -1089,8 +1262,9 @@ pub fn parse_expr( "pattern", )?), parse_escape_char(&like.escape_char)?, + false, ))), - ExprType::Ilike(like) => Ok(Expr::ILike(Like::new( + ExprType::Ilike(like) => Ok(Expr::Like(Like::new( like.negated, Box::new(parse_required_expr(like.expr.as_deref(), registry, "expr")?), Box::new(parse_required_expr( @@ -1099,6 +1273,7 @@ pub fn parse_expr( "pattern", )?), parse_escape_char(&like.escape_char)?, + true, ))), ExprType::SimilarTo(like) => Ok(Expr::SimilarTo(Like::new( like.negated, @@ -1109,6 +1284,7 @@ pub fn parse_expr( "pattern", )?), parse_escape_char(&like.escape_char)?, + false, ))), ExprType::Case(case) => { let when_then_expr = case @@ -1161,10 +1337,12 @@ pub fn parse_expr( .collect::, _>>()?, in_list.negated, ))), - ExprType::Wildcard(_) => Ok(Expr::Wildcard), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { + qualifier: qualifier.clone(), + }), ExprType::ScalarFunction(expr) => { - let scalar_function = protobuf::ScalarFunction::from_i32(expr.fun) - .ok_or_else(|| Error::unknown("ScalarFunction", expr.fun))?; + let scalar_function = protobuf::ScalarFunction::try_from(expr.fun) + .map_err(|_| Error::unknown("ScalarFunction", expr.fun))?; let args = &expr.args; match scalar_function { @@ -1182,6 +1360,17 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::ArrayPopFront => { + Ok(array_pop_front(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ArrayPopBack => { + Ok(array_pop_back(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayPrepend => Ok(array_prepend( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1192,7 +1381,23 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), - ScalarFunction::ArrayFill => Ok(array_fill( + ScalarFunction::ArrayExcept => Ok(array_except( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayHasAll => Ok(array_has_all( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayHasAny => Ok(array_has_any( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayHas => Ok(array_has( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayIntersect => Ok(array_intersect( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), @@ -1205,26 +1410,57 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayRepeat => Ok(array_repeat( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayRemove => Ok(array_remove( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayRemoveN => Ok(array_remove_n( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::ArrayRemoveAll => Ok(array_remove_all( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayReplace => Ok(array_replace( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::ArrayReplaceN => Ok(array_replace_n( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + parse_expr(&args[3], registry)?, + )), + ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::ArraySlice => Ok(array_slice( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::ArrayToString => Ok(array_to_string( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Range => Ok(gen_range( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } - ScalarFunction::TrimArray => Ok(trim_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::ArrayLength => Ok(array_length( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1232,9 +1468,25 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayDistinct => { + Ok(array_distinct(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ArrayElement => Ok(array_element( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayEmpty => { + Ok(array_empty(parse_expr(&args[0], registry)?)) + } ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayUnion => Ok(array( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)), @@ -1262,7 +1514,12 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), - ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)), + ScalarFunction::Trunc => Ok(trunc( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)), ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), ScalarFunction::OctetLength => { @@ -1291,6 +1548,14 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), + ScalarFunction::Encode => Ok(encode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::Decode => Ok(decode( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1406,6 +1671,10 @@ pub fn parse_expr( )) } } + ScalarFunction::Levenshtein => Ok(levenshtein( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) @@ -1413,6 +1682,9 @@ pub fn parse_expr( ScalarFunction::ToTimestampMicros => { Ok(to_timestamp_micros(parse_expr(&args[0], registry)?)) } + ScalarFunction::ToTimestampNanos => { + Ok(to_timestamp_nanos(parse_expr(&args[0], registry)?)) + } ScalarFunction::ToTimestampSeconds => { Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) } @@ -1444,14 +1716,50 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - _ => Err(proto_error( - "Protobuf deserialization error: Unsupported scalar function", + ScalarFunction::CurrentDate => Ok(current_date()), + ScalarFunction::CurrentTime => Ok(current_time()), + ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Nanvl => Ok(nanvl( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, )), + ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), + ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), + ScalarFunction::ArrowTypeof => { + Ok(arrow_typeof(parse_expr(&args[0], registry)?)) + } + ScalarFunction::ToTimestamp => { + Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?)) + } + ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], registry)?)), + ScalarFunction::StringToArray => Ok(string_to_array( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::OverLay => Ok(overlay( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), + ScalarFunction::SubstrIndex => Ok(substr_index( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), + ScalarFunction::FindInSet => Ok(find_in_set( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::StructFun => { + Ok(struct_fun(parse_expr(&args[0], registry)?)) + } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { let scalar_fn = registry.udf(fun_name.as_str())?; - Ok(Expr::ScalarUDF(expr::ScalarUDF::new( + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() .map(|expr| parse_expr(expr, registry)) @@ -1461,12 +1769,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) @@ -1512,9 +1821,7 @@ fn parse_escape_char(s: &str) -> Result> { match s.len() { 0 => Ok(None), 1 => Ok(s.chars().next()), - _ => Err(DataFusionError::Internal( - "Invalid length for escape char".to_string(), - )), + _ => internal_err!("Invalid length for escape char"), } } @@ -1552,6 +1859,8 @@ pub fn from_proto_binary_op(op: &str) -> Result { "RegexNotIMatch" => Ok(Operator::RegexNotIMatch), "RegexNotMatch" => Ok(Operator::RegexNotMatch), "StringConcat" => Ok(Operator::StringConcat), + "AtArrow" => Ok(Operator::AtArrow), + "ArrowAt" => Ok(Operator::ArrowAt), other => Err(proto_error(format!( "Unsupported binary operator '{other:?}'" ))), @@ -1564,9 +1873,7 @@ fn parse_vec_expr( ) -> Result>, Error> { let res = p .iter() - .map(|elem| { - parse_expr(elem, registry).map_err(|e| DataFusionError::Plan(e.to_string())) - }) + .map(|elem| parse_expr(elem, registry).map_err(|e| plan_datafusion_err!("{}", e))) .collect::>>()?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 3774ce14305dc..50bca0295def6 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; +use std::fmt::Debug; +use std::str::FromStr; +use std::sync::Arc; + use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; @@ -25,12 +30,13 @@ use crate::{ logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, }, }; + use arrow::datatypes::{DataType, Schema, SchemaRef}; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::{ datasource::{ - file_format::{ - avro::AvroFormat, csv::CsvFormat, parquet::ParquetFormat, FileFormat, - }, + file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, view::ViewTable, TableProvider, @@ -38,39 +44,36 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; +use datafusion_common::plan_datafusion_err; use datafusion_common::{ - context, parsers::CompressionTypeVariant, DataFusionError, OwnedTableReference, - Result, + context, internal_err, not_impl_err, parsers::CompressionTypeVariant, + DataFusionError, OwnedTableReference, Result, }; -use datafusion_expr::logical_plan::DdlStatement; -use datafusion_expr::DropView; use datafusion_expr::{ logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, CrossJoin, Distinct, EmptyRelation, Extension, - Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, - SubqueryAlias, TableScan, Values, Window, + CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, + EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, + Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - Expr, LogicalPlan, LogicalPlanBuilder, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; + use prost::bytes::BufMut; use prost::Message; -use std::fmt::Debug; -use std::str::FromStr; -use std::sync::Arc; pub mod from_proto; pub mod to_proto; impl From for DataFusionError { fn from(e: from_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) + plan_datafusion_err!("{}", e) } } impl From for DataFusionError { fn from(e: to_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) + plan_datafusion_err!("{}", e) } } @@ -132,15 +135,11 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { _inputs: &[LogicalPlan], _ctx: &SessionContext, ) -> Result { - Err(DataFusionError::NotImplemented( - "LogicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("LogicalExtensionCodec is not provided") } fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { - Err(DataFusionError::NotImplemented( - "LogicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("LogicalExtensionCodec is not provided") } fn try_decode_table_provider( @@ -149,9 +148,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { - Err(DataFusionError::NotImplemented( - "LogicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("LogicalExtensionCodec is not provided") } fn try_encode_table_provider( @@ -159,9 +156,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { _node: Arc, _buf: &mut Vec, ) -> Result<()> { - Err(DataFusionError::NotImplemented( - "LogicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("LogicalExtensionCodec is not provided") } } @@ -225,11 +220,11 @@ impl AsLogicalPlan for LogicalPlanNode { let values: Vec> = if values.values_list.is_empty() { Ok(Vec::new()) } else if values.values_list.len() % n_cols != 0 { - Err(DataFusionError::Internal(format!( + internal_err!( "Invalid values list length, expect {} to be divisible by {}", values.values_list.len(), n_cols - ))) + ) } else { values .values_list @@ -342,18 +337,25 @@ impl AsLogicalPlan for LogicalPlanNode { "logical_plan::from_proto() Unsupported file format '{self:?}'" )) })? { + #[cfg(feature = "parquet")] &FileFormatType::Parquet(protobuf::ParquetFormat {}) => { Arc::new(ParquetFormat::default()) } FileFormatType::Csv(protobuf::CsvFormat { has_header, delimiter, - }) => Arc::new( - CsvFormat::default() - .with_has_header(*has_header) - .with_delimiter(str_to_byte(delimiter)?), - ), - FileFormatType::Avro(..) => Arc::new(AvroFormat::default()), + quote, + optional_escape + }) => { + let mut csv = CsvFormat::default() + .with_has_header(*has_header) + .with_delimiter(str_to_byte(delimiter, "delimiter")?) + .with_quote(str_to_byte(quote, "quote")?); + if let Some(protobuf::csv_format::OptionalEscape::Escape(escape)) = optional_escape { + csv = csv.with_quote(str_to_byte(escape, "escape")?); + } + Arc::new(csv)}, + FileFormatType::Avro(..) => Arc::new(AvroFormat), }; let table_paths = &scan @@ -363,7 +365,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, _>>()?; let options = ListingOptions::new(file_format) - .with_file_extension(scan.file_extension.clone()) + .with_file_extension(&scan.file_extension) .with_table_partition_cols( scan.table_partition_cols .iter() @@ -388,7 +390,12 @@ impl AsLogicalPlan for LogicalPlanNode { .with_listing_options(options) .with_schema(Arc::new(schema)); - let provider = ListingTable::try_new(config)?; + let provider = ListingTable::try_new(config)?.with_cache( + ctx.state() + .runtime_env() + .cache_manager + .get_file_statistic_cache(), + ); let table_name = from_owned_table_reference( scan.table_name.as_ref(), @@ -453,7 +460,7 @@ impl AsLogicalPlan for LogicalPlanNode { let input: LogicalPlan = into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; - let pb_partition_method = repartition.partition_method.clone().ok_or_else(|| { + let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| { DataFusionError::Internal(String::from( "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'", )) @@ -468,10 +475,10 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?, - partition_count as usize, + *partition_count as usize, ), PartitionMethod::RoundRobin(partition_count) => { - Partitioning::RoundRobinBatch(partition_count as usize) + Partitioning::RoundRobinBatch(*partition_count as usize) } }; @@ -489,6 +496,11 @@ impl AsLogicalPlan for LogicalPlanNode { )) })?; + let constraints = (create_extern_table.constraints.clone()).ok_or_else(|| { + DataFusionError::Internal(String::from( + "Protobuf deserialization error, CreateExternalTableNode was missing required table constraints.", + )) + })?; let definition = if !create_extern_table.definition.is_empty() { Some(create_extern_table.definition.clone()) } else { @@ -497,9 +509,7 @@ impl AsLogicalPlan for LogicalPlanNode { let file_type = create_extern_table.file_type.as_str(); if ctx.table_factory(file_type).is_none() { - Err(DataFusionError::Internal(format!( - "No TableProviderFactory for file type: {file_type}" - )))? + internal_err!("No TableProviderFactory for file type: {file_type}")? } let mut order_exprs = vec![]; @@ -512,6 +522,13 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs.push(order_expr) } + let mut column_defaults = + HashMap::with_capacity(create_extern_table.column_defaults.len()); + for (col_name, expr) in &create_extern_table.column_defaults { + let expr = from_proto::parse_expr(expr, ctx)?; + column_defaults.insert(col_name.clone(), expr); + } + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -530,6 +547,8 @@ impl AsLogicalPlan for LogicalPlanNode { definition, unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), + constraints: constraints.into(), + column_defaults, }))) } LogicalPlanType::CreateView(create_view) => { @@ -632,16 +651,16 @@ impl AsLogicalPlan for LogicalPlanNode { .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; let join_type = - protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| { + protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a JoinNode message with unknown JoinType {}", join.join_type )) })?; - let join_constraint = protobuf::JoinConstraint::from_i32( + let join_constraint = protobuf::JoinConstraint::try_from( join.join_constraint, ) - .ok_or_else(|| { + .map_err(|_| { proto_error(format!( "Received a JoinNode message with unknown JoinConstraint {}", join.join_constraint @@ -724,6 +743,33 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(distinct.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input).distinct()?.build() } + LogicalPlanType::DistinctOn(distinct_on) => { + let input: LogicalPlan = + into_logical_plan!(distinct_on.input, ctx, extension_codec)?; + let on_expr = distinct_on + .on_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let select_expr = distinct_on + .select_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let sort_expr = match distinct_on.sort_expr.len() { + 0 => None, + _ => Some( + distinct_on + .sort_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?, + ), + }; + LogicalPlanBuilder::from(input) + .distinct_on(on_expr, select_expr, sort_expr)? + .build() + } LogicalPlanType::ViewScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -840,20 +886,49 @@ impl AsLogicalPlan for LogicalPlanNode { if let Some(listing_table) = source.downcast_ref::() { let any = listing_table.options().format.as_any(); - let file_format_type = if any.is::() { - FileFormatType::Parquet(protobuf::ParquetFormat {}) - } else if let Some(csv) = any.downcast_ref::() { - FileFormatType::Csv(protobuf::CsvFormat { - delimiter: byte_to_string(csv.delimiter())?, - has_header: csv.has_header(), - }) - } else if any.is::() { - FileFormatType::Avro(protobuf::AvroFormat {}) - } else { - return Err(proto_error(format!( + let file_format_type = { + let mut maybe_some_type = None; + + #[cfg(feature = "parquet")] + if any.is::() { + maybe_some_type = + Some(FileFormatType::Parquet(protobuf::ParquetFormat {})) + }; + + if let Some(csv) = any.downcast_ref::() { + maybe_some_type = + Some(FileFormatType::Csv(protobuf::CsvFormat { + delimiter: byte_to_string( + csv.delimiter(), + "delimiter", + )?, + has_header: csv.has_header(), + quote: byte_to_string(csv.quote(), "quote")?, + optional_escape: if let Some(escape) = csv.escape() { + Some( + protobuf::csv_format::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + ), + ) + } else { + None + }, + })) + } + + if any.is::() { + maybe_some_type = + Some(FileFormatType::Avro(protobuf::AvroFormat {})) + } + + if let Some(file_format_type) = maybe_some_type { + file_format_type + } else { + return Err(proto_error(format!( "Error converting file format, {:?} is invalid as a datafusion format.", listing_table.options().format ))); + } }; let options = listing_table.options(); @@ -966,7 +1041,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -980,6 +1055,42 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + })) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr = match sort_expr { + None => vec![], + Some(sort_expr) => sort_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( + protobuf::DistinctOnNode { + on_expr: on_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + select_expr: select_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + sort_expr, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::Window(Window { input, window_expr, .. }) => { @@ -1075,9 +1186,9 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => Err(DataFusionError::NotImplemented( - "LogicalPlan serde is not yet implemented for subqueries".to_string(), - )), + LogicalPlan::Subquery(_) => { + not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") + } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( @@ -1158,9 +1269,7 @@ impl AsLogicalPlan for LogicalPlanNode { PartitionMethod::RoundRobin(*partition_count as u64) } Partitioning::DistributeBy(_) => { - return Err(DataFusionError::NotImplemented( - "DistributeBy".to_string(), - )) + return not_impl_err!("DistributeBy") } }; @@ -1197,6 +1306,8 @@ impl AsLogicalPlan for LogicalPlanNode { order_exprs, unbounded, options, + constraints, + column_defaults, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1211,6 +1322,12 @@ impl AsLogicalPlan for LogicalPlanNode { converted_order_exprs.push(temp); } + let mut converted_column_defaults = + HashMap::with_capacity(column_defaults.len()); + for (col_name, expr) in column_defaults { + converted_column_defaults.insert(col_name.clone(), expr.try_into()?); + } + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { @@ -1227,6 +1344,8 @@ impl AsLogicalPlan for LogicalPlanNode { file_compression_type: file_compression_type.to_string(), unbounded: *unbounded, options: options.clone(), + constraints: Some(constraints.clone().into()), + column_defaults: converted_column_defaults, }, )), }) @@ -1415,1383 +1534,12 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Dml(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for Dml", )), + LogicalPlan::Copy(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for Copy", + )), LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), } } } - -#[cfg(test)] -mod roundtrip_tests { - use super::from_proto::parse_expr; - use super::protobuf; - use crate::bytes::{ - logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, - logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, - }; - use crate::logical_plan::LogicalExtensionCodec; - use arrow::datatypes::{Fields, Schema, SchemaRef, UnionFields}; - use arrow::{ - array::ArrayRef, - datatypes::{ - DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - TimeUnit, UnionMode, - }, - }; - use datafusion::datasource::datasource::TableProviderFactory; - use datafusion::datasource::TableProvider; - use datafusion::execution::context::SessionState; - use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::prelude::{ - create_udf, CsvReadOptions, SessionConfig, SessionContext, - }; - use datafusion::test_util::{TestTableFactory, TestTableProvider}; - use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; - use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - ScalarUDF, Sort, - }; - use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; - use datafusion_expr::{ - col, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, TryCast, Volatility, - }; - use datafusion_expr::{ - create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, - }; - use prost::Message; - use std::collections::HashMap; - use std::fmt; - use std::fmt::Debug; - use std::fmt::Formatter; - use std::sync::Arc; - - #[cfg(feature = "json")] - fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { - let string = serde_json::to_string(proto).unwrap(); - let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); - assert_eq!(proto, &back); - } - - #[cfg(not(feature = "json"))] - fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} - - // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test - // equality. - fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) - where - for<'a> &'a T: TryInto + Debug, - E: Debug, - { - let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); - let round_trip: Expr = parse_expr(&proto, &ctx).unwrap(); - - assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); - - roundtrip_json_test(&proto); - } - - fn new_arc_field(name: &str, dt: DataType, nullable: bool) -> Arc { - Arc::new(Field::new(name, dt, nullable)) - } - - #[tokio::test] - async fn roundtrip_logical_plan() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - let scan = ctx.table("t1").await?.into_optimized_plan()?; - let topk_plan = LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), - }); - let extension_codec = TopKExtensionCodec {}; - let bytes = - logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; - let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; - assert_eq!(format!("{topk_plan:?}"), format!("{logical_round_trip:?}")); - Ok(()) - } - - #[derive(Clone, PartialEq, Eq, ::prost::Message)] - pub struct TestTableProto { - /// URL of the table root - #[prost(string, tag = "1")] - pub url: String, - } - - #[derive(Debug)] - pub struct TestTableProviderCodec {} - - impl LogicalExtensionCodec for TestTableProviderCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[LogicalPlan], - _ctx: &SessionContext, - ) -> Result { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_decode_table_provider( - &self, - buf: &[u8], - schema: SchemaRef, - _ctx: &SessionContext, - ) -> Result> { - let msg = TestTableProto::decode(buf).map_err(|_| { - DataFusionError::Internal("Error decoding test table".to_string()) - })?; - let provider = TestTableProvider { - url: msg.url, - schema, - }; - Ok(Arc::new(provider)) - } - - fn try_encode_table_provider( - &self, - node: Arc, - buf: &mut Vec, - ) -> Result<()> { - let table = node - .as_ref() - .as_any() - .downcast_ref::() - .expect("Can't encode non-test tables"); - let msg = TestTableProto { - url: table.url.clone(), - }; - msg.encode(buf).map_err(|_| { - DataFusionError::Internal("Error encoding test table".to_string()) - }) - } - } - - #[tokio::test] - async fn roundtrip_custom_tables() -> Result<()> { - let mut table_factories: HashMap> = - HashMap::new(); - table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::with_config_rt(ses, Arc::new(env)); - // replace factories - *state.table_factories_mut() = table_factories; - let ctx = SessionContext::with_state(state); - - let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; - ctx.sql(sql).await.unwrap(); - - let codec = TestTableProviderCodec {}; - let scan = ctx.table("t").await?.into_optimized_plan()?; - let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; - let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; - assert_eq!(format!("{scan:?}"), format!("{logical_round_trip:?}")); - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_aggregation() -> Result<()> { - let ctx = SessionContext::new(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Decimal128(15, 2), true), - ]); - - ctx.register_csv( - "t1", - "testdata/test.csv", - CsvReadOptions::default().schema(&schema), - ) - .await?; - - let query = - "SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC"; - let plan = ctx.sql(query).await?.into_optimized_plan()?; - - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_single_count_distinct() -> Result<()> { - let ctx = SessionContext::new(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Decimal128(15, 2), true), - ]); - - ctx.register_csv( - "t1", - "testdata/test.csv", - CsvReadOptions::default().schema(&schema), - ) - .await?; - - let query = "SELECT a, COUNT(DISTINCT b) as b_cd FROM t1 GROUP BY a"; - let plan = ctx.sql(query).await?.into_optimized_plan()?; - - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_with_extension() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - let plan = ctx.table("t1").await?.into_optimized_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - ctx.sql("CREATE VIEW view_t1(a, b) AS SELECT a, b FROM t1") - .await?; - - // SELECT - let plan = ctx - .sql("SELECT * FROM view_t1") - .await? - .into_optimized_plan()?; - - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - - // DROP - let plan = ctx.sql("DROP VIEW view_t1").await?.into_optimized_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - - Ok(()) - } - - pub mod proto { - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct TopKPlanProto { - #[prost(uint64, tag = "1")] - pub k: u64, - - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, - } - - #[derive(Clone, PartialEq, Eq, ::prost::Message)] - pub struct TopKExecProto { - #[prost(uint64, tag = "1")] - pub k: u64, - } - } - - #[derive(PartialEq, Eq, Hash)] - struct TopKPlanNode { - k: usize, - input: LogicalPlan, - /// The sort expression (this example only supports a single sort - /// expr) - expr: Expr, - } - - impl TopKPlanNode { - pub fn new(k: usize, input: LogicalPlan, expr: Expr) -> Self { - Self { k, input, expr } - } - } - - impl Debug for TopKPlanNode { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.fmt_for_explain(f) - } - } - - impl UserDefinedLogicalNodeCore for TopKPlanNode { - fn name(&self) -> &str { - "TopK" - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - vec![&self.input] - } - - /// Schema for TopK is the same as the input - fn schema(&self) -> &DFSchemaRef { - self.input.schema() - } - - fn expressions(&self) -> Vec { - vec![self.expr.clone()] - } - - /// For example: `TopK: k=10` - fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TopK: k={}", self.k) - } - - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { - assert_eq!(inputs.len(), 1, "input size inconsistent"); - assert_eq!(exprs.len(), 1, "expression size inconsistent"); - Self { - k: self.k, - input: inputs[0].clone(), - expr: exprs[0].clone(), - } - } - } - - #[derive(Debug)] - pub struct TopKExtensionCodec {} - - impl LogicalExtensionCodec for TopKExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[LogicalPlan], - ctx: &SessionContext, - ) -> Result { - if let Some((input, _)) = inputs.split_first() { - let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { - DataFusionError::Internal(format!( - "failed to decode logical plan: {e:?}" - )) - })?; - - if let Some(expr) = proto.expr.as_ref() { - let node = TopKPlanNode::new( - proto.k as usize, - input.clone(), - parse_expr(expr, ctx)?, - ); - - Ok(Extension { - node: Arc::new(node), - }) - } else { - Err(DataFusionError::Internal( - "invalid plan, no expr".to_string(), - )) - } - } else { - Err(DataFusionError::Internal( - "invalid plan, no input".to_string(), - )) - } - } - - fn try_encode(&self, node: &Extension, buf: &mut Vec) -> Result<()> { - if let Some(exec) = node.node.as_any().downcast_ref::() { - let proto = proto::TopKPlanProto { - k: exec.k as u64, - expr: Some((&exec.expr).try_into()?), - }; - - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!( - "failed to encode logical plan: {e:?}" - )) - })?; - - Ok(()) - } else { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - } - - fn try_decode_table_provider( - &self, - _buf: &[u8], - _schema: SchemaRef, - _ctx: &SessionContext, - ) -> Result> { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - - fn try_encode_table_provider( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - } - - #[test] - fn scalar_values_error_serialization() { - let should_fail_on_seralize: Vec = vec![ - // Should fail due to empty values - ScalarValue::Struct( - Some(vec![]), - vec![Field::new("item", DataType::Int16, true)].into(), - ), - // Should fail due to inconsistent types in the list - ScalarValue::new_list( - Some(vec![ - ScalarValue::Int16(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::Int16, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - None, - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - None, - DataType::List(new_arc_field( - "lists are typed inconsistently", - DataType::Int16, - true, - )), - ), - ]), - DataType::List(new_arc_field( - "level1", - DataType::List(new_arc_field("level2", DataType::Float32, true)), - true, - )), - ), - ]; - - for test_case in should_fail_on_seralize.into_iter() { - let proto: Result = - (&test_case).try_into(); - - // Validation is also done on read, so if serialization passed - // also try to convert back to ScalarValue - if let Ok(proto) = proto { - let res: Result = (&proto).try_into(); - assert!( - res.is_err(), - "The value {test_case:?} unexpectedly serialized without error:{res:?}" - ); - } - } - } - - #[test] - fn round_trip_scalar_values() { - let should_pass: Vec = vec![ - ScalarValue::Boolean(None), - ScalarValue::Float32(None), - ScalarValue::Float64(None), - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ScalarValue::new_list(None, DataType::Boolean), - ScalarValue::Date32(None), - ScalarValue::Boolean(Some(true)), - ScalarValue::Boolean(Some(false)), - ScalarValue::Float32(Some(1.0)), - ScalarValue::Float32(Some(f32::MAX)), - ScalarValue::Float32(Some(f32::MIN)), - ScalarValue::Float32(Some(-2000.0)), - ScalarValue::Float64(Some(1.0)), - ScalarValue::Float64(Some(f64::MAX)), - ScalarValue::Float64(Some(f64::MIN)), - ScalarValue::Float64(Some(-2000.0)), - ScalarValue::Int8(Some(i8::MIN)), - ScalarValue::Int8(Some(i8::MAX)), - ScalarValue::Int8(Some(0)), - ScalarValue::Int8(Some(-15)), - ScalarValue::Int16(Some(i16::MIN)), - ScalarValue::Int16(Some(i16::MAX)), - ScalarValue::Int16(Some(0)), - ScalarValue::Int16(Some(-15)), - ScalarValue::Int32(Some(i32::MIN)), - ScalarValue::Int32(Some(i32::MAX)), - ScalarValue::Int32(Some(0)), - ScalarValue::Int32(Some(-15)), - ScalarValue::Int64(Some(i64::MIN)), - ScalarValue::Int64(Some(i64::MAX)), - ScalarValue::Int64(Some(0)), - ScalarValue::Int64(Some(-15)), - ScalarValue::UInt8(Some(u8::MAX)), - ScalarValue::UInt8(Some(0)), - ScalarValue::UInt16(Some(u16::MAX)), - ScalarValue::UInt16(Some(0)), - ScalarValue::UInt32(Some(u32::MAX)), - ScalarValue::UInt32(Some(0)), - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::UInt64(Some(0)), - ScalarValue::Utf8(Some(String::from("Test string "))), - ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), - ScalarValue::Date32(Some(0)), - ScalarValue::Date32(Some(i32::MAX)), - ScalarValue::Date32(None), - ScalarValue::Date64(Some(0)), - ScalarValue::Date64(Some(i64::MAX)), - ScalarValue::Date64(None), - ScalarValue::Time32Second(Some(0)), - ScalarValue::Time32Second(Some(i32::MAX)), - ScalarValue::Time32Second(None), - ScalarValue::Time32Millisecond(Some(0)), - ScalarValue::Time32Millisecond(Some(i32::MAX)), - ScalarValue::Time32Millisecond(None), - ScalarValue::Time64Microsecond(Some(0)), - ScalarValue::Time64Microsecond(Some(i64::MAX)), - ScalarValue::Time64Microsecond(None), - ScalarValue::Time64Nanosecond(Some(0)), - ScalarValue::Time64Nanosecond(Some(i64::MAX)), - ScalarValue::Time64Nanosecond(None), - ScalarValue::TimestampNanosecond(Some(0), None), - ScalarValue::TimestampNanosecond(Some(i64::MAX), None), - ScalarValue::TimestampNanosecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampNanosecond(None, None), - ScalarValue::TimestampMicrosecond(Some(0), None), - ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), - ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampMicrosecond(None, None), - ScalarValue::TimestampMillisecond(Some(0), None), - ScalarValue::TimestampMillisecond(Some(i64::MAX), None), - ScalarValue::TimestampMillisecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampMillisecond(None, None), - ScalarValue::TimestampSecond(Some(0), None), - ScalarValue::TimestampSecond(Some(i64::MAX), None), - ScalarValue::TimestampSecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampSecond(None, None), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( - i32::MAX, - i32::MAX, - ))), - ScalarValue::IntervalDayTime(None), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(0, 0, 0), - )), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(1, 2, 3), - )), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(i32::MAX, i32::MAX, i64::MAX), - )), - ScalarValue::IntervalMonthDayNano(None), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list(None, DataType::Float32), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ]), - DataType::List(new_arc_field("item", DataType::Float32, true)), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(None)), - ), - ScalarValue::Binary(Some(b"bar".to_vec())), - ScalarValue::Binary(None), - ScalarValue::LargeBinary(Some(b"bar".to_vec())), - ScalarValue::LargeBinary(None), - ScalarValue::Struct( - Some(vec![ - ScalarValue::Int32(Some(23)), - ScalarValue::Boolean(Some(false)), - ]), - Fields::from(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Boolean, false), - ]), - ), - ScalarValue::Struct( - None, - Fields::from(vec![ - Field::new("a", DataType::Int32, true), - Field::new("a", DataType::Boolean, false), - ]), - ), - ScalarValue::FixedSizeBinary( - b"bar".to_vec().len() as i32, - Some(b"bar".to_vec()), - ), - ScalarValue::FixedSizeBinary(0, None), - ScalarValue::FixedSizeBinary(5, None), - ]; - - for test_case in should_pass.into_iter() { - let proto: super::protobuf::ScalarValue = (&test_case) - .try_into() - .expect("failed conversion to protobuf"); - - let roundtrip: ScalarValue = (&proto) - .try_into() - .expect("failed conversion from protobuf"); - - assert_eq!( - test_case, roundtrip, - "ScalarValue was not the same after round trip!\n\n\ - Input: {test_case:?}\n\nRoundtrip: {roundtrip:?}" - ); - } - } - - #[test] - fn round_trip_scalar_types() { - let should_pass: Vec = vec![ - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float32, - DataType::Float64, - DataType::Date32, - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Utf8, - DataType::LargeUtf8, - // Recursive list tests - DataType::List(new_arc_field("level1", DataType::Boolean, true)), - DataType::List(new_arc_field( - "Level1", - DataType::List(new_arc_field("level2", DataType::Date32, true)), - true, - )), - ]; - - for test_case in should_pass.into_iter() { - let field = Field::new("item", test_case, true); - let proto: super::protobuf::Field = (&field).try_into().unwrap(); - let roundtrip: Field = (&proto).try_into().unwrap(); - assert_eq!(format!("{field:?}"), format!("{roundtrip:?}")); - } - } - - #[test] - fn round_trip_datatype() { - let test_cases: Vec = vec![ - DataType::Null, - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float16, - DataType::Float32, - DataType::Float64, - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Nanosecond), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::DayTime), - DataType::Binary, - DataType::FixedSizeBinary(0), - DataType::FixedSizeBinary(1234), - DataType::FixedSizeBinary(-432), - DataType::LargeBinary, - DataType::Utf8, - DataType::LargeUtf8, - DataType::Decimal128(7, 12), - // Recursive list tests - DataType::List(new_arc_field("Level1", DataType::Binary, true)), - DataType::List(new_arc_field( - "Level1", - DataType::List(new_arc_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), - true, - )), - // Fixed size lists - DataType::FixedSizeList(new_arc_field("Level1", DataType::Binary, true), 4), - DataType::FixedSizeList( - new_arc_field( - "Level1", - DataType::List(new_arc_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), - true, - ), - 41, - ), - // Struct Testing - DataType::Struct(Fields::from(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ])), - DataType::Struct(Fields::from(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new( - "nested_struct", - DataType::Struct(Fields::from(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ])), - true, - ), - ])), - DataType::Union( - UnionFields::new( - vec![7, 5, 3], - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ], - ), - UnionMode::Sparse, - ), - DataType::Union( - UnionFields::new( - vec![5, 8, 1], - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new_struct( - "nested_struct", - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ], - true, - ), - ], - ), - UnionMode::Dense, - ), - DataType::Dictionary( - Box::new(DataType::Utf8), - Box::new(DataType::Struct(Fields::from(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]))), - ), - DataType::Dictionary( - Box::new(DataType::Decimal128(10, 50)), - Box::new(DataType::FixedSizeList( - new_arc_field("Level1", DataType::Binary, true), - 4, - )), - ), - DataType::Map( - new_arc_field( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("keys", DataType::Utf8, false), - Field::new("values", DataType::Int32, true), - ])), - true, - ), - false, - ), - ]; - - for test_case in test_cases.into_iter() { - let proto: super::protobuf::ArrowType = (&test_case).try_into().unwrap(); - let roundtrip: DataType = (&proto).try_into().unwrap(); - assert_eq!(format!("{test_case:?}"), format!("{roundtrip:?}")); - } - } - - #[test] - fn roundtrip_null_scalar_values() { - let test_types = vec![ - ScalarValue::Boolean(None), - ScalarValue::Float32(None), - ScalarValue::Float64(None), - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ScalarValue::Date32(None), - ScalarValue::TimestampMicrosecond(None, None), - ScalarValue::TimestampNanosecond(None, None), - ScalarValue::List( - None, - Arc::new(Field::new("item", DataType::Boolean, false)), - ), - ]; - - for test_case in test_types.into_iter() { - let proto_scalar: super::protobuf::ScalarValue = - (&test_case).try_into().unwrap(); - let returned_scalar: datafusion::scalar::ScalarValue = - (&proto_scalar).try_into().unwrap(); - assert_eq!(format!("{:?}", &test_case), format!("{returned_scalar:?}")); - } - } - - #[test] - fn roundtrip_not() { - let test_expr = Expr::Not(Box::new(lit(1.0_f32))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_is_null() { - let test_expr = Expr::IsNull(Box::new(col("id"))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_is_not_null() { - let test_expr = Expr::IsNotNull(Box::new(col("id"))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_between() { - let test_expr = Expr::Between(Between::new( - Box::new(lit(1.0_f32)), - true, - Box::new(lit(2.0_f32)), - Box::new(lit(3.0_f32)), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_binary_op() { - fn test(op: Operator) { - let test_expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(lit(1.0_f32)), - op, - Box::new(lit(2.0_f32)), - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - test(Operator::StringConcat); - test(Operator::RegexNotIMatch); - test(Operator::RegexNotMatch); - test(Operator::RegexIMatch); - test(Operator::RegexMatch); - test(Operator::BitwiseShiftRight); - test(Operator::BitwiseShiftLeft); - test(Operator::BitwiseAnd); - test(Operator::BitwiseOr); - test(Operator::BitwiseXor); - test(Operator::IsDistinctFrom); - test(Operator::IsNotDistinctFrom); - test(Operator::And); - test(Operator::Or); - test(Operator::Eq); - test(Operator::NotEq); - test(Operator::Lt); - test(Operator::LtEq); - test(Operator::Gt); - test(Operator::GtEq); - } - - #[test] - fn roundtrip_case() { - let test_expr = Expr::Case(Case::new( - Some(Box::new(lit(1.0_f32))), - vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(lit(4.0_f32))), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_case_with_null() { - let test_expr = Expr::Case(Case::new( - Some(Box::new(lit(1.0_f32))), - vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_try_cast() { - let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - - let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_sort_expr() { - let test_expr = Expr::Sort(Sort::new(Box::new(lit(1.0_f32)), true, true)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_negative() { - let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_inlist() { - let test_expr = Expr::InList(InList::new( - Box::new(lit(1.0_f32)), - vec![lit(2.0_f32)], - true, - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_sqrt() { - let test_expr = Expr::ScalarFunction(ScalarFunction::new(Sqrt, vec![col("col")])); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_like() { - fn like(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - like(true, Some('X')); - like(false, Some('\\')); - like(true, None); - like(false, None); - } - - #[test] - fn roundtrip_ilike() { - fn ilike(negated: bool, escape_char: Option) { - let test_expr = Expr::ILike(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - ilike(true, Some('X')); - ilike(false, Some('\\')); - ilike(true, None); - ilike(false, None); - } - - #[test] - fn roundtrip_similar_to() { - fn similar_to(negated: bool, escape_char: Option) { - let test_expr = Expr::SimilarTo(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - similar_to(true, Some('X')); - similar_to(false, Some('\\')); - similar_to(true, None); - similar_to(false, None); - } - - #[test] - fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ApproxPercentileCont, - vec![col("bananas"), lit(0.42_f32)], - false, - None, - None, - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_aggregate_udf() { - #[derive(Debug)] - struct Dummy {} - - impl Accumulator for Dummy { - fn state(&self) -> datafusion::error::Result> { - Ok(vec![]) - } - - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { - Ok(()) - } - - fn merge_batch( - &mut self, - _states: &[ArrayRef], - ) -> datafusion::error::Result<()> { - Ok(()) - } - - fn evaluate(&self) -> datafusion::error::Result { - Ok(ScalarValue::Float64(None)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - } - - let dummy_agg = create_udaf( - // the name; used to represent it in plan descriptions and in the registry, to use in SQL. - "dummy_agg", - // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, - // the return type; DataFusion expects this to match the type returned by `evaluate`. - Arc::new(DataType::Float64), - Volatility::Immutable, - // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), - // This is the description of the state. `state()` must match the types here. - Arc::new(vec![DataType::Float64, DataType::UInt32]), - ); - - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( - Arc::new(dummy_agg.clone()), - vec![lit(1.0_f64)], - Some(Box::new(lit(true))), - None, - )); - - let ctx = SessionContext::new(); - ctx.register_udaf(dummy_agg); - - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_scalar_udf() { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); - - let udf = create_udf( - "dummy", - vec![DataType::Utf8], - Arc::new(DataType::Utf8), - Volatility::Immutable, - scalar_fn, - ); - - let test_expr = - Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")])); - - let ctx = SessionContext::new(); - ctx.register_udf(udf); - - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_grouping_sets() { - let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ - vec![col("a")], - vec![col("b")], - vec![col("a"), col("b")], - ])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_rollup() { - let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_cube() { - let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_substr() { - // substr(string, position) - let test_expr = Expr::ScalarFunction(ScalarFunction::new( - Substr, - vec![col("col"), lit(1_i64)], - )); - - // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( - Substr, - vec![col("col"), lit(1_i64), lit(1_i64)], - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx.clone()); - roundtrip_expr_test(test_expr_with_count, ctx); - } - #[test] - fn roundtrip_window() { - let ctx = SessionContext::new(); - - // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(true), - )); - - // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(true), - )); - - // 3. with window_frame with row numbers - let range_number_frame = WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), - end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), - }; - - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - )); - - // 4. test with AggregateFunction - let row_number_frame = WindowFrame { - units: WindowFrameUnits::Rows, - start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), - end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), - }; - - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame, - )); - - roundtrip_expr_test(test_expr1, ctx.clone()); - roundtrip_expr_test(test_expr2, ctx.clone()); - roundtrip_expr_test(test_expr3, ctx.clone()); - roundtrip_expr_test(test_expr4, ctx); - } -} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index dbf9432bdc2dc..2997d147424d8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -24,20 +24,29 @@ use crate::protobuf::{ arrow_type::ArrowTypeEnum, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, - UnionMode, +use arrow::{ + datatypes::{ + DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, + TimeUnit, UnionMode, + }, + ipc::writer::{DictionaryTracker, IpcDataGenerator}, + record_batch::RecordBatch, +}; +use datafusion_common::{ + Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, OwnedTableReference, + ScalarValue, }; -use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference, ScalarValue}; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, InList, Like, - Placeholder, ScalarFunction, ScalarUDF, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -49,13 +58,6 @@ use datafusion_expr::{ pub enum Error { General(String), - InconsistentListTyping(DataType, DataType), - - InconsistentListDesignated { - value: ScalarValue, - designated: DataType, - }, - InvalidScalarValue(ScalarValue), InvalidScalarType(DataType), @@ -73,18 +75,6 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::General(desc) => write!(f, "General error: {desc}"), - Self::InconsistentListTyping(type1, type2) => { - write!( - f, - "Lists with inconsistent typing; {type1:?} and {type2:?} found within list", - ) - } - Self::InconsistentListDesignated { value, designated } => { - write!( - f, - "Value {value:?} was inconsistent with designated type {designated:?}" - ) - } Self::InvalidScalarValue(value) => { write!(f, "{value:?} is invalid as a DataFusion scalar value") } @@ -117,6 +107,9 @@ impl TryFrom<&Field> for protobuf::Field { arrow_type: Some(Box::new(arrow_type)), nullable: field.is_nullable(), children: Vec::new(), + metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } @@ -266,6 +259,7 @@ impl TryFrom<&Schema> for protobuf::Schema { .iter() .map(|f| f.as_ref().try_into()) .collect::, Error>>()?, + metadata: schema.metadata.clone(), }) } } @@ -280,6 +274,7 @@ impl TryFrom for protobuf::Schema { .iter() .map(|f| f.as_ref().try_into()) .collect::, Error>>()?, + metadata: schema.metadata.clone(), }) } } @@ -297,10 +292,10 @@ impl TryFrom<&DFField> for protobuf::DfField { } } -impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { +impl TryFrom<&DFSchema> for protobuf::DfSchema { type Error = Error; - fn try_from(s: &DFSchemaRef) -> Result { + fn try_from(s: &DFSchema) -> Result { let columns = s .fields() .iter() @@ -313,6 +308,14 @@ impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { } } +impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { + type Error = Error; + + fn try_from(s: &DFSchemaRef) -> Result { + s.as_ref().try_into() + } +} + impl From<&StringifiedPlan> for protobuf::StringifiedPlan { fn from(stringified_plan: &StringifiedPlan) -> Self { Self { @@ -353,6 +356,12 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::FinalPhysicalPlan => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlan(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), + }), + PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } @@ -381,6 +390,15 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::RegrSlope => Self::RegrSlope, + AggregateFunction::RegrIntercept => Self::RegrIntercept, + AggregateFunction::RegrCount => Self::RegrCount, + AggregateFunction::RegrR2 => Self::RegrR2, + AggregateFunction::RegrAvgx => Self::RegrAvgx, + AggregateFunction::RegrAvgy => Self::RegrAvgy, + AggregateFunction::RegrSXX => Self::RegrSxx, + AggregateFunction::RegrSYY => Self::RegrSyy, + AggregateFunction::RegrSXY => Self::RegrSxy, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -390,6 +408,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -468,10 +487,18 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(expr, alias) => { + Expr::Alias(Alias { + expr, + relation, + name, + }) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(expr.as_ref().try_into()?)), - alias: alias.to_owned(), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), + alias: name.to_owned(), }); Self { expr_type: Some(ExprType::Alias(alias)), @@ -523,31 +550,34 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr, pattern, escape_char, + case_insensitive, }) => { - let pb = Box::new(protobuf::LikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), - }); - Self { - expr_type: Some(ExprType::Like(pb)), - } - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let pb = Box::new(protobuf::ILikeNode { - negated: *negated, - expr: Some(Box::new(expr.as_ref().try_into()?)), - pattern: Some(Box::new(pattern.as_ref().try_into()?)), - escape_char: escape_char.map(|ch| ch.to_string()).unwrap_or_default(), - }); - Self { - expr_type: Some(ExprType::Ilike(pb)), + if *case_insensitive { + let pb = Box::new(protobuf::ILikeNode { + negated: *negated, + expr: Some(Box::new(expr.as_ref().try_into()?)), + pattern: Some(Box::new(pattern.as_ref().try_into()?)), + escape_char: escape_char + .map(|ch| ch.to_string()) + .unwrap_or_default(), + }); + + Self { + expr_type: Some(ExprType::Ilike(pb)), + } + } else { + let pb = Box::new(protobuf::LikeNode { + negated: *negated, + expr: Some(Box::new(expr.as_ref().try_into()?)), + pattern: Some(Box::new(pattern.as_ref().try_into()?)), + escape_char: escape_char + .map(|ch| ch.to_string()) + .unwrap_or_default(), + }); + + Self { + expr_type: Some(ExprType::Like(pb)), + } } } Expr::SimilarTo(Like { @@ -555,6 +585,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr, pattern, escape_char, + case_insensitive: _, }) => { let pb = Box::new(protobuf::SimilarToNode { negated: *negated, @@ -584,11 +615,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { protobuf::BuiltInWindowFunction::from(fun).into(), ) } - // TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584 - WindowFunction::AggregateUDF(_) => { - return Err(Error::NotImplemented( - "UDAF as window function in proto".to_string(), - )) + WindowFunction::AggregateUDF(aggr_udf) => { + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ) + } + WindowFunction::WindowUDF(window_udf) => { + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ) } }; let arg_expr: Option> = if !args.is_empty() { @@ -620,144 +655,178 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let args = args .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, - }, - )), + .map(|expr| expr.try_into()) + .collect::, Error>>()?; + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), + args, + }, + )), + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } - Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name.clone(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - })), - }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name.clone(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), @@ -929,8 +998,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard => Self { - expr_type: Some(ExprType::Wildcard(true)), + Expr::Wildcard { qualifier } => Self { + expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { + qualifier: qualifier.clone(), + })), }, Expr::ScalarSubquery(_) | Expr::InSubquery(_) @@ -940,14 +1011,41 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - key: Some(key.try_into()?), - expr: Some(Box::new(expr.as_ref().try_into()?)), - }, - ))), - }, + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + protobuf::get_indexed_field::Field::NamedStructField( + protobuf::NamedStructField { + name: Some(name.try_into()?), + }, + ) + } + GetFieldAccess::ListIndex { key } => { + protobuf::get_indexed_field::Field::ListIndex(Box::new( + protobuf::ListIndex { + key: Some(Box::new(key.as_ref().try_into()?)), + }, + )) + } + GetFieldAccess::ListRange { start, stop } => { + protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { + start: Some(Box::new(start.as_ref().try_into()?)), + stop: Some(Box::new(stop.as_ref().try_into()?)), + }, + )) + } + }; + + Self { + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + expr: Some(Box::new(expr.as_ref().try_into()?)), + field: Some(field), + }, + ))), + } + } Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { expr_type: Some(ExprType::Cube(CubeNode { @@ -994,11 +1092,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { })), } } - - Expr::QualifiedWildcard { .. } => return Err(Error::General( - "Proto serialization error: Expr::QualifiedWildcard { .. } not supported" - .to_string(), - )), }; Ok(expr_node) @@ -1009,90 +1102,116 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { type Error = Error; fn try_from(val: &ScalarValue) -> Result { - use datafusion_common::scalar; use protobuf::scalar_value::Value; - let data_type = val.get_datatype(); + let data_type = val.data_type(); match val { - scalar::ScalarValue::Boolean(val) => { + ScalarValue::Boolean(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) } - scalar::ScalarValue::Float32(val) => { + ScalarValue::Float32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) } - scalar::ScalarValue::Float64(val) => { + ScalarValue::Float64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float64Value(*s)) } - scalar::ScalarValue::Int8(val) => { + ScalarValue::Int8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Int8Value(*s as i32) }) } - scalar::ScalarValue::Int16(val) => { + ScalarValue::Int16(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Int16Value(*s as i32) }) } - scalar::ScalarValue::Int32(val) => { + ScalarValue::Int32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int32Value(*s)) } - scalar::ScalarValue::Int64(val) => { + ScalarValue::Int64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int64Value(*s)) } - scalar::ScalarValue::UInt8(val) => { + ScalarValue::UInt8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Uint8Value(*s as u32) }) } - scalar::ScalarValue::UInt16(val) => { + ScalarValue::UInt16(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Uint16Value(*s as u32) }) } - scalar::ScalarValue::UInt32(val) => { + ScalarValue::UInt32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint32Value(*s)) } - scalar::ScalarValue::UInt64(val) => { + ScalarValue::UInt64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint64Value(*s)) } - scalar::ScalarValue::Utf8(val) => { + ScalarValue::Utf8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::Utf8Value(s.to_owned()) }) } - scalar::ScalarValue::LargeUtf8(val) => { + ScalarValue::LargeUtf8(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::LargeUtf8Value(s.to_owned()) }) } - scalar::ScalarValue::List(values, boxed_field) => { - let is_null = values.is_none(); + // ScalarValue::List and ScalarValue::FixedSizeList are serialized using + // Arrow IPC messages as a single column RecordBatch + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { + // Wrap in a "field_name" column + let batch = RecordBatch::try_from_iter(vec![( + "field_name", + arr.to_owned(), + )]) + .map_err(|e| { + Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!( + "Error encoding ScalarValue::List as IPC: {e}" + )) + })?; - let values = if let Some(values) = values.as_ref() { - values - .iter() - .map(|v| v.try_into()) - .collect::, _>>()? - } else { - vec![] - }; + let schema: protobuf::Schema = batch.schema().try_into()?; - let field = boxed_field.as_ref().try_into()?; + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; - Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - protobuf::ScalarListValue { - is_null, - field: Some(field), - values, - }, - )), - }) + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } } - datafusion::scalar::ScalarValue::Date32(val) => { + ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) } - datafusion::scalar::ScalarValue::TimestampMicrosecond(val, tz) => { + ScalarValue::TimestampMicrosecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1104,7 +1223,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::TimestampNanosecond(val, tz) => { + ScalarValue::TimestampNanosecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1116,7 +1235,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::Decimal128(val, p, s) => match *val { + ScalarValue::Decimal128(val, p, s) => match *val { Some(v) => { let array = v.to_be_bytes(); let vec_val: Vec = array.to_vec(); @@ -1134,10 +1253,28 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }), }, - datafusion::scalar::ScalarValue::Date64(val) => { + ScalarValue::Decimal256(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal256Value(protobuf::Decimal256 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, + ScalarValue::Date64(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) } - datafusion::scalar::ScalarValue::TimestampSecond(val, tz) => { + ScalarValue::TimestampSecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1147,7 +1284,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::TimestampMillisecond(val, tz) => { + ScalarValue::TimestampMillisecond(val, tz) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::TimestampValue(protobuf::ScalarTimestampValue { timezone: tz.as_deref().unwrap_or("").to_string(), @@ -1159,31 +1296,31 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } - datafusion::scalar::ScalarValue::IntervalYearMonth(val) => { + ScalarValue::IntervalYearMonth(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::IntervalYearmonthValue(*s) }) } - datafusion::scalar::ScalarValue::IntervalDayTime(val) => { + ScalarValue::IntervalDayTime(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::IntervalDaytimeValue(*s) }) } - datafusion::scalar::ScalarValue::Null => Ok(protobuf::ScalarValue { + ScalarValue::Null => Ok(protobuf::ScalarValue { value: Some(Value::NullValue((&data_type).try_into()?)), }), - scalar::ScalarValue::Binary(val) => { + ScalarValue::Binary(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::BinaryValue(s.to_owned()) }) } - scalar::ScalarValue::LargeBinary(val) => { + ScalarValue::LargeBinary(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::LargeBinaryValue(s.to_owned()) }) } - scalar::ScalarValue::FixedSizeBinary(length, val) => { + ScalarValue::FixedSizeBinary(length, val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { values: s.to_owned(), @@ -1192,7 +1329,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time32Second(v) => { + ScalarValue::Time32Second(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time32Value(protobuf::ScalarTime32Value { value: Some( @@ -1202,7 +1339,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time32Millisecond(v) => { + ScalarValue::Time32Millisecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time32Value(protobuf::ScalarTime32Value { value: Some( @@ -1214,7 +1351,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time64Microsecond(v) => { + ScalarValue::Time64Microsecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time64Value(protobuf::ScalarTime64Value { value: Some( @@ -1226,7 +1363,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Time64Nanosecond(v) => { + ScalarValue::Time64Nanosecond(v) => { create_proto_scalar(v.as_ref(), &data_type, |v| { Value::Time64Value(protobuf::ScalarTime64Value { value: Some( @@ -1238,7 +1375,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::IntervalMonthDayNano(v) => { + ScalarValue::IntervalMonthDayNano(v) => { let value = if let Some(v) = v { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); Value::IntervalMonthDayNano(protobuf::IntervalMonthDayNanoValue { @@ -1247,13 +1384,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { nanos, }) } else { - protobuf::scalar_value::Value::NullValue((&data_type).try_into()?) + Value::NullValue((&data_type).try_into()?) }; Ok(protobuf::ScalarValue { value: Some(value) }) } - datafusion::scalar::ScalarValue::Struct(values, fields) => { + ScalarValue::DurationSecond(v) => { + let value = match v { + Some(v) => Value::DurationSecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMillisecond(v) => { + let value = match v { + Some(v) => Value::DurationMillisecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMicrosecond(v) => { + let value = match v { + Some(v) => Value::DurationMicrosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationNanosecond(v) => { + let value = match v { + Some(v) => Value::DurationNanosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + + ScalarValue::Struct(values, fields) => { // encode null as empty field values list let field_values = if let Some(values) = values { if values.is_empty() { @@ -1280,7 +1446,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) } - datafusion::scalar::ScalarValue::Dictionary(index_type, val) => { + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { value: Some(Value::DictionaryValue(Box::new( @@ -1305,6 +1471,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Sin => Self::Sin, BuiltinScalarFunction::Cos => Self::Cos, BuiltinScalarFunction::Tan => Self::Tan, + BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, BuiltinScalarFunction::Tanh => Self::Tanh, @@ -1337,20 +1504,38 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, + BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, + BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, + BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, + BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, + BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, - BuiltinScalarFunction::ArrayFill => Self::ArrayFill, + BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, + BuiltinScalarFunction::ArrayElement => Self::ArrayElement, + BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, + BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, + BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend, + BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat, BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove, + BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, + BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, + BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, + BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, + BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion, + BuiltinScalarFunction::Range => Self::Range, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, - BuiltinScalarFunction::TrimArray => Self::TrimArray, BuiltinScalarFunction::NullIf => Self::NullIf, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, @@ -1361,6 +1546,8 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::SHA384 => Self::Sha384, BuiltinScalarFunction::SHA512 => Self::Sha512, BuiltinScalarFunction::Digest => Self::Digest, + BuiltinScalarFunction::Decode => Self::Decode, + BuiltinScalarFunction::Encode => Self::Encode, BuiltinScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, @@ -1382,11 +1569,13 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Right => Self::Right, BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::SplitPart => Self::SplitPart, + BuiltinScalarFunction::StringToArray => Self::StringToArray, BuiltinScalarFunction::StartsWith => Self::StartsWith, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::ToHex => Self::ToHex, BuiltinScalarFunction::ToTimestampMicros => Self::ToTimestampMicros, + BuiltinScalarFunction::ToTimestampNanos => Self::ToTimestampNanos, BuiltinScalarFunction::ToTimestampSeconds => Self::ToTimestampSeconds, BuiltinScalarFunction::Now => Self::Now, BuiltinScalarFunction::CurrentDate => Self::CurrentDate, @@ -1399,7 +1588,14 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, + BuiltinScalarFunction::Nanvl => Self::Nanvl, + BuiltinScalarFunction::Isnan => Self::Isnan, + BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, + BuiltinScalarFunction::OverLay => Self::OverLay, + BuiltinScalarFunction::Levenshtein => Self::Levenshtein, + BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, + BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) @@ -1483,6 +1679,35 @@ impl From for protobuf::JoinConstraint { } } +impl From for protobuf::Constraints { + fn from(value: Constraints) -> Self { + let constraints = value.into_iter().map(|item| item.into()).collect(); + protobuf::Constraints { constraints } + } +} + +impl From for protobuf::Constraint { + fn from(value: Constraint) -> Self { + let res = match value { + Constraint::PrimaryKey(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + Constraint::Unique(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + }; + protobuf::Constraint { + constraint_mode: Some(res), + } + } +} + /// Creates a scalar protobuf value from an optional value (T), and /// encoding None as the appropriate datatype fn create_proto_scalar protobuf::scalar_value::Value>( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7a52e5f0d09fd..dcebfbf2dabbd 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -17,43 +17,44 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use crate::protobuf; -use arrow::datatypes::DataType; -use chrono::TimeZone; -use chrono::Utc; +use std::convert::{TryFrom, TryInto}; +use std::sync::Arc; + +use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::window_function::WindowFunction; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - date_time_interval_expr, GetIndexedFieldExpr, + in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, + Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::expressions::{in_list, LikeExpr}; +use datafusion::physical_plan::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; +use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ - expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal, - NegativeExpr, NotExpr, TryCastExpr, - }, - functions, Partitioning, + functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, +}; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::stats::Precision; +use datafusion_common::{ + not_impl_err, DataFusionError, FileTypeWriterOptions, JoinSide, Result, ScalarValue, }; -use datafusion::physical_plan::{ColumnStatistics, PhysicalExpr, Statistics}; -use datafusion_common::{DataFusionError, Result}; -use object_store::path::Path; -use object_store::ObjectMeta; -use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; -use std::sync::Arc; use crate::common::proto_error; use crate::convert_required; use crate::logical_plan; +use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use datafusion::physical_plan::joins::utils::JoinSide; -use datafusion::physical_plan::sorts::sort::SortOptions; + +use chrono::{TimeZone, Utc}; +use object_store::path::Path; +use object_store::ObjectMeta; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -86,6 +87,61 @@ pub fn parse_physical_sort_expr( } } +/// Parses a physical window expr from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical window exprression node. +/// * `name` - Name of the window expression. +/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +pub fn parse_physical_window_expr( + proto: &protobuf::PhysicalWindowExprNode, + registry: &dyn FunctionRegistry, + input_schema: &Schema, +) -> Result> { + let window_node_expr = proto + .args + .iter() + .map(|e| parse_physical_expr(e, registry, input_schema)) + .collect::>>()?; + + let partition_by = proto + .partition_by + .iter() + .map(|p| parse_physical_expr(p, registry, input_schema)) + .collect::>>()?; + + let order_by = proto + .order_by + .iter() + .map(|o| parse_physical_sort_expr(o, registry, input_schema)) + .collect::>>()?; + + let window_frame = proto + .window_frame + .as_ref() + .map(|wf| wf.clone().try_into()) + .transpose() + .map_err(|e| DataFusionError::Internal(format!("{e}")))? + .ok_or_else(|| { + DataFusionError::Internal( + "Missing required field 'window_frame' in protobuf".to_string(), + ) + })?; + + create_window_expr( + &convert_required!(proto.window_function)?, + proto.name.clone(), + &window_node_expr, + &partition_by, + &order_by, + Arc::new(window_frame), + input_schema, + ) +} + /// Parses a physical expression from a protobuf. /// /// # Arguments @@ -125,36 +181,18 @@ pub fn parse_physical_expr( input_schema, )?, )), - ExprType::DateTimeIntervalExpr(expr) => date_time_interval_expr( - parse_required_physical_expr( - expr.l.as_deref(), - registry, - "left", - input_schema, - )?, - logical_plan::from_proto::from_proto_binary_op(&expr.op)?, - parse_required_physical_expr( - expr.r.as_deref(), - registry, - "right", - input_schema, - )?, - input_schema, - )?, ExprType::AggregateExpr(_) => { - return Err(DataFusionError::NotImplemented( - "Cannot convert aggregate expr node to physical expression".to_owned(), - )); + return not_impl_err!( + "Cannot convert aggregate expr node to physical expression" + ); } ExprType::WindowExpr(_) => { - return Err(DataFusionError::NotImplemented( - "Cannot convert window expr node to physical expression".to_owned(), - )); + return not_impl_err!( + "Cannot convert window expr node to physical expression" + ); } ExprType::Sort(_) => { - return Err(DataFusionError::NotImplemented( - "Cannot convert sort expr node to physical expression".to_owned(), - )); + return not_impl_err!("Cannot convert sort expr node to physical expression"); } ExprType::IsNullExpr(e) => { Arc::new(IsNullExpr::new(parse_required_physical_expr( @@ -250,7 +288,7 @@ pub fn parse_physical_expr( )), ExprType::ScalarFunction(e) => { let scalar_function = - protobuf::ScalarFunction::from_i32(e.fun).ok_or_else(|| { + protobuf::ScalarFunction::try_from(e.fun).map_err(|_| { proto_error( format!("Received an unknown scalar function: {}", e.fun,), ) @@ -274,11 +312,12 @@ pub fn parse_physical_expr( &e.name, fun_expr, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, + None, )) } ExprType::ScalarUdf(e) => { - let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun; + let scalar_fun = registry.udf(e.name.as_str())?.fun().clone(); let args = e .args @@ -290,7 +329,8 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, + None, )) } ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new( @@ -310,6 +350,36 @@ pub fn parse_physical_expr( )?, )), ExprType::GetIndexedFieldExpr(get_indexed_field_expr) => { + let field = match &get_indexed_field_expr.field { + Some(protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(named_struct_field_expr)) => GetFieldAccessExpr::NamedStructField{ + name: convert_required!(named_struct_field_expr.name)?, + }, + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListIndexExpr(list_index_expr)) => GetFieldAccessExpr::ListIndex{ + key: parse_required_physical_expr( + list_index_expr.key.as_deref(), + registry, + "key", + input_schema, + )?}, + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(list_range_expr)) => GetFieldAccessExpr::ListRange{ + start: parse_required_physical_expr( + list_range_expr.start.as_deref(), + registry, + "start", + input_schema, + )?, + stop: parse_required_physical_expr( + list_range_expr.stop.as_deref(), + registry, + "stop", + input_schema + )?, + }, + None => return Err(proto_error( + "Field must not be None", + )), + }; + Arc::new(GetIndexedFieldExpr::new( parse_required_physical_expr( get_indexed_field_expr.arg.as_deref(), @@ -317,7 +387,7 @@ pub fn parse_physical_expr( "arg", input_schema, )?, - convert_required!(get_indexed_field_expr.key)?, + field, )) } }; @@ -346,7 +416,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun ) -> Result { match expr { protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { - let f = protobuf::AggregateFunction::from_i32(*n).ok_or_else(|| { + let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { proto_error(format!( "Received an unknown window aggregate function: {n}" )) @@ -355,12 +425,11 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun Ok(WindowFunction::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { - let f = - protobuf::BuiltInWindowFunction::from_i32(*n).ok_or_else(|| { - proto_error(format!( - "Received an unknown window builtin function: {n}" - )) - })?; + let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { + proto_error(format!( + "Received an unknown window builtin function: {n}" + )) + })?; Ok(WindowFunction::BuiltInWindowFunction(f.into())) } @@ -422,13 +491,8 @@ pub fn parse_protobuf_file_scan_config( let table_partition_cols = proto .table_partition_cols .iter() - .map(|col| { - Ok(( - col.to_owned(), - schema.field_with_name(col)?.data_type().clone(), - )) - }) - .collect::>>()?; + .map(|col| Ok(schema.field_with_name(col)?.clone())) + .collect::>>()?; let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { @@ -476,6 +540,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), size: val.size as usize, e_tag: None, + version: None, }, partition_values: val .partition_values @@ -513,10 +578,96 @@ impl TryFrom<&protobuf::FileGroup> for Vec { impl From<&protobuf::ColumnStats> for ColumnStatistics { fn from(cs: &protobuf::ColumnStats) -> ColumnStatistics { ColumnStatistics { - null_count: Some(cs.null_count as usize), - max_value: cs.max_value.as_ref().map(|m| m.try_into().unwrap()), - min_value: cs.min_value.as_ref().map(|m| m.try_into().unwrap()), - distinct_count: Some(cs.distinct_count as usize), + null_count: if let Some(nc) = &cs.null_count { + nc.clone().into() + } else { + Precision::Absent + }, + max_value: if let Some(max) = &cs.max_value { + max.clone().into() + } else { + Precision::Absent + }, + min_value: if let Some(min) = &cs.min_value { + min.clone().into() + } else { + Precision::Absent + }, + distinct_count: if let Some(dc) = &cs.distinct_count { + dc.clone().into() + } else { + Precision::Absent + }, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Exact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Inexact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Exact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Inexact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, } } } @@ -535,27 +686,91 @@ impl TryFrom<&protobuf::Statistics> for Statistics { fn try_from(s: &protobuf::Statistics) -> Result { // Keep it sync with Statistics::to_proto - let none_value = -1_i64; - let column_statistics = - s.column_stats.iter().map(|s| s.into()).collect::>(); Ok(Statistics { - num_rows: if s.num_rows == none_value { - None + num_rows: if let Some(nr) = &s.num_rows { + nr.clone().into() } else { - Some(s.num_rows as usize) + Precision::Absent }, - total_byte_size: if s.total_byte_size == none_value { - None + total_byte_size: if let Some(tbs) = &s.total_byte_size { + tbs.clone().into() } else { - Some(s.total_byte_size as usize) + Precision::Absent }, // No column statistic (None) is encoded with empty array - column_statistics: if column_statistics.is_empty() { - None - } else { - Some(column_statistics) - }, - is_exact: s.is_exact, + column_statistics: s.column_stats.iter().map(|s| s.into()).collect(), }) } } + +impl TryFrom<&protobuf::JsonSink> for JsonSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::JsonSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &protobuf::FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ListingTableUrl::parse) + .collect::>>()?; + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|protobuf::PartitionColumn { name, arrow_type }| { + let data_type = convert_required!(arrow_type)?; + Ok((name.clone(), data_type)) + }) + .collect::>>()?; + Ok(Self { + object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, + file_groups, + table_paths, + output_schema: Arc::new(convert_required!(conf.output_schema)?), + table_partition_cols, + single_file_output: conf.single_file_output, + unbounded_input: conf.unbounded_input, + overwrite: conf.overwrite, + file_type_writer_options: convert_required!(conf.file_type_writer_options)?, + }) + } +} + +impl From for CompressionTypeVariant { + fn from(value: protobuf::CompressionTypeVariant) -> Self { + match value { + protobuf::CompressionTypeVariant::Gzip => Self::GZIP, + protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, + protobuf::CompressionTypeVariant::Xz => Self::XZ, + protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, + protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, + } + } +} + +impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(value: &protobuf::FileTypeWriterOptions) -> Result { + let file_type = value + .file_type + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))?; + match file_type { + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( + Self::JSON(JsonWriterOptions::new(opts.compression().into())), + ), + } + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 3c14981355ec0..73091a6fced9f 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,38 +21,46 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::datasource::file_format::file_type::FileCompressionType; -use datafusion::datasource::physical_plan::{AvroExec, CsvExec, ParquetExec}; +use datafusion::datasource::file_format::file_compression_type::FileCompressionType; +use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::WindowFrame; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::joins::CrossJoinExec; +use datafusion::physical_plan::joins::{ + CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, + udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, }; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use prost::bytes::BufMut; use prost::Message; -use crate::common::proto_error; -use crate::common::{csv_delimiter_to_string, str_to_byte}; +use crate::common::str_to_byte; +use crate::common::{byte_to_string, proto_error}; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_protobuf_file_scan_config, }; @@ -60,9 +68,13 @@ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, PhysicalPlanNode}; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; use crate::{convert_required, into_required}; +use self::from_proto::parse_physical_window_expr; + pub mod from_proto; pub mod to_proto; @@ -147,7 +159,16 @@ impl AsExecutionPlan for PhysicalPlanNode { .to_owned(), ) })?; - Ok(Arc::new(FilterExec::try_new(predicate, input)?)) + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let filter = FilterExec::try_new(predicate, input)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(DataFusionError::Internal( + "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), + )), + } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( parse_protobuf_file_scan_config( @@ -155,9 +176,19 @@ impl AsExecutionPlan for PhysicalPlanNode { registry, )?, scan.has_header, - str_to_byte(&scan.delimiter)?, + str_to_byte(&scan.delimiter, "delimiter")?, + str_to_byte(&scan.quote, "quote")?, + if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( + escape, + )) = &scan.optional_escape + { + Some(str_to_byte(escape, "escape")?) + } else { + None + }, FileCompressionType::UNCOMPRESSED, ))), + #[cfg(feature = "parquet")] PhysicalPlanType::ParquetScan(scan) => { let base_config = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), @@ -240,9 +271,7 @@ impl AsExecutionPlan for PhysicalPlanNode { ), )?)) } - _ => Err(DataFusionError::Internal( - "Invalid partitioning scheme".to_owned(), - )), + _ => internal_err!("Invalid partitioning scheme"), } } PhysicalPlanType::GlobalLimit(limit) => { @@ -271,70 +300,56 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; - let input_schema = window_agg - .input_schema - .as_ref() - .ok_or_else(|| { - DataFusionError::Internal( - "input_schema in WindowAggrNode is missing.".to_owned(), - ) - })? - .clone(); - let physical_schema: SchemaRef = - SchemaRef::new((&input_schema).try_into()?); + let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg .window_expr .iter() - .zip(window_agg.window_expr_name.iter()) - .map(|(expr, name)| { - let expr_type = expr.expr_type.as_ref().ok_or_else(|| { - proto_error("Unexpected empty window physical expression") - })?; - - match expr_type { - ExprType::WindowExpr(window_node) => { - let window_node_expr = window_node - .expr - .as_ref() - .map(|e| { - parse_physical_expr( - e.as_ref(), - registry, - &physical_schema, - ) - }) - .transpose()? - .ok_or_else(|| { - proto_error( - "missing window_node expr expression" - .to_string(), - ) - })?; - - Ok(create_window_expr( - &convert_required!(window_node.window_function)?, - name.to_owned(), - &[window_node_expr], - &[], - &[], - Arc::new(WindowFrame::new(false)), - &physical_schema, - )?) - } - _ => Err(DataFusionError::Internal( - "Invalid expression for WindowAggrExec".to_string(), - )), - } + .map(|window_expr| { + parse_physical_window_expr( + window_expr, + registry, + input_schema.as_ref(), + ) }) .collect::, _>>()?; - //todo fill partition keys and sort keys - Ok(Arc::new(WindowAggExec::try_new( - physical_window_expr, - input, - Arc::new((&input_schema).try_into()?), - vec![], - )?)) + + let partition_keys = window_agg + .partition_keys + .iter() + .map(|expr| { + parse_physical_expr(expr, registry, input.schema().as_ref()) + }) + .collect::>>>()?; + + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => { + InputOrderMode::Linear + } + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( + columns.iter().map(|c| *c as usize).collect(), + ), + window_agg_exec_node::InputOrderMode::Sorted(_) => { + InputOrderMode::Sorted + } + }; + + Ok(Arc::new(BoundedWindowAggExec::try_new( + physical_window_expr, + input, + partition_keys, + input_order_mode, + )?)) + } else { + Ok(Arc::new(WindowAggExec::try_new( + physical_window_expr, + input, + partition_keys, + )?)) + } } PhysicalPlanType::Aggregate(hash_agg) => { let input: Arc = into_physical_plan( @@ -343,8 +358,8 @@ impl AsExecutionPlan for PhysicalPlanNode { runtime, extension_codec, )?; - let mode = protobuf::AggregateMode::from_i32(hash_agg.mode).ok_or_else( - || { + let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err( + |_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", hash_agg.mode @@ -358,6 +373,9 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateMode::FinalPartitioned } protobuf::AggregateMode::Single => AggregateMode::Single, + protobuf::AggregateMode::SinglePartitioned => { + AggregateMode::SinglePartitioned + } }; let num_expr = hash_agg.group_expr.len(); @@ -392,17 +410,12 @@ impl AsExecutionPlan for PhysicalPlanNode { vec![] }; - let input_schema = hash_agg - .input_schema - .as_ref() - .ok_or_else(|| { - DataFusionError::Internal( - "input_schema in AggregateNode is missing.".to_owned(), - ) - })? - .clone(); - let physical_schema: SchemaRef = - SchemaRef::new((&input_schema).try_into()?); + let input_schema = hash_agg.input_schema.as_ref().ok_or_else(|| { + DataFusionError::Internal( + "input_schema in AggregateNode is missing.".to_owned(), + ) + })?; + let physical_schema: SchemaRef = SchemaRef::new(input_schema.try_into()?); let physical_filter_expr = hash_agg .filter_expr @@ -441,13 +454,14 @@ impl AsExecutionPlan for PhysicalPlanNode { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); - + let ordering_req: Vec = agg_node.ordering_req.iter() + .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { match func { AggregateFunction::AggrFunction(i) => { - let aggr_function = protobuf::AggregateFunction::from_i32(*i) - .ok_or_else( - || { + let aggr_function = protobuf::AggregateFunction::try_from(*i) + .map_err( + |_| { proto_error(format!( "Received an unknown aggregate function: {i}" )) @@ -458,6 +472,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &aggr_function.into(), agg_node.distinct, input_phy_expr.as_slice(), + &ordering_req, &physical_schema, name.to_string(), ) @@ -471,10 +486,9 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Invalid AggregateExpr, missing aggregate_function") }) } - _ => Err(DataFusionError::Internal( + _ => internal_err!( "Invalid aggregate expression for AggregateExec" - .to_string(), - )), + ), } }) .collect::, _>>()?; @@ -486,7 +500,7 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_filter_expr, physical_order_by_expr, input, - Arc::new((&input_schema).try_into()?), + Arc::new(input_schema.try_into()?), )?)) } PhysicalPlanType::HashJoin(hashjoin) => { @@ -511,8 +525,8 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok((left, right)) }) .collect::>()?; - let join_type = protobuf::JoinType::from_i32(hashjoin.join_type) - .ok_or_else(|| { + let join_type = protobuf::JoinType::try_from(hashjoin.join_type) + .map_err(|_| { proto_error(format!( "Received a HashJoinNode message with unknown JoinType {}", hashjoin.join_type @@ -532,18 +546,18 @@ impl AsExecutionPlan for PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema + registry, &schema, )?; let column_indices = f.column_indices .iter() .map(|i| { - let side = protobuf::JoinSide::from_i32(i.side) - .ok_or_else(|| proto_error(format!( + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( "Received a HashJoinNode message with JoinSide in Filter {}", i.side)) )?; - Ok(ColumnIndex{ + Ok(ColumnIndex { index: i.index as usize, side: side.into(), }) @@ -554,14 +568,15 @@ impl AsExecutionPlan for PhysicalPlanNode { }) .map_or(Ok(None), |v: Result| v.map(Some))?; - let partition_mode = - protobuf::PartitionMode::from_i32(hashjoin.partition_mode) - .ok_or_else(|| { - proto_error(format!( + let partition_mode = protobuf::PartitionMode::try_from( + hashjoin.partition_mode, + ) + .map_err(|_| { + proto_error(format!( "Received a HashJoinNode message with unknown PartitionMode {}", hashjoin.partition_mode )) - })?; + })?; let partition_mode = match partition_mode { protobuf::PartitionMode::CollectLeft => PartitionMode::CollectLeft, protobuf::PartitionMode::Partitioned => PartitionMode::Partitioned, @@ -577,6 +592,97 @@ impl AsExecutionPlan for PhysicalPlanNode { hashjoin.null_equals_null, )?)) } + PhysicalPlanType::SymmetricHashJoin(sym_join) => { + let left = into_physical_plan( + &sym_join.left, + registry, + runtime, + extension_codec, + )?; + let right = into_physical_plan( + &sym_join.right, + registry, + runtime, + extension_codec, + )?; + let on = sym_join + .on + .iter() + .map(|col| { + let left = into_required!(col.left)?; + let right = into_required!(col.right)?; + Ok((left, right)) + }) + .collect::>()?; + let join_type = protobuf::JoinType::try_from(sym_join.join_type) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown JoinType {}", + sym_join.join_type + )) + })?; + let filter = sym_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema, + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = + protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown PartitionMode {}", + sym_join.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::StreamPartitionMode::SinglePartition => { + StreamJoinPartitionMode::SinglePartition + } + protobuf::StreamPartitionMode::PartitionedExec => { + StreamJoinPartitionMode::Partitioned + } + }; + SymmetricHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + sym_join.null_equals_null, + partition_mode, + ) + .map(|e| Arc::new(e) as _) + } PhysicalPlanType::Union(union) => { let mut inputs: Vec> = vec![]; for input in &union.inputs { @@ -588,6 +694,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } Ok(Arc::new(UnionExec::new(inputs))) } + PhysicalPlanType::Interleave(interleave) => { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan( + registry, + runtime, + extension_codec, + )?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan( &crossjoin.left, @@ -605,7 +722,11 @@ impl AsExecutionPlan for PhysicalPlanNode { } PhysicalPlanType::Empty(empty) => { let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(empty.produce_one_row, schema))) + Ok(Arc::new(EmptyExec::new(schema))) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) } PhysicalPlanType::Sort(sort) => { let input: Arc = @@ -630,16 +751,16 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, }, }) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "physical_plan::from_proto() {self:?}" - ))) + ) } }) .collect::, _>>()?; @@ -677,20 +798,27 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr,registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, }, }) } else { - Err(DataFusionError::Internal(format!( + internal_err!( "physical_plan::from_proto() {self:?}" - ))) + ) } }) .collect::, _>>()?; - Ok(Arc::new(SortPreservingMergeExec::new(exprs, input))) + let fetch = if sort.fetch < 0 { + None + } else { + Some(sort.fetch as usize) + }; + Ok(Arc::new( + SortPreservingMergeExec::new(exprs, input).with_fetch(fetch), + )) } PhysicalPlanType::Extension(extension) => { let inputs: Vec> = extension @@ -707,6 +835,106 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(extension_node) } + PhysicalPlanType::NestedLoopJoin(join) => { + let left: Arc = + into_physical_plan(&join.left, registry, runtime, extension_codec)?; + let right: Arc = + into_physical_plan(&join.right, registry, runtime, extension_codec)?; + let join_type = + protobuf::JoinType::try_from(join.join_type).map_err(|_| { + proto_error(format!( + "Received a NestedLoopJoinExecNode message with unknown JoinType {}", + join.join_type + )) + })?; + let filter = join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + registry, &schema, + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a NestedLoopJoinExecNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, schema)) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, + right, + filter, + &join_type.into(), + )?)) + } + PhysicalPlanType::Analyze(analyze) => { + let input: Arc = into_physical_plan( + &analyze.input, + registry, + runtime, + extension_codec, + )?; + Ok(Arc::new(AnalyzeExec::new( + analyze.verbose, + analyze.show_statistics, + input, + Arc::new(convert_required!(analyze.schema)?), + ))) + } + PhysicalPlanType::JsonSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: JsonSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } } } @@ -721,7 +949,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( protobuf::ExplainExecNode { schema: Some(exec.schema().as_ref().try_into()?), @@ -733,8 +961,10 @@ impl AsExecutionPlan for PhysicalPlanNode { verbose: exec.verbose(), }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -745,7 +975,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.0.clone().try_into()) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Projection(Box::new( protobuf::ProjectionExecNode { input: Some(Box::new(input)), @@ -753,27 +983,49 @@ impl AsExecutionPlan for PhysicalPlanNode { expr_name, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( + protobuf::AnalyzeExecNode { + verbose: exec.verbose(), + show_statistics: exec.show_statistics(), + input: Some(Box::new(input)), + schema: Some(exec.schema().as_ref().try_into()?), + }, + ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(exec.predicate().clone().try_into()?), + default_filter_selectivity: exec.default_selectivity() as u32, }, ))), - }) - } else if let Some(limit) = plan.downcast_ref::() { + }); + } + + if let Some(limit) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( limit.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::GlobalLimit(Box::new( protobuf::GlobalLimitExecNode { input: Some(Box::new(input)), @@ -784,21 +1036,25 @@ impl AsExecutionPlan for PhysicalPlanNode { }, }, ))), - }) - } else if let Some(limit) = plan.downcast_ref::() { + }); + } + + if let Some(limit) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( limit.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( protobuf::LocalLimitExecNode { input: Some(Box::new(input)), fetch: limit.fetch() as u32, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), extension_codec, @@ -853,7 +1109,7 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMode::Auto => protobuf::PartitionMode::Auto, }; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { left: Some(Box::new(left)), @@ -865,8 +1121,10 @@ impl AsExecutionPlan for PhysicalPlanNode { filter, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let left = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.left().to_owned(), extension_codec, @@ -875,15 +1133,89 @@ impl AsExecutionPlan for PhysicalPlanNode { exec.right().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + let on = exec + .on() + .iter() + .map(|tuple| protobuf::JoinOn { + left: Some(protobuf::PhysicalColumn { + name: tuple.0.name().to_string(), + index: tuple.0.index() as u32, + }), + right: Some(protobuf::PhysicalColumn { + name: tuple.1.name().to_string(), + index: tuple.1.index() as u32, + }), + }) + .collect(); + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = match exec.partition_mode() { + StreamJoinPartitionMode::SinglePartition => { + protobuf::StreamPartitionMode::SinglePartition + } + StreamJoinPartitionMode::Partitioned => { + protobuf::StreamPartitionMode::PartitionedExec + } + }; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( + protobuf::SymmetricHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equals_null: exec.null_equals_null(), + filter, + }, + ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( protobuf::CrossJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + if let Some(exec) = plan.downcast_ref::() { let groups: Vec = exec .group_expr() .groups() @@ -932,6 +1264,9 @@ impl AsExecutionPlan for PhysicalPlanNode { protobuf::AggregateMode::FinalPartitioned } AggregateMode::Single => protobuf::AggregateMode::Single, + AggregateMode::SinglePartitioned => { + protobuf::AggregateMode::SinglePartitioned + } }; let input_schema = exec.input_schema(); let input = protobuf::PhysicalPlanNode::try_from_physical_plan( @@ -953,7 +1288,7 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.0.to_owned().try_into()) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( protobuf::AggregateExecNode { group_expr, @@ -969,75 +1304,107 @@ impl AsExecutionPlan for PhysicalPlanNode { groups, }, ))), - }) - } else if let Some(empty) = plan.downcast_ref::() { + }); + } + + if let Some(empty) = plan.downcast_ref::() { let schema = empty.schema().as_ref().try_into()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty( protobuf::EmptyExecNode { - produce_one_row: empty.produce_one_row(), schema: Some(schema), }, )), - }) - } else if let Some(coalesce_batches) = plan.downcast_ref::() - { + }); + } + + if let Some(empty) = plan.downcast_ref::() { + let schema = empty.schema().as_ref().try_into()?; + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { + schema: Some(schema), + }, + )), + }); + } + + if let Some(coalesce_batches) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( coalesce_batches.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( protobuf::CoalesceBatchesExecNode { input: Some(Box::new(input)), target_batch_size: coalesce_batches.target_batch_size() as u32, }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CsvScan( protobuf::CsvScanExecNode { base_conf: Some(exec.base_config().try_into()?), has_header: exec.has_header(), - delimiter: csv_delimiter_to_string(exec.delimiter())?, + delimiter: byte_to_string(exec.delimiter(), "delimiter")?, + quote: byte_to_string(exec.quote(), "quote")?, + optional_escape: if let Some(escape) = exec.escape() { + Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + )) + } else { + None + }, }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + #[cfg(feature = "parquet")] + if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() .map(|pred| pred.clone().try_into()) .transpose()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(exec.base_config().try_into()?), predicate, }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { - Ok(protobuf::PhysicalPlanNode { + }); + } + + if let Some(exec) = plan.downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::AvroScan( protobuf::AvroScanExecNode { base_conf: Some(exec.base_config().try_into()?), }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( protobuf::CoalescePartitionsExecNode { input: Some(Box::new(input)), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1061,15 +1428,17 @@ impl AsExecutionPlan for PhysicalPlanNode { } }; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( protobuf::RepartitionExecNode { input: Some(Box::new(input)), partition_method: Some(pb_partition_method), }, ))), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1090,7 +1459,7 @@ impl AsExecutionPlan for PhysicalPlanNode { }) }) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( protobuf::SortExecNode { input: Some(Box::new(input)), @@ -1102,8 +1471,10 @@ impl AsExecutionPlan for PhysicalPlanNode { preserve_partitioning: exec.preserve_partitioning(), }, ))), - }) - } else if let Some(union) = plan.downcast_ref::() { + }); + } + + if let Some(union) = plan.downcast_ref::() { let mut inputs: Vec = vec![]; for input in union.inputs() { inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( @@ -1111,12 +1482,29 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?); } - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union( protobuf::UnionExecNode { inputs }, )), - }) - } else if let Some(exec) = plan.downcast_ref::() { + }); + } + + if let Some(interleave) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }); + } + + if let Some(exec) = plan.downcast_ref::() { let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, @@ -1137,39 +1525,206 @@ impl AsExecutionPlan for PhysicalPlanNode { }) }) .collect::>>()?; - Ok(protobuf::PhysicalPlanNode { + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortPreservingMerge( Box::new(protobuf::SortPreservingMergeExecNode { input: Some(Box::new(input)), expr, + fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1), }), )), - }) - } else { - let mut buf: Vec = vec![]; - match extension_codec.try_encode(plan_clone.clone(), &mut buf) { - Ok(_) => { - let inputs: Vec = plan_clone - .children() - .into_iter() + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = f.expression().to_owned().try_into()?; + let column_indices = f + .column_indices() + .iter() .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( - i, - extension_codec, - ) + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } }) - .collect::>()?; - - Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Extension( - protobuf::PhysicalExtensionNode { node: buf, inputs }, - )), + .collect(); + let schema = f.schema().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::NestedLoopJoin(Box::new( + protobuf::NestedLoopJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + join_type: join_type.into(), + filter, + }, + ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let window_expr = + exec.window_expr() + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let partition_keys = exec + .partition_keys + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + partition_keys, + input_order_mode: None, + }, + ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let window_expr = + exec.window_expr() + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let partition_keys = exec + .partition_keys + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( + protobuf::EmptyMessage {}, + ), + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { + columns: columns.iter().map(|c| *c as u64).collect(), + }, + ) } - Err(e) => Err(DataFusionError::Internal(format!( - "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" + InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( + protobuf::EmptyMessage {}, + ), + }; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + partition_keys, + input_order_mode: Some(input_order_mode), + }, ))), + }); + } + + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let sort_order = match exec.sort_order() { + Some(requirements) => { + let expr = requirements + .iter() + .map(|requirement| { + let expr: PhysicalSortExpr = requirement.to_owned().into(); + let sort_expr = protobuf::PhysicalSortExprNode { + expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }; + Ok(sort_expr) + }) + .collect::>>()?; + Some(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: expr, + }) + } + None => None, + }; + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( + protobuf::JsonSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); } + + // If unknown DataSink then let extension handle it + } + + let mut buf: Vec = vec![]; + match extension_codec.try_encode(plan_clone.clone(), &mut buf) { + Ok(_) => { + let inputs: Vec = plan_clone + .children() + .into_iter() + .map(|i| { + protobuf::PhysicalPlanNode::try_from_physical_plan( + i, + extension_codec, + ) + }) + .collect::>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + protobuf::PhysicalExtensionNode { node: buf, inputs }, + )), + }) + } + Err(e) => internal_err!( + "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" + ), } } } @@ -1220,9 +1775,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { _inputs: &[Arc], _registry: &dyn FunctionRegistry, ) -> Result> { - Err(DataFusionError::NotImplemented( - "PhysicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("PhysicalExtensionCodec is not provided") } fn try_encode( @@ -1230,9 +1783,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { _node: Arc, _buf: &mut Vec, ) -> Result<()> { - Err(DataFusionError::NotImplemented( - "PhysicalExtensionCodec is not provided".to_string(), - )) + not_impl_err!("PhysicalExtensionCodec is not provided") } } @@ -1248,534 +1799,3 @@ fn into_physical_plan( Err(proto_error("Missing required field in protobuf")) } } - -#[cfg(test)] -mod roundtrip_tests { - use std::ops::Deref; - use std::sync::Arc; - - use super::super::protobuf; - use crate::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; - use datafusion::arrow::array::ArrayRef; - use datafusion::arrow::datatypes::IntervalUnit; - use datafusion::datasource::object_store::ObjectStoreUrl; - use datafusion::execution::context::ExecutionProps; - use datafusion::logical_expr::create_udf; - use datafusion::logical_expr::{BuiltinScalarFunction, Volatility}; - use datafusion::physical_expr::expressions::in_list; - use datafusion::physical_expr::ScalarFunctionExpr; - use datafusion::physical_plan::aggregates::PhysicalGroupBy; - use datafusion::physical_plan::expressions::{ - date_time_interval_expr, like, BinaryExpr, GetIndexedFieldExpr, - }; - use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::physical_plan::projection::ProjectionExec; - use datafusion::physical_plan::{functions, udaf}; - use datafusion::{ - arrow::{ - compute::kernels::sort::SortOptions, - datatypes::{DataType, Field, Schema}, - }, - datasource::{ - listing::PartitionedFile, - physical_plan::{FileScanConfig, ParquetExec}, - }, - logical_expr::{JoinType, Operator}, - physical_plan::{ - aggregates::{AggregateExec, AggregateMode}, - empty::EmptyExec, - expressions::{binary, col, lit, NotExpr}, - expressions::{Avg, Column, DistinctCount, PhysicalSortExpr}, - filter::FilterExec, - joins::{HashJoinExec, PartitionMode}, - limit::{GlobalLimitExec, LocalLimitExec}, - sorts::sort::SortExec, - AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, - }, - prelude::SessionContext, - scalar::ScalarValue, - }; - use datafusion_common::Result; - use datafusion_expr::{ - Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, - Signature, StateTypeFunction, - }; - - fn roundtrip_test(exec_plan: Arc) -> Result<()> { - let ctx = SessionContext::new(); - let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) - .expect("to proto"); - let runtime = ctx.runtime_env(); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) - .expect("from proto"); - assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) - } - - fn roundtrip_test_with_context( - exec_plan: Arc, - ctx: SessionContext, - ) -> Result<()> { - let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) - .expect("to proto"); - let runtime = ctx.runtime_env(); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) - .expect("from proto"); - assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) - } - - #[test] - fn roundtrip_empty() -> Result<()> { - roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty())))) - } - - #[test] - fn roundtrip_date_time_interval() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("some_date", DataType::Date32, false), - Field::new( - "some_interval", - DataType::Interval(IntervalUnit::DayTime), - false, - ), - ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); - let date_expr = col("some_date", &schema)?; - let literal_expr = col("some_interval", &schema)?; - let date_time_interval_expr = - date_time_interval_expr(date_expr, Operator::Plus, literal_expr, &schema)?; - let plan = Arc::new(ProjectionExec::try_new( - vec![(date_time_interval_expr, "result".to_string())], - input, - )?); - roundtrip_test(plan) - } - - #[test] - fn roundtrip_local_limit() -> Result<()> { - roundtrip_test(Arc::new(LocalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), - 25, - ))) - } - - #[test] - fn roundtrip_global_limit() -> Result<()> { - roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), - 0, - Some(25), - ))) - } - - #[test] - fn roundtrip_global_skip_no_limit() -> Result<()> { - roundtrip_test(Arc::new(GlobalLimitExec::new( - Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))), - 10, - None, // no limit - ))) - } - - #[test] - fn roundtrip_hash_join() -> Result<()> { - let field_a = Field::new("col", DataType::Int64, false); - let schema_left = Schema::new(vec![field_a.clone()]); - let schema_right = Schema::new(vec![field_a]); - let on = vec![( - Column::new("col", schema_left.index_of("col")?), - Column::new("col", schema_right.index_of("col")?), - )]; - - let schema_left = Arc::new(schema_left); - let schema_right = Arc::new(schema_right); - for join_type in &[ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftAnti, - JoinType::RightAnti, - JoinType::LeftSemi, - JoinType::RightSemi, - ] { - for partition_mode in - &[PartitionMode::Partitioned, PartitionMode::CollectLeft] - { - roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, schema_left.clone())), - Arc::new(EmptyExec::new(false, schema_right.clone())), - on.clone(), - None, - join_type, - *partition_mode, - false, - )?))?; - } - } - Ok(()) - } - - #[test] - fn rountrip_aggregate() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - let aggregates: Vec> = - vec![Arc::new(Avg::new_with_pre_cast( - col("b", &schema)?, - "AVG(b)".to_string(), - DataType::Float64, - DataType::Float64, - true, - ))]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups.clone()), - aggregates.clone(), - vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), - schema, - )?)) - } - - #[test] - fn roundtrip_aggregate_udaf() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - #[derive(Debug)] - struct Example; - impl Accumulator for Example { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(0))]) - } - - fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { - Ok(()) - } - - fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { - Ok(()) - } - - fn evaluate(&self) -> Result { - Ok(ScalarValue::Int64(Some(0))) - } - - fn size(&self) -> usize { - 0 - } - } - - let rt_func: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Int64))); - let accumulator: AccumulatorFunctionImplementation = - Arc::new(|_| Ok(Box::new(Example))); - let st_func: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); - - let udaf = AggregateUDF::new( - "example", - &Signature::exact(vec![DataType::Int64], Volatility::Immutable), - &rt_func, - &accumulator, - &st_func, - ); - - let ctx = SessionContext::new(); - ctx.register_udaf(udaf.clone()); - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - let aggregates: Vec> = vec![udaf::create_aggregate_expr( - &udaf, - &[col("b", &schema)?], - &schema, - "example_agg", - )?]; - - roundtrip_test_with_context( - Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups.clone()), - aggregates.clone(), - vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), - schema, - )?), - ctx, - ) - } - - #[test] - fn roundtrip_filter_with_not_and_in_list() -> Result<()> { - let field_a = Field::new("a", DataType::Boolean, false); - let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("c", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); - let not = Arc::new(NotExpr::new(col("a", &schema)?)); - let in_list = in_list( - col("b", &schema)?, - vec![ - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(2))), - ], - &false, - schema.as_ref(), - )?; - let and = binary(not, Operator::And, in_list, &schema)?; - roundtrip_test(Arc::new(FilterExec::try_new( - and, - Arc::new(EmptyExec::new(false, schema.clone())), - )?)) - } - - #[test] - fn roundtrip_sort() -> Result<()> { - let field_a = Field::new("a", DataType::Boolean, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = vec![ - PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: true, - nulls_first: false, - }, - }, - PhysicalSortExpr { - expr: col("b", &schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }, - ]; - roundtrip_test(Arc::new(SortExec::new( - sort_exprs, - Arc::new(EmptyExec::new(false, schema)), - ))) - } - - #[test] - fn roundtrip_sort_preserve_partitioning() -> Result<()> { - let field_a = Field::new("a", DataType::Boolean, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = vec![ - PhysicalSortExpr { - expr: col("a", &schema)?, - options: SortOptions { - descending: true, - nulls_first: false, - }, - }, - PhysicalSortExpr { - expr: col("b", &schema)?, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }, - ]; - - roundtrip_test(Arc::new(SortExec::new( - sort_exprs.clone(), - Arc::new(EmptyExec::new(false, schema.clone())), - )))?; - - roundtrip_test(Arc::new( - SortExec::new(sort_exprs, Arc::new(EmptyExec::new(false, schema))) - .with_preserve_partitioning(true), - )) - } - - #[test] - fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { - let scan_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: Arc::new(Schema::new(vec![Field::new( - "col", - DataType::Utf8, - false, - )])), - file_groups: vec![vec![PartitionedFile::new( - "/path/to/file.parquet".to_string(), - 1024, - )]], - statistics: Statistics { - num_rows: Some(100), - total_byte_size: Some(1024), - column_statistics: None, - is_exact: false, - }, - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }; - - let predicate = Arc::new(BinaryExpr::new( - Arc::new(Column::new("col", 1)), - Operator::Eq, - lit("1"), - )); - roundtrip_test(Arc::new(ParquetExec::new( - scan_config, - Some(predicate), - None, - ))) - } - - #[test] - fn roundtrip_builtin_scalar_function() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let input = Arc::new(EmptyExec::new(false, schema.clone())); - - let execution_props = ExecutionProps::new(); - - let fun_expr = functions::create_physical_fun( - &BuiltinScalarFunction::Abs, - &execution_props, - )?; - - let expr = ScalarFunctionExpr::new( - "abs", - fun_expr, - vec![col("a", &schema)?], - &DataType::Int64, - ); - - let project = - ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; - - roundtrip_test(Arc::new(project)) - } - - #[test] - fn roundtrip_scalar_udf() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let input = Arc::new(EmptyExec::new(false, schema.clone())); - - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); - - let udf = create_udf( - "dummy", - vec![DataType::Int64], - Arc::new(DataType::Int64), - Volatility::Immutable, - scalar_fn.clone(), - ); - - let expr = ScalarFunctionExpr::new( - "dummy", - scalar_fn, - vec![col("a", &schema)?], - &DataType::Int64, - ); - - let project = - ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; - - let ctx = SessionContext::new(); - - ctx.register_udf(udf); - - roundtrip_test_with_context(Arc::new(project), ctx) - } - - #[test] - fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( - DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), - vec![None], - vec![None], - Arc::new(EmptyExec::new(false, schema.clone())), - schema, - )?)) - } - - #[test] - fn roundtrip_like() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - ]); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); - let like_expr = like( - false, - false, - col("a", &schema)?, - col("b", &schema)?, - &schema, - )?; - let plan = Arc::new(ProjectionExec::try_new( - vec![(like_expr, "result".to_string())], - input, - )?); - roundtrip_test(plan) - } - - #[test] - fn roundtrip_get_indexed_field() -> Result<()> { - let fields = vec![ - Field::new("id", DataType::Int64, true), - Field::new_list("a", Field::new("item", DataType::Float64, true), true), - ]; - - let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); - - let col_a = col("a", &schema)?; - let key = ScalarValue::Int64(Some(1)); - let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new(col_a, key)); - - let plan = Arc::new(ProjectionExec::try_new( - vec![(get_indexed_field_expr, "result".to_string())], - input, - )?); - - roundtrip_test(plan) - } -} diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 0910ddaad0c73..ea00b726b9d68 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,162 +22,95 @@ use std::{ sync::Arc, }; -use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; -use datafusion::physical_plan::ColumnStatistics; -use datafusion::physical_plan::{ - expressions::{ - CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, - }, - Statistics, -}; - -use datafusion::datasource::listing::{FileRange, PartitionedFile}; -use datafusion::datasource::physical_plan::FileScanConfig; - -use datafusion::physical_plan::expressions::{Count, DistinctCount, Literal}; - -use datafusion::physical_plan::expressions::{ - Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Column, LikeExpr, Max, Min, - Sum, -}; -use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; - -use crate::protobuf; +use crate::protobuf::{self, physical_window_expr_node, scalar_value::Value}; use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, ScalarValue, }; + +use datafusion::datasource::{ + file_format::json::JsonSink, + listing::{FileRange, PartitionedFile}, + physical_plan::FileScanConfig, + physical_plan::FileSinkConfig, +}; use datafusion::logical_expr::BuiltinScalarFunction; -use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr}; +use datafusion::physical_expr::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; +use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; -use datafusion::physical_plan::joins::utils::JoinSide; +use datafusion::physical_plan::expressions::{ + ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, + ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, + DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, + Min, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, + Regr, RegrType, RowNumber, Stddev, StddevPop, Sum, TryCastExpr, Variance, + VariancePop, WindowShift, +}; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion_common::{DataFusionError, Result}; +use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; +use datafusion::physical_plan::{ + AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, +}; +use datafusion_common::{ + file_options::{ + arrow_writer::ArrowWriterOptions, avro_writer::AvroWriterOptions, + csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions, + parquet_writer::ParquetWriterOptions, + }, + internal_err, not_impl_err, + parsers::CompressionTypeVariant, + stats::Precision, + DataFusionError, FileTypeWriterOptions, JoinSide, Result, +}; impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { - use datafusion::physical_plan::expressions; - use protobuf::AggregateFunction; - let expressions: Vec = a .expressions() .iter() .map(|e| e.clone().try_into()) .collect::>>()?; - let mut distinct = false; - let aggr_function = if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Avg.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Sum.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Count.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitAnd.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitOr.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BitXor.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BoolAnd.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::BoolOr.into()) - } else if a.as_any().downcast_ref::().is_some() { - distinct = true; - Ok(AggregateFunction::Count.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Min.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Max.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxDistinct.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::ArrayAgg.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Variance.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::VariancePop.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::Covariance.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::CovariancePop.into()) - } else if a.as_any().downcast_ref::().is_some() { - Ok(AggregateFunction::Stddev.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::StddevPop.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::Correlation.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxPercentileCont.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxPercentileContWithWeight.into()) - } else if a - .as_any() - .downcast_ref::() - .is_some() - { - Ok(AggregateFunction::ApproxMedian.into()) - } else { - if let Some(a) = a.as_any().downcast_ref::() { - return Ok(protobuf::PhysicalExprNode { + let ordering_req: Vec = a + .order_bys() + .unwrap_or(&[]) + .iter() + .map(|e| e.clone().try_into()) + .collect::>>()?; + + if let Some(a) = a.as_any().downcast_ref::() { + let name = a.fun().name().to_string(); + return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, - distinct, + ordering_req, + distinct: false, }, )), }); - } + } - Err(DataFusionError::NotImplemented(format!( - "Aggregate function not supported: {a:?}" - ))) - }?; + let AggrFn { + inner: aggr_function, + distinct, + } = aggr_expr_to_aggr_fn(a.as_ref())?; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some( physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function, + aggr_function as i32, ), ), expr: expressions, + ordering_req, distinct, }, )), @@ -185,6 +118,253 @@ impl TryFrom> for protobuf::PhysicalExprNode { } } +impl TryFrom> for protobuf::PhysicalWindowExprNode { + type Error = DataFusionError; + + fn try_from( + window_expr: Arc, + ) -> std::result::Result { + let expr = window_expr.as_any(); + + let mut args = window_expr.expressions().to_vec(); + let window_frame = window_expr.get_window_frame(); + + let window_function = if let Some(built_in_window_expr) = + expr.downcast_ref::() + { + let expr = built_in_window_expr.get_built_in_func_expr(); + let built_in_fn_expr = expr.as_any(); + + let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::RowNumber + } else if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() { + match rank_expr.get_type() { + RankType::Basic => protobuf::BuiltInWindowFunction::Rank, + RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, + RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, + } + } else if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::CumeDist + } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { + args.insert( + 0, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + ntile_expr.get_n() as i64, + )))), + ); + protobuf::BuiltInWindowFunction::Ntile + } else if let Some(window_shift_expr) = + built_in_fn_expr.downcast_ref::() + { + args.insert( + 1, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + window_shift_expr.get_shift_offset(), + )))), + ); + if let Some(default_value) = window_shift_expr.get_default_value() { + args.insert(2, Arc::new(Literal::new(default_value))); + } + if window_shift_expr.get_shift_offset() >= 0 { + protobuf::BuiltInWindowFunction::Lag + } else { + protobuf::BuiltInWindowFunction::Lead + } + } else if let Some(nth_value_expr) = + built_in_fn_expr.downcast_ref::() + { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new( + datafusion_common::ScalarValue::Int64(Some(n)), + )), + ); + protobuf::BuiltInWindowFunction::NthValue + } + } + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; + + physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) + } else if let Some(plain_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + plain_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } + + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else if let Some(sliding_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + sliding_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + if window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } + + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else { + return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + }; + + let args = args + .into_iter() + .map(|e| e.try_into()) + .collect::>>()?; + + let partition_by = window_expr + .partition_by() + .iter() + .map(|p| p.clone().try_into()) + .collect::>>()?; + + let order_by = window_expr + .order_by() + .iter() + .map(|o| o.clone().try_into()) + .collect::>>()?; + + let window_frame: protobuf::WindowFrame = window_frame + .as_ref() + .try_into() + .map_err(|e| DataFusionError::Internal(format!("{e}")))?; + + let name = window_expr.name().to_string(); + + Ok(protobuf::PhysicalWindowExprNode { + args, + partition_by, + order_by, + window_frame: Some(window_frame), + window_function: Some(window_function), + name, + }) + } +} + +struct AggrFn { + inner: protobuf::AggregateFunction, + distinct: bool, +} + +fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { + let aggr_expr = expr.as_any(); + let mut distinct = false; + + let inner = if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Count + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::Count + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Grouping + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitAnd + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitOr + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BitXor + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::BitXor + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BoolAnd + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::BoolOr + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Sum + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::Sum + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxDistinct + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + distinct = true; + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ArrayAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Min + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Max + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Avg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Variance + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::VariancePop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Covariance + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::CovariancePop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Stddev + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::StddevPop + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Correlation + } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { + match regr_expr.get_regr_type() { + RegrType::Slope => protobuf::AggregateFunction::RegrSlope, + RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, + RegrType::Count => protobuf::AggregateFunction::RegrCount, + RegrType::R2 => protobuf::AggregateFunction::RegrR2, + RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, + RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, + RegrType::SXX => protobuf::AggregateFunction::RegrSxx, + RegrType::SYY => protobuf::AggregateFunction::RegrSyy, + RegrType::SXY => protobuf::AggregateFunction::RegrSxy, + } + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxPercentileCont + } else if aggr_expr + .downcast_ref::() + .is_some() + { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::ApproxMedian + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::Median + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::FirstValueAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::LastValueAgg + } else { + return not_impl_err!("Aggregate function not supported: {expr:?}"); + }; + + Ok(AggrFn { inner, distinct }) +} + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; @@ -350,20 +530,6 @@ impl TryFrom> for protobuf::PhysicalExprNode { )), }) } - } else if let Some(expr) = expr.downcast_ref::() { - let dti_expr = Box::new(protobuf::PhysicalDateTimeIntervalExprNode { - l: Some(Box::new(expr.lhs().to_owned().try_into()?)), - r: Some(Box::new(expr.rhs().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::DateTimeIntervalExpr( - dti_expr, - ), - ), - }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( @@ -376,20 +542,37 @@ impl TryFrom> for protobuf::PhysicalExprNode { )), }) } else if let Some(expr) = expr.downcast_ref::() { + let field = match expr.field() { + GetFieldAccessExpr::NamedStructField{name} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(protobuf::NamedStructFieldExpr { + name: Some(ScalarValue::try_from(name)?) + }) + ), + GetFieldAccessExpr::ListIndex{key} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::ListIndexExpr(Box::new(protobuf::ListIndexExpr { + key: Some(Box::new(key.to_owned().try_into()?)) + })) + ), + GetFieldAccessExpr::ListRange{start, stop} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr { + start: Some(Box::new(start.to_owned().try_into()?)), + stop: Some(Box::new(stop.to_owned().try_into()?)), + })) + ), + }; + Ok(protobuf::PhysicalExprNode { expr_type: Some( protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( Box::new(protobuf::PhysicalGetIndexedFieldExprNode { arg: Some(Box::new(expr.arg().to_owned().try_into()?)), - key: Some(ScalarValue::try_from(expr.key())?), + field, }), ), ), }) } else { - Err(DataFusionError::Internal(format!( - "physical_plan::to_proto() unsupported expression {value:?}" - ))) + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } } @@ -408,10 +591,16 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { type Error = DataFusionError; fn try_from(pf: &PartitionedFile) -> Result { + let last_modified = pf.object_meta.last_modified; + let last_modified_ns = last_modified.timestamp_nanos_opt().ok_or_else(|| { + DataFusionError::Plan(format!( + "Invalid timestamp on PartitionedFile::ObjectMeta: {last_modified}" + )) + })? as u64; Ok(protobuf::PartitionedFile { path: pf.object_meta.location.as_ref().to_owned(), size: pf.object_meta.size as u64, - last_modified_ns: pf.object_meta.last_modified.timestamp_nanos() as u64, + last_modified_ns, partition_values: pf .partition_values .iter() @@ -446,29 +635,66 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { } } -impl From<&ColumnStatistics> for protobuf::ColumnStats { - fn from(cs: &ColumnStatistics) -> protobuf::ColumnStats { - protobuf::ColumnStats { - min_value: cs.min_value.as_ref().map(|m| m.try_into().unwrap()), - max_value: cs.max_value.as_ref().map(|m| m.try_into().unwrap()), - null_count: cs.null_count.map(|n| n as u32).unwrap_or(0), - distinct_count: cs.distinct_count.map(|n| n as u32).unwrap_or(0), +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: Some(ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: Some(ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(ScalarValue { value: None }), + }, + } + } +} + +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: val.try_into().ok(), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: val.try_into().ok(), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(ScalarValue { value: None }), + }, } } } impl From<&Statistics> for protobuf::Statistics { fn from(s: &Statistics) -> protobuf::Statistics { - let none_value = -1_i64; - let column_stats = match &s.column_statistics { - None => vec![], - Some(column_stats) => column_stats.iter().map(|s| s.into()).collect(), - }; + let column_stats = s.column_statistics.iter().map(|s| s.into()).collect(); protobuf::Statistics { - num_rows: s.num_rows.map(|n| n as i64).unwrap_or(none_value), - total_byte_size: s.total_byte_size.map(|n| n as i64).unwrap_or(none_value), + num_rows: Some(protobuf::Precision::from(&s.num_rows)), + total_byte_size: Some(protobuf::Precision::from(&s.total_byte_size)), column_stats, - is_exact: s.is_exact, + } + } +} + +impl From<&ColumnStatistics> for protobuf::ColumnStats { + fn from(s: &ColumnStatistics) -> protobuf::ColumnStats { + protobuf::ColumnStats { + min_value: Some(protobuf::Precision::from(&s.min_value)), + max_value: Some(protobuf::Precision::from(&s.max_value)), + null_count: Some(protobuf::Precision::from(&s.null_count)), + distinct_count: Some(protobuf::Precision::from(&s.distinct_count)), } } } @@ -515,7 +741,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { table_partition_cols: conf .table_partition_cols .iter() - .map(|x| x.0.clone()) + .map(|x| x.name().clone()) .collect::>(), object_store_url: conf.object_store_url.to_string(), output_ordering: output_orderings @@ -577,3 +803,98 @@ impl TryFrom for protobuf::PhysicalSortExprNode { }) } } + +impl TryFrom<&JsonSink> for protobuf::JsonSink { + type Error = DataFusionError; + + fn try_from(value: &JsonSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { + type Error = DataFusionError; + + fn try_from(conf: &FileSinkConfig) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(TryInto::try_into) + .collect::>>()?; + let table_paths = conf + .table_paths + .iter() + .map(ToString::to_string) + .collect::>(); + let table_partition_cols = conf + .table_partition_cols + .iter() + .map(|(name, data_type)| { + Ok(protobuf::PartitionColumn { + name: name.to_owned(), + arrow_type: Some(data_type.try_into()?), + }) + }) + .collect::>>()?; + let file_type_writer_options = &conf.file_type_writer_options; + Ok(Self { + object_store_url: conf.object_store_url.to_string(), + file_groups, + table_paths, + output_schema: Some(conf.output_schema.as_ref().try_into()?), + table_partition_cols, + single_file_output: conf.single_file_output, + unbounded_input: conf.unbounded_input, + overwrite: conf.overwrite, + file_type_writer_options: Some(file_type_writer_options.try_into()?), + }) + } +} + +impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { + fn from(value: &CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { + type Error = DataFusionError; + + fn try_from(opts: &FileTypeWriterOptions) -> Result { + let file_type = match opts { + #[cfg(feature = "parquet")] + FileTypeWriterOptions::Parquet(ParquetWriterOptions { + writer_options: _, + }) => return not_impl_err!("Parquet file sink protobuf serialization"), + FileTypeWriterOptions::CSV(CsvWriterOptions { + writer_options: _, + compression: _, + }) => return not_impl_err!("CSV file sink protobuf serialization"), + FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::file_type_writer_options::FileType::JsonOptions( + protobuf::JsonWriterOptions { + compression: compression.into(), + }, + ) + } + FileTypeWriterOptions::Avro(AvroWriterOptions {}) => { + return not_impl_err!("Avro file sink protobuf serialization") + } + FileTypeWriterOptions::Arrow(ArrowWriterOptions {}) => { + return not_impl_err!("Arrow file sink protobuf serialization") + } + }; + Ok(Self { + file_type: Some(file_type), + }) + } +} diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs new file mode 100644 index 0000000000000..b17289205f3de --- /dev/null +++ b/datafusion/proto/tests/cases/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod roundtrip_logical_plan; +mod roundtrip_physical_plan; +mod serialize; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs new file mode 100644 index 0000000000000..8e15b5d0d4808 --- /dev/null +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -0,0 +1,1667 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::datatypes::{ + DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, +}; + +use prost::Message; + +use datafusion::datasource::provider::TableProviderFactory; +use datafusion::datasource::TableProvider; +use datafusion::execution::context::SessionState; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; +use datafusion::test_util::{TestTableFactory, TestTableProvider}; +use datafusion_common::Result; +use datafusion_common::{internal_err, not_impl_err, plan_err}; +use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; +use datafusion_expr::expr::{ + self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, + Sort, +}; +use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; +use datafusion_expr::{ + col, create_udaf, lit, Accumulator, AggregateFunction, + BuiltinScalarFunction::{Sqrt, Substr}, + Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, +}; +use datafusion_proto::bytes::{ + logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, + logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, +}; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::logical_plan::{from_proto, to_proto}; +use datafusion_proto::protobuf; + +#[cfg(feature = "json")] +fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { + let string = serde_json::to_string(proto).unwrap(); + let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); + assert_eq!(proto, &back); +} + +#[cfg(not(feature = "json"))] +fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} + +// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test +// equality. +fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) +where + for<'a> &'a T: TryInto + Debug, + E: Debug, +{ + let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); + let round_trip: Expr = from_proto::parse_expr(&proto, &ctx).unwrap(); + + assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); + + roundtrip_json_test(&proto); +} + +fn new_arc_field(name: &str, dt: DataType, nullable: bool) -> Arc { + Arc::new(Field::new(name, dt, nullable)) +} + +#[tokio::test] +async fn roundtrip_logical_plan() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let scan = ctx.table("t1").await?.into_optimized_plan()?; + let topk_plan = LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), + }); + let extension_codec = TopKExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; + assert_eq!(format!("{topk_plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} + +#[derive(Clone, PartialEq, Eq, ::prost::Message)] +pub struct TestTableProto { + /// URL of the table root + #[prost(string, tag = "1")] + pub url: String, +} + +#[derive(Debug)] +pub struct TestTableProviderCodec {} + +impl LogicalExtensionCodec for TestTableProviderCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("No extension codec provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + let msg = TestTableProto::decode(buf).map_err(|_| { + DataFusionError::Internal("Error decoding test table".to_string()) + })?; + let provider = TestTableProvider { + url: msg.url, + schema, + }; + Ok(Arc::new(provider)) + } + + fn try_encode_table_provider( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + let table = node + .as_ref() + .as_any() + .downcast_ref::() + .expect("Can't encode non-test tables"); + let msg = TestTableProto { + url: table.url.clone(), + }; + msg.encode(buf).map_err(|_| { + DataFusionError::Internal("Error encoding test table".to_string()) + }) + } +} + +#[tokio::test] +async fn roundtrip_custom_tables() -> Result<()> { + let mut table_factories: HashMap> = + HashMap::new(); + table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); + let cfg = RuntimeConfig::new(); + let env = RuntimeEnv::new(cfg).unwrap(); + let ses = SessionConfig::new(); + let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + // replace factories + *state.table_factories_mut() = table_factories; + let ctx = SessionContext::new_with_state(state); + + let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; + ctx.sql(sql).await.unwrap(); + + let codec = TestTableProviderCodec {}; + let scan = ctx.table("t").await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{scan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_custom_memory_tables() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let query = "CREATE TABLE sales_global_with_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0)"; + + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_custom_listing_tables() -> Result<()> { + let ctx = SessionContext::new(); + + let query = "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER DEFAULT 1*2 + 3, + b INTEGER DEFAULT NULL, + c INTEGER, + d INTEGER, + primary key(c) + ) + STORED AS CSV + WITH HEADER ROW + WITH ORDER (a ASC, b ASC) + WITH ORDER (c ASC) + LOCATION '../core/tests/data/window_2.csv';"; + + let plan = ctx.state().create_logical_plan(query).await?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + // Use exact matching to verify everything. Make sure during round-trip, + // information like constraints, column defaults, and other aspects of the plan are preserved. + assert_eq!(plan, logical_round_trip); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_aggregation_with_pk() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql( + "CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + primary key(c) + ) + STORED AS CSV + WITH HEADER ROW + WITH ORDER (a ASC, b ASC) + WITH ORDER (c ASC) + LOCATION '../core/tests/data/window_2.csv';", + ) + .await?; + + let query = "SELECT c, b, SUM(d) + FROM multiple_ordered_table_with_pk + GROUP BY c"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_aggregation() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_distinct_on() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT DISTINCT ON (a % 2) a, b * 2 FROM t1 ORDER BY a % 2 DESC, b"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_single_count_distinct() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT a, COUNT(DISTINCT b) as b_cd FROM t1 GROUP BY a"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_with_extension() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let plan = ctx.table("t1").await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + ctx.sql("CREATE VIEW view_t1(a, b) AS SELECT a, b FROM t1") + .await?; + + // SELECT + let plan = ctx + .sql("SELECT * FROM view_t1") + .await? + .into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + // DROP + let plan = ctx.sql("DROP VIEW view_t1").await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +pub mod proto { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct TopKPlanProto { + #[prost(uint64, tag = "1")] + pub k: u64, + + #[prost(message, optional, tag = "2")] + pub expr: ::core::option::Option, + } + + #[derive(Clone, PartialEq, Eq, ::prost::Message)] + pub struct TopKExecProto { + #[prost(uint64, tag = "1")] + pub k: u64, + } +} + +#[derive(PartialEq, Eq, Hash)] +struct TopKPlanNode { + k: usize, + input: LogicalPlan, + /// The sort expression (this example only supports a single sort + /// expr) + expr: Expr, +} + +impl TopKPlanNode { + pub fn new(k: usize, input: LogicalPlan, expr: Expr) -> Self { + Self { k, input, expr } + } +} + +impl Debug for TopKPlanNode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } +} + +impl UserDefinedLogicalNodeCore for TopKPlanNode { + fn name(&self) -> &str { + "TopK" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for TopK is the same as the input + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![self.expr.clone()] + } + + /// For example: `TopK: k=10` + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "TopK: k={}", self.k) + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + assert_eq!(exprs.len(), 1, "expression size inconsistent"); + Self { + k: self.k, + input: inputs[0].clone(), + expr: exprs[0].clone(), + } + } +} + +#[derive(Debug)] +pub struct TopKExtensionCodec {} + +impl LogicalExtensionCodec for TopKExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &SessionContext, + ) -> Result { + if let Some((input, _)) = inputs.split_first() { + let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to decode logical plan: {e:?}")) + })?; + + if let Some(expr) = proto.expr.as_ref() { + let node = TopKPlanNode::new( + proto.k as usize, + input.clone(), + from_proto::parse_expr(expr, ctx)?, + ); + + Ok(Extension { + node: Arc::new(node), + }) + } else { + internal_err!("invalid plan, no expr") + } + } else { + internal_err!("invalid plan, no input") + } + } + + fn try_encode(&self, node: &Extension, buf: &mut Vec) -> Result<()> { + if let Some(exec) = node.node.as_any().downcast_ref::() { + let proto = proto::TopKPlanProto { + k: exec.k as u64, + expr: Some((&exec.expr).try_into()?), + }; + + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode logical plan: {e:?}")) + })?; + + Ok(()) + } else { + internal_err!("unsupported plan type") + } + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + internal_err!("unsupported plan type") + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + internal_err!("unsupported plan type") + } +} + +#[test] +fn scalar_values_error_serialization() { + let should_fail_on_seralize: Vec = vec![ + // Should fail due to empty values + ScalarValue::Struct( + Some(vec![]), + vec![Field::new("item", DataType::Int16, true)].into(), + ), + ]; + + for test_case in should_fail_on_seralize.into_iter() { + let proto: Result = + (&test_case).try_into(); + + // Validation is also done on read, so if serialization passed + // also try to convert back to ScalarValue + if let Ok(proto) = proto { + let res: Result = (&proto).try_into(); + assert!( + res.is_err(), + "The value {test_case:?} unexpectedly serialized without error:{res:?}" + ); + } + } +} + +#[test] +fn round_trip_scalar_values() { + let should_pass: Vec = vec![ + ScalarValue::Boolean(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), + ScalarValue::Date32(None), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Float32(Some(1.0)), + ScalarValue::Float32(Some(f32::MAX)), + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(-2000.0)), + ScalarValue::Float64(Some(1.0)), + ScalarValue::Float64(Some(f64::MAX)), + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(-2000.0)), + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ScalarValue::Int8(Some(0)), + ScalarValue::Int8(Some(-15)), + ScalarValue::Int16(Some(i16::MIN)), + ScalarValue::Int16(Some(i16::MAX)), + ScalarValue::Int16(Some(0)), + ScalarValue::Int16(Some(-15)), + ScalarValue::Int32(Some(i32::MIN)), + ScalarValue::Int32(Some(i32::MAX)), + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(-15)), + ScalarValue::Int64(Some(i64::MIN)), + ScalarValue::Int64(Some(i64::MAX)), + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(-15)), + ScalarValue::UInt8(Some(u8::MAX)), + ScalarValue::UInt8(Some(0)), + ScalarValue::UInt16(Some(u16::MAX)), + ScalarValue::UInt16(Some(0)), + ScalarValue::UInt32(Some(u32::MAX)), + ScalarValue::UInt32(Some(0)), + ScalarValue::UInt64(Some(u64::MAX)), + ScalarValue::UInt64(Some(0)), + ScalarValue::Utf8(Some(String::from("Test string "))), + ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), + ScalarValue::Date32(Some(0)), + ScalarValue::Date32(Some(i32::MAX)), + ScalarValue::Date32(None), + ScalarValue::Date64(Some(0)), + ScalarValue::Date64(Some(i64::MAX)), + ScalarValue::Date64(None), + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(i32::MAX)), + ScalarValue::Time32Second(None), + ScalarValue::Time32Millisecond(Some(0)), + ScalarValue::Time32Millisecond(Some(i32::MAX)), + ScalarValue::Time32Millisecond(None), + ScalarValue::Time64Microsecond(Some(0)), + ScalarValue::Time64Microsecond(Some(i64::MAX)), + ScalarValue::Time64Microsecond(None), + ScalarValue::Time64Nanosecond(Some(0)), + ScalarValue::Time64Nanosecond(Some(i64::MAX)), + ScalarValue::Time64Nanosecond(None), + ScalarValue::TimestampNanosecond(Some(0), None), + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + ScalarValue::TimestampNanosecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampNanosecond(None, None), + ScalarValue::TimestampMicrosecond(Some(0), None), + ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), + ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampMillisecond(Some(0), None), + ScalarValue::TimestampMillisecond(Some(i64::MAX), None), + ScalarValue::TimestampMillisecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampMillisecond(None, None), + ScalarValue::TimestampSecond(Some(0), None), + ScalarValue::TimestampSecond(Some(i64::MAX), None), + ScalarValue::TimestampSecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampSecond(None, None), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( + i32::MAX, + i32::MAX, + ))), + ScalarValue::IntervalDayTime(None), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 0, 0, 0, + ))), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 1, 2, 3, + ))), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + i32::MAX, + i32::MAX, + i64::MAX, + ))), + ScalarValue::IntervalMonthDayNano(None), + ScalarValue::List(ScalarValue::new_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list( + &[ + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), + ScalarValue::List(ScalarValue::new_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::List(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< + Int32Type, + _, + _, + >( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ))), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("foo")), + ), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(None)), + ), + ScalarValue::Binary(Some(b"bar".to_vec())), + ScalarValue::Binary(None), + ScalarValue::LargeBinary(Some(b"bar".to_vec())), + ScalarValue::LargeBinary(None), + ScalarValue::Struct( + Some(vec![ + ScalarValue::Int32(Some(23)), + ScalarValue::Boolean(Some(false)), + ]), + Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + ]), + ), + ScalarValue::Struct( + None, + Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("a", DataType::Boolean, false), + ]), + ), + ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), + ScalarValue::FixedSizeBinary(0, None), + ScalarValue::FixedSizeBinary(5, None), + ]; + + for test_case in should_pass.into_iter() { + let proto: protobuf::ScalarValue = (&test_case) + .try_into() + .expect("failed conversion to protobuf"); + + let roundtrip: ScalarValue = (&proto) + .try_into() + .expect("failed conversion from protobuf"); + + assert_eq!( + test_case, roundtrip, + "ScalarValue was not the same after round trip!\n\n\ + Input: {test_case:?}\n\nRoundtrip: {roundtrip:?}" + ); + } +} + +#[test] +fn round_trip_scalar_types() { + let should_pass: Vec = vec![ + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Utf8, + DataType::LargeUtf8, + // Recursive list tests + DataType::List(new_arc_field("level1", DataType::Boolean, true)), + DataType::List(new_arc_field( + "Level1", + DataType::List(new_arc_field("level2", DataType::Date32, true)), + true, + )), + ]; + + for test_case in should_pass.into_iter() { + let field = Field::new("item", test_case, true); + let proto: protobuf::Field = (&field).try_into().unwrap(); + let roundtrip: Field = (&proto).try_into().unwrap(); + assert_eq!(format!("{field:?}"), format!("{roundtrip:?}")); + } +} + +#[test] +fn round_trip_datatype() { + let test_cases: Vec = vec![ + DataType::Null, + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Microsecond), + DataType::Time32(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Second), + DataType::Time64(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Second), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Nanosecond), + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Binary, + DataType::FixedSizeBinary(0), + DataType::FixedSizeBinary(1234), + DataType::FixedSizeBinary(-432), + DataType::LargeBinary, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Decimal128(7, 12), + // Recursive list tests + DataType::List(new_arc_field("Level1", DataType::Binary, true)), + DataType::List(new_arc_field( + "Level1", + DataType::List(new_arc_field( + "Level2", + DataType::FixedSizeBinary(53), + false, + )), + true, + )), + // Fixed size lists + DataType::FixedSizeList(new_arc_field("Level1", DataType::Binary, true), 4), + DataType::FixedSizeList( + new_arc_field( + "Level1", + DataType::List(new_arc_field( + "Level2", + DataType::FixedSizeBinary(53), + false, + )), + true, + ), + 41, + ), + // Struct Testing + DataType::Struct(Fields::from(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ])), + DataType::Struct(Fields::from(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + Field::new( + "nested_struct", + DataType::Struct(Fields::from(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ])), + true, + ), + ])), + DataType::Union( + UnionFields::new( + vec![7, 5, 3], + vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ], + ), + UnionMode::Sparse, + ), + DataType::Union( + UnionFields::new( + vec![5, 8, 1], + vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + Field::new_struct( + "nested_struct", + vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ], + true, + ), + ], + ), + UnionMode::Dense, + ), + DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Struct(Fields::from(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ]))), + ), + DataType::Dictionary( + Box::new(DataType::Decimal128(10, 50)), + Box::new(DataType::FixedSizeList( + new_arc_field("Level1", DataType::Binary, true), + 4, + )), + ), + DataType::Map( + new_arc_field( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + ])), + true, + ), + false, + ), + ]; + + for test_case in test_cases.into_iter() { + let proto: protobuf::ArrowType = (&test_case).try_into().unwrap(); + let roundtrip: DataType = (&proto).try_into().unwrap(); + assert_eq!(format!("{test_case:?}"), format!("{roundtrip:?}")); + } +} + +#[test] +fn roundtrip_dict_id() -> Result<()> { + let dict_id = 42; + let field = Field::new( + "keys", + DataType::List(Arc::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + dict_id, + false, + ))), + false, + ); + let schema = Arc::new(Schema::new(vec![field])); + + // encode + let mut buf: Vec = vec![]; + let schema_proto: datafusion_proto::generated::datafusion::Schema = + schema.try_into().unwrap(); + schema_proto.encode(&mut buf).unwrap(); + + // decode + let schema_proto = + datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let decoded: Schema = (&schema_proto).try_into()?; + + // assert + let keys = decoded.fields().iter().last().unwrap(); + match keys.data_type() { + DataType::List(field) => { + assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained"); + } + _ => panic!("Invalid type"), + } + + Ok(()) +} + +#[test] +fn roundtrip_null_scalar_values() { + let test_types = vec![ + ScalarValue::Boolean(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ScalarValue::Date32(None), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampNanosecond(None, None), + ]; + + for test_case in test_types.into_iter() { + let proto_scalar: protobuf::ScalarValue = (&test_case).try_into().unwrap(); + let returned_scalar: datafusion::scalar::ScalarValue = + (&proto_scalar).try_into().unwrap(); + assert_eq!(format!("{:?}", &test_case), format!("{returned_scalar:?}")); + } +} + +#[test] +fn roundtrip_field() { + let field = Field::new("f", DataType::Int32, true).with_metadata(HashMap::from([ + (String::from("k1"), String::from("v1")), + (String::from("k2"), String::from("v2")), + ])); + let proto_field: protobuf::Field = (&field).try_into().unwrap(); + let returned_field: Field = (&proto_field).try_into().unwrap(); + assert_eq!(field, returned_field); +} + +#[test] +fn roundtrip_schema() { + let schema = Schema::new_with_metadata( + vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Decimal128(15, 2), true) + .with_metadata(HashMap::from([(String::from("k1"), String::from("v1"))])), + ], + HashMap::from([ + (String::from("k2"), String::from("v2")), + (String::from("k3"), String::from("v3")), + ]), + ); + let proto_schema: protobuf::Schema = (&schema).try_into().unwrap(); + let returned_schema: Schema = (&proto_schema).try_into().unwrap(); + assert_eq!(schema, returned_schema); +} + +#[test] +fn roundtrip_dfschema() { + let dfschema = DFSchema::new_with_metadata( + vec![ + DFField::new_unqualified("a", DataType::Int64, false), + DFField::new(Some("t"), "b", DataType::Decimal128(15, 2), true) + .with_metadata(HashMap::from([(String::from("k1"), String::from("v1"))])), + ], + HashMap::from([ + (String::from("k2"), String::from("v2")), + (String::from("k3"), String::from("v3")), + ]), + ) + .unwrap(); + let proto_dfschema: protobuf::DfSchema = (&dfschema).try_into().unwrap(); + let returned_dfschema: DFSchema = (&proto_dfschema).try_into().unwrap(); + assert_eq!(dfschema, returned_dfschema); + + let arc_dfschema = Arc::new(dfschema.clone()); + let proto_dfschema: protobuf::DfSchema = (&arc_dfschema).try_into().unwrap(); + let returned_arc_dfschema: DFSchemaRef = proto_dfschema.try_into().unwrap(); + assert_eq!(arc_dfschema, returned_arc_dfschema); + assert_eq!(dfschema, *returned_arc_dfschema); +} + +#[test] +fn roundtrip_not() { + let test_expr = Expr::Not(Box::new(lit(1.0_f32))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_is_null() { + let test_expr = Expr::IsNull(Box::new(col("id"))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_is_not_null() { + let test_expr = Expr::IsNotNull(Box::new(col("id"))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_between() { + let test_expr = Expr::Between(Between::new( + Box::new(lit(1.0_f32)), + true, + Box::new(lit(2.0_f32)), + Box::new(lit(3.0_f32)), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_binary_op() { + fn test(op: Operator) { + let test_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(1.0_f32)), + op, + Box::new(lit(2.0_f32)), + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + test(Operator::ArrowAt); + test(Operator::AtArrow); + test(Operator::StringConcat); + test(Operator::RegexNotIMatch); + test(Operator::RegexNotMatch); + test(Operator::RegexIMatch); + test(Operator::RegexMatch); + test(Operator::BitwiseShiftRight); + test(Operator::BitwiseShiftLeft); + test(Operator::BitwiseAnd); + test(Operator::BitwiseOr); + test(Operator::BitwiseXor); + test(Operator::IsDistinctFrom); + test(Operator::IsNotDistinctFrom); + test(Operator::And); + test(Operator::Or); + test(Operator::Eq); + test(Operator::NotEq); + test(Operator::Lt); + test(Operator::LtEq); + test(Operator::Gt); + test(Operator::GtEq); +} + +#[test] +fn roundtrip_case() { + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(lit(4.0_f32))), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_case_with_null() { + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(Expr::Literal(ScalarValue::Null))), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_null_literal() { + let test_expr = Expr::Literal(ScalarValue::Null); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_cast() { + let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_try_cast() { + let test_expr = + Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + + let test_expr = + Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_sort_expr() { + let test_expr = Expr::Sort(Sort::new(Box::new(lit(1.0_f32)), true, true)); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_negative() { + let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_inlist() { + let test_expr = Expr::InList(InList::new( + Box::new(lit(1.0_f32)), + vec![lit(2.0_f32)], + true, + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_wildcard() { + let test_expr = Expr::Wildcard { qualifier: None }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_qualified_wildcard() { + let test_expr = Expr::Wildcard { + qualifier: Some("foo".into()), + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_sqrt() { + let test_expr = Expr::ScalarFunction(ScalarFunction::new(Sqrt, vec![col("col")])); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_like() { + fn like(negated: bool, escape_char: Option) { + let test_expr = Expr::Like(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + false, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + like(true, Some('X')); + like(false, Some('\\')); + like(true, None); + like(false, None); +} + +#[test] +fn roundtrip_ilike() { + fn ilike(negated: bool, escape_char: Option) { + let test_expr = Expr::Like(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + true, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + ilike(true, Some('X')); + ilike(false, Some('\\')); + ilike(true, None); + ilike(false, None); +} + +#[test] +fn roundtrip_similar_to() { + fn similar_to(negated: bool, escape_char: Option) { + let test_expr = Expr::SimilarTo(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + false, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + similar_to(true, Some('X')); + similar_to(false, Some('\\')); + similar_to(true, None); + similar_to(false, None); +} + +#[test] +fn roundtrip_count() { + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("bananas")], + false, + None, + None, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_count_distinct() { + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("bananas")], + true, + None, + None, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_approx_percentile_cont() { + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::ApproxPercentileCont, + vec![col("bananas"), lit(0.42_f32)], + false, + None, + None, + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_aggregate_udf() { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Arc::new(dummy_agg.clone()), + vec![lit(1.0_f64)], + false, + Some(Box::new(lit(true))), + None, + )); + + let ctx = SessionContext::new(); + ctx.register_udaf(dummy_agg); + + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_scalar_udf() { + let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); + + let scalar_fn = make_scalar_function(fn_impl); + + let udf = create_udf( + "dummy", + vec![DataType::Utf8], + Arc::new(DataType::Utf8), + Volatility::Immutable, + scalar_fn, + ); + + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf.clone()), + vec![lit("")], + )); + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_grouping_sets() { + let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("a")], + vec![col("b")], + vec![col("a"), col("b")], + ])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_rollup() { + let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_cube() { + let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); +} + +#[test] +fn roundtrip_substr() { + // substr(string, position) + let test_expr = + Expr::ScalarFunction(ScalarFunction::new(Substr, vec![col("col"), lit(1_i64)])); + + // substr(string, position, count) + let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( + Substr, + vec![col("col"), lit(1_i64), lit(1_i64)], + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx.clone()); + roundtrip_expr_test(test_expr_with_count, ctx); +} +#[test] +fn roundtrip_window() { + let ctx = SessionContext::new(); + + // 1. without window_frame + let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + vec![], + vec![col("col1")], + vec![col("col2")], + WindowFrame::new(true), + )); + + // 2. with default window_frame + let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + vec![], + vec![col("col1")], + vec![col("col2")], + WindowFrame::new(true), + )); + + // 3. with window_frame with row numbers + let range_number_frame = WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + vec![], + vec![col("col1")], + vec![col("col2")], + range_number_frame, + )); + + // 4. test with AggregateFunction + let row_number_frame = WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateFunction(AggregateFunction::Max), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], + row_number_frame.clone(), + )); + + // 5. test with AggregateUDF + #[derive(Debug)] + struct DummyAggr {} + + impl Accumulator for DummyAggr { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(DummyAggr {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], + row_number_frame.clone(), + )); + ctx.register_udaf(dummy_agg); + + // 6. test with WindowUDF + #[derive(Clone, Debug)] + struct DummyWindow {} + + impl PartitionEvaluator for DummyWindow { + fn uses_window_frame(&self) -> bool { + true + } + + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &std::ops::Range, + ) -> Result { + Ok(ScalarValue::Float64(None)) + } + } + + fn return_type(arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ); + } + Ok(Arc::new(arg_types[0].clone())) + } + + fn make_partition_evaluator() -> Result> { + Ok(Box::new(DummyWindow {})) + } + + let dummy_window_udf = WindowUDF::new( + "dummy_udwf", + &Signature::exact(vec![DataType::Float64], Volatility::Immutable), + &(Arc::new(return_type) as _), + &(Arc::new(make_partition_evaluator) as _), + ); + + let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], + row_number_frame, + )); + + ctx.register_udwf(dummy_window_udf); + + roundtrip_expr_test(test_expr1, ctx.clone()); + roundtrip_expr_test(test_expr2, ctx.clone()); + roundtrip_expr_test(test_expr3, ctx.clone()); + roundtrip_expr_test(test_expr4, ctx.clone()); + roundtrip_expr_test(test_expr5, ctx.clone()); + roundtrip_expr_test(test_expr6, ctx); +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs new file mode 100644 index 0000000000000..da76209dbb496 --- /dev/null +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -0,0 +1,834 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Deref; +use std::sync::Arc; + +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::compute::kernels::sort::SortOptions; +use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; +use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; +use datafusion::datasource::object_store::ObjectStoreUrl; +use datafusion::datasource::physical_plan::{ + FileScanConfig, FileSinkConfig, ParquetExec, +}; +use datafusion::execution::context::ExecutionProps; +use datafusion::logical_expr::{ + create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, +}; +use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::analyze::AnalyzeExec; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::physical_plan::expressions::{ + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, + GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum, +}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::physical_plan::insert::FileSinkExec; +use datafusion::physical_plan::joins::{ + HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, +}; +use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::windows::{ + BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, +}; +use datafusion::physical_plan::{ + functions, udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, +}; +use datafusion::prelude::SessionContext; +use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::stats::Precision; +use datafusion_common::{FileTypeWriterOptions, Result}; +use datafusion_expr::{ + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, + StateTypeFunction, WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; +use datafusion_proto::protobuf; + +fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto: protobuf::PhysicalPlanNode = + protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) + .expect("to proto"); + let runtime = ctx.runtime_env(); + let result_exec_plan: Arc = proto + .try_into_physical_plan(&ctx, runtime.deref(), &codec) + .expect("from proto"); + assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + Ok(()) +} + +fn roundtrip_test_with_context( + exec_plan: Arc, + ctx: SessionContext, +) -> Result<()> { + let codec = DefaultPhysicalExtensionCodec {}; + let proto: protobuf::PhysicalPlanNode = + protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) + .expect("to proto"); + let runtime = ctx.runtime_env(); + let result_exec_plan: Arc = proto + .try_into_physical_plan(&ctx, runtime.deref(), &codec) + .expect("from proto"); + assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + Ok(()) +} + +#[test] +fn roundtrip_empty() -> Result<()> { + roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) +} + +#[test] +fn roundtrip_date_time_interval() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("some_date", DataType::Date32, false), + Field::new( + "some_interval", + DataType::Interval(IntervalUnit::DayTime), + false, + ), + ]); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); + let date_expr = col("some_date", &schema)?; + let literal_expr = col("some_interval", &schema)?; + let date_time_interval_expr = + binary(date_expr, Operator::Plus, literal_expr, &schema)?; + let plan = Arc::new(ProjectionExec::try_new( + vec![(date_time_interval_expr, "result".to_string())], + input, + )?); + roundtrip_test(plan) +} + +#[test] +fn roundtrip_local_limit() -> Result<()> { + roundtrip_test(Arc::new(LocalLimitExec::new( + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + 25, + ))) +} + +#[test] +fn roundtrip_global_limit() -> Result<()> { + roundtrip_test(Arc::new(GlobalLimitExec::new( + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + 0, + Some(25), + ))) +} + +#[test] +fn roundtrip_global_skip_no_limit() -> Result<()> { + roundtrip_test(Arc::new(GlobalLimitExec::new( + Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), + 10, + None, // no limit + ))) +} + +#[test] +fn roundtrip_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + for partition_mode in &[PartitionMode::Partitioned, PartitionMode::CollectLeft] { + roundtrip_test(Arc::new(HashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + *partition_mode, + false, + )?))?; + } + } + Ok(()) +} + +#[test] +fn roundtrip_nested_loop_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(NestedLoopJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + None, + join_type, + )?))?; + } + Ok(()) +} + +#[test] +fn roundtrip_window() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let window_frame = WindowFrame { + units: datafusion_expr::WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)), + end_bound: WindowFrameBound::CurrentRow, + }; + + let builtin_window_expr = Arc::new(BuiltInWindowExpr::new( + Arc::new(NthValue::first( + "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + col("a", &schema)?, + DataType::Int64, + )), + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(window_frame), + )); + + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( + Arc::new(Avg::new( + cast(col("b", &schema)?, &schema, DataType::Float64)?, + "AVG(b)".to_string(), + DataType::Float64, + )), + &[], + &[], + Arc::new(WindowFrame::new(false)), + )); + + let window_frame = WindowFrame { + units: datafusion_expr::WindowFrameUnits::Range, + start_bound: WindowFrameBound::CurrentRow, + end_bound: WindowFrameBound::Preceding(ScalarValue::Int64(None)), + }; + + let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(Sum::new( + cast(col("a", &schema)?, &schema, DataType::Float64)?, + "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", + DataType::Float64, + )), + &[], + &[], + Arc::new(window_frame), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(WindowAggExec::try_new( + vec![ + builtin_window_expr, + plain_aggr_window_expr, + sliding_aggr_window_expr, + ], + input, + vec![col("b", &schema)?], + )?)) +} + +#[test] +fn rountrip_aggregate() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + cast(col("b", &schema)?, &schema, DataType::Float64)?, + "AVG(b)".to_string(), + DataType::Float64, + ))]; + + roundtrip_test(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?)) +} + +#[test] +fn roundtrip_aggregate_udaf() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + #[derive(Debug)] + struct Example; + impl Accumulator for Example { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(0))]) + } + + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(0))) + } + + fn size(&self) -> usize { + 0 + } + } + + let rt_func: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); + let st_func: StateTypeFunction = + Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); + + let udaf = AggregateUDF::new( + "example", + &Signature::exact(vec![DataType::Int64], Volatility::Immutable), + &rt_func, + &accumulator, + &st_func, + ); + + let ctx = SessionContext::new(); + ctx.register_udaf(udaf.clone()); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates: Vec> = vec![udaf::create_aggregate_expr( + &udaf, + &[col("b", &schema)?], + &schema, + "example_agg", + )?]; + + roundtrip_test_with_context( + Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?), + ctx, + ) +} + +#[test] +fn roundtrip_filter_with_not_and_in_list() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let field_c = Field::new("c", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let not = Arc::new(NotExpr::new(col("a", &schema)?)); + let in_list = in_list( + col("b", &schema)?, + vec![ + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Int64(Some(2))), + ], + &false, + schema.as_ref(), + )?; + let and = binary(not, Operator::And, in_list, &schema)?; + roundtrip_test(Arc::new(FilterExec::try_new( + and, + Arc::new(EmptyExec::new(schema.clone())), + )?)) +} + +#[test] +fn roundtrip_sort() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ]; + roundtrip_test(Arc::new(SortExec::new( + sort_exprs, + Arc::new(EmptyExec::new(schema)), + ))) +} + +#[test] +fn roundtrip_sort_preserve_partitioning() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ]; + + roundtrip_test(Arc::new(SortExec::new( + sort_exprs.clone(), + Arc::new(EmptyExec::new(schema.clone())), + )))?; + + roundtrip_test(Arc::new( + SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))) + .with_preserve_partitioning(true), + )) +} + +#[test] +fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { + let scan_config = FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_schema: Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + false, + )])), + file_groups: vec![vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )]], + statistics: Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&Arc::new(Schema::new(vec![ + Field::new("col", DataType::Utf8, false), + ]))), + }, + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }; + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("col", 1)), + Operator::Eq, + lit("1"), + )); + roundtrip_test(Arc::new(ParquetExec::new( + scan_config, + Some(predicate), + None, + ))) +} + +#[test] +fn roundtrip_builtin_scalar_function() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + let execution_props = ExecutionProps::new(); + + let fun_expr = + functions::create_physical_fun(&BuiltinScalarFunction::Acos, &execution_props)?; + + let expr = ScalarFunctionExpr::new( + "acos", + fun_expr, + vec![col("a", &schema)?], + DataType::Int64, + None, + ); + + let project = + ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; + + roundtrip_test(Arc::new(project)) +} + +#[test] +fn roundtrip_scalar_udf() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); + + let scalar_fn = make_scalar_function(fn_impl); + + let udf = create_udf( + "dummy", + vec![DataType::Int64], + Arc::new(DataType::Int64), + Volatility::Immutable, + scalar_fn.clone(), + ); + + let expr = ScalarFunctionExpr::new( + "dummy", + scalar_fn, + vec![col("a", &schema)?], + DataType::Int64, + None, + ); + + let project = + ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; + + let ctx = SessionContext::new(); + + ctx.register_udf(udf); + + roundtrip_test_with_context(Arc::new(project), ctx) +} + +#[test] +fn roundtrip_distinct_count() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let aggregates: Vec> = vec![Arc::new(DistinctCount::new( + DataType::Int64, + col("b", &schema)?, + "COUNT(DISTINCT b)".to_string(), + ))]; + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + roundtrip_test(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups), + aggregates.clone(), + vec![None], + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?)) +} + +#[test] +fn roundtrip_like() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); + let like_expr = like( + false, + false, + col("a", &schema)?, + col("b", &schema)?, + &schema, + )?; + let plan = Arc::new(ProjectionExec::try_new( + vec![(like_expr, "result".to_string())], + input, + )?); + roundtrip_test(plan) +} + +#[test] +fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { + let fields = vec![ + Field::new("id", DataType::Int64, true), + Field::new_struct( + "arg", + Fields::from(vec![Field::new("name", DataType::Float64, true)]), + true, + ), + ]; + + let schema = Schema::new(fields); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); + + let col_arg = col("arg", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::NamedStructField { + name: ScalarValue::from("name"), + }, + )); + + let plan = Arc::new(ProjectionExec::try_new( + vec![(get_indexed_field_expr, "result".to_string())], + input, + )?); + + roundtrip_test(plan) +} + +#[test] +fn roundtrip_get_indexed_field_list_index() -> Result<()> { + let fields = vec![ + Field::new("id", DataType::Int64, true), + Field::new_list("arg", Field::new("item", DataType::Float64, true), true), + Field::new("key", DataType::Int64, true), + ]; + + let schema = Schema::new(fields); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); + + let col_arg = col("arg", &schema)?; + let col_key = col("key", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::ListIndex { key: col_key }, + )); + + let plan = Arc::new(ProjectionExec::try_new( + vec![(get_indexed_field_expr, "result".to_string())], + input, + )?); + + roundtrip_test(plan) +} + +#[test] +fn roundtrip_get_indexed_field_list_range() -> Result<()> { + let fields = vec![ + Field::new("id", DataType::Int64, true), + Field::new_list("arg", Field::new("item", DataType::Float64, true), true), + Field::new("start", DataType::Int64, true), + Field::new("stop", DataType::Int64, true), + ]; + + let schema = Schema::new(fields); + let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); + + let col_arg = col("arg", &schema)?; + let col_start = col("start", &schema)?; + let col_stop = col("stop", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::ListRange { + start: col_start, + stop: col_stop, + }, + )); + + let plan = Arc::new(ProjectionExec::try_new( + vec![(get_indexed_field_expr, "result".to_string())], + input, + )?); + + roundtrip_test(plan) +} + +#[test] +fn roundtrip_analyze() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Schema::new(vec![field_a, field_b]); + let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); + + roundtrip_test(Arc::new(AnalyzeExec::new( + false, + false, + input, + Arc::new(schema), + ))) +} + +#[test] +fn roundtrip_json_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + unbounded_input: false, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::JSON(JsonWriterOptions::new( + CompressionTypeVariant::UNCOMPRESSED, + )), + }; + let data_sink = Arc::new(JsonSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + +#[test] +fn roundtrip_sym_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Column::new("col", schema_left.index_of("col")?), + Column::new("col", schema_right.index_of("col")?), + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + for partition_mode in &[ + StreamJoinPartitionMode::Partitioned, + StreamJoinPartitionMode::SinglePartition, + ] { + roundtrip_test(Arc::new( + datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + false, + *partition_mode, + )?, + ))?; + } + } + Ok(()) +} + +#[test] +fn roundtrip_union() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let left = EmptyExec::new(Arc::new(schema_left)); + let right = EmptyExec::new(Arc::new(schema_right)); + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let union = UnionExec::new(inputs); + roundtrip_test(Arc::new(union)) +} + +#[test] +fn roundtrip_interleave() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let partition = Partitioning::Hash(vec![], 3); + let left = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_left))), + partition.clone(), + )?; + let right = RepartitionExec::try_new( + Arc::new(EmptyExec::new(Arc::new(schema_right))), + partition.clone(), + )?; + let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; + let interleave = InterleaveExec::try_new(inputs)?; + roundtrip_test(Arc::new(interleave)) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs new file mode 100644 index 0000000000000..5b890accd81f2 --- /dev/null +++ b/datafusion/proto/tests/cases/serialize.rs @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; + +use datafusion::execution::FunctionRegistry; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::prelude::SessionContext; +use datafusion_expr::{col, create_udf, lit}; +use datafusion_expr::{Expr, Volatility}; +use datafusion_proto::bytes::Serializeable; + +#[test] +#[should_panic( + expected = "Error decoding expr as protobuf: failed to decode Protobuf message" +)] +fn bad_decode() { + Expr::from_bytes(b"Leet").unwrap(); +} + +#[test] +#[cfg(feature = "json")] +fn plan_to_json() { + use datafusion_common::DFSchema; + use datafusion_expr::{logical_plan::EmptyRelation, LogicalPlan}; + use datafusion_proto::bytes::logical_plan_to_json; + + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }); + let actual = logical_plan_to_json(&plan).unwrap(); + let expected = r#"{"emptyRelation":{}}"#.to_string(); + assert_eq!(actual, expected); +} + +#[test] +#[cfg(feature = "json")] +fn json_to_plan() { + use datafusion_expr::LogicalPlan; + use datafusion_proto::bytes::logical_plan_from_json; + + let input = r#"{"emptyRelation":{}}"#.to_string(); + let ctx = SessionContext::new(); + let actual = logical_plan_from_json(&input, &ctx).unwrap(); + let result = matches!(actual, LogicalPlan::EmptyRelation(_)); + assert!(result, "Should parse empty relation"); +} + +#[test] +fn udf_roundtrip_with_registry() { + let ctx = context_with_udf(); + + let expr = ctx + .udf("dummy") + .expect("could not find udf") + .call(vec![lit("")]); + + let bytes = expr.to_bytes().unwrap(); + let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + + assert_eq!(expr, deserialized_expr); +} + +#[test] +#[should_panic( + expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'" +)] +fn udf_roundtrip_without_registry() { + let ctx = context_with_udf(); + + let expr = ctx + .udf("dummy") + .expect("could not find udf") + .call(vec![lit("")]); + + let bytes = expr.to_bytes().unwrap(); + // should explode + Expr::from_bytes(&bytes).unwrap(); +} + +fn roundtrip_expr(expr: &Expr) -> Expr { + let bytes = expr.to_bytes().unwrap(); + Expr::from_bytes(&bytes).unwrap() +} + +#[test] +fn exact_roundtrip_linearized_binary_expr() { + // (((A AND B) AND C) AND D) + let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D")); + assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered)); + + // Ensure that no other variation becomes equal + let other_variants = vec![ + // (((B AND A) AND C) AND D) + col("B").and(col("A")).and(col("C")).and(col("D")), + // (((A AND C) AND B) AND D) + col("A").and(col("C")).and(col("B")).and(col("D")), + // (((A AND B) AND D) AND C) + col("A").and(col("B")).and(col("D")).and(col("C")), + // A AND (B AND (C AND D))) + col("A").and(col("B").and(col("C").and(col("D")))), + ]; + for case in other_variants { + // Each variant is still equal to itself + assert_eq!(case, roundtrip_expr(&case)); + + // But non of them is equal to the original + assert_ne!(expr_ordered, roundtrip_expr(&case)); + assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case)); + } +} + +#[test] +fn roundtrip_qualified_alias() { + let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column"); + assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); +} + +#[test] +fn roundtrip_deeply_nested_binary_expr() { + // We need more stack space so this doesn't overflow in dev builds + std::thread::Builder::new() + .stack_size(10_000_000) + .spawn(|| { + let n = 100; + // a < 5 + let basic_expr = col("a").lt(lit(5i32)); + // (a < 5) OR (a < 5) OR (a < 5) OR ... + let or_chain = + (0..n).fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone())); + // (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ... + let expr = + (0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone())); + + // Should work fine. + let bytes = expr.to_bytes().unwrap(); + + let decoded_expr = Expr::from_bytes(&bytes) + .expect("serialization worked, so deserialization should work as well"); + assert_eq!(decoded_expr, expr); + }) + .expect("spawning thread") + .join() + .expect("joining thread"); +} + +#[test] +fn roundtrip_deeply_nested_binary_expr_reverse_order() { + // We need more stack space so this doesn't overflow in dev builds + std::thread::Builder::new() + .stack_size(10_000_000) + .spawn(|| { + let n = 100; + + // a < 5 + let expr_base = col("a").lt(lit(5i32)); + + // ((a < 5 AND a < 5) AND a < 5) AND ... + let and_chain = + (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone())); + + // a < 5 AND (a < 5 AND (a < 5 AND ...)) + let expr = expr_base.and(and_chain); + + // Should work fine. + let bytes = expr.to_bytes().unwrap(); + + let decoded_expr = Expr::from_bytes(&bytes) + .expect("serialization worked, so deserialization should work as well"); + assert_eq!(decoded_expr, expr); + }) + .expect("spawning thread") + .join() + .expect("joining thread"); +} + +#[test] +fn roundtrip_deeply_nested() { + // we need more stack space so this doesn't overflow in dev builds + std::thread::Builder::new().stack_size(20_000_000).spawn(|| { + // don't know what "too much" is, so let's slowly try to increase complexity + let n_max = 100; + + for n in 1..n_max { + println!("testing: {n}"); + + let expr_base = col("a").lt(lit(5i32)); + // Generate a tree of AND and OR expressions (no subsequent ANDs or ORs). + let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) }); + + // Convert it to an opaque form + let bytes = match expr.to_bytes() { + Ok(bytes) => bytes, + Err(_) => { + // found expression that is too deeply nested + return; + } + }; + + // Decode bytes from somewhere (over network, etc. + let decoded_expr = Expr::from_bytes(&bytes).expect("serialization worked, so deserialization should work as well"); + assert_eq!(expr, decoded_expr); + } + + panic!("did not find a 'too deeply nested' expression, tested up to a depth of {n_max}") + }).expect("spawning thread").join().expect("joining thread"); +} + +/// return a `SessionContext` with a `dummy` function registered as a UDF +fn context_with_udf() -> SessionContext { + let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); + + let scalar_fn = make_scalar_function(fn_impl); + + let udf = create_udf( + "dummy", + vec![DataType::Utf8], + Arc::new(DataType::Utf8), + Volatility::Immutable, + scalar_fn, + ); + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + ctx +} diff --git a/datafusion/proto/tests/proto_integration.rs b/datafusion/proto/tests/proto_integration.rs new file mode 100644 index 0000000000000..6ce41c9de71a8 --- /dev/null +++ b/datafusion/proto/tests/proto_integration.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Run all tests that are found in the `cases` directory +mod cases; diff --git a/datafusion/proto/testdata/test.csv b/datafusion/proto/tests/testdata/test.csv similarity index 100% rename from datafusion/proto/testdata/test.csv rename to datafusion/proto/tests/testdata/test.csv diff --git a/datafusion/row/src/accessor.rs b/datafusion/row/src/accessor.rs deleted file mode 100644 index a0b5a70df9933..0000000000000 --- a/datafusion/row/src/accessor.rs +++ /dev/null @@ -1,384 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RowAccessor`] provides a Read/Write/Modify access for row with all fixed-sized fields: - -use crate::layout::RowLayout; -use crate::validity::NullBitsFormatter; -use crate::{fn_get_idx, fn_get_idx_opt, fn_set_idx}; -use arrow::datatypes::{DataType, Schema}; -use arrow::util::bit_util::{get_bit_raw, set_bit_raw}; -use datafusion_common::ScalarValue; -use std::ops::{BitAnd, BitOr, BitXor}; -use std::sync::Arc; - -//TODO: DRY with reader and writer - -/// Provides read/write/modify access to a tuple stored in Row format -/// at `data[base_offset..]` -/// -/// ```text -/// Set / Update data -/// in [u8] -/// ─ ─ ─ ─ ─ ─ ─ ┐ Read data out as native -/// │ types or ScalarValues -/// │ -/// │ ┌───────────────────────┐ -/// │ │ -/// └ ▶│ [u8] │─ ─ ─ ─ ─ ─ ─ ─▶ -/// │ │ -/// └───────────────────────┘ -/// ``` -pub struct RowAccessor<'a> { - /// Layout on how to read each field - layout: Arc, - /// Raw bytes slice where the tuple stores - data: &'a mut [u8], - /// Start position for the current tuple in the raw bytes slice. - base_offset: usize, -} - -impl<'a> std::fmt::Debug for RowAccessor<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.null_free() { - write!(f, "null_free") - } else { - let null_bits = self.null_bits(); - write!( - f, - "{:?}", - NullBitsFormatter::new(null_bits, self.layout.field_count) - ) - } - } -} - -#[macro_export] -macro_rules! fn_add_idx { - ($NATIVE: ident) => { - paste::item! { - /// add field at `idx` with `value` - #[inline(always)] - pub fn [](&mut self, idx: usize, value: $NATIVE) { - if self.is_valid_at(idx) { - self.[](idx, value + self.[](idx)); - } else { - self.set_non_null_at(idx); - self.[](idx, value); - } - } - } - }; -} - -macro_rules! fn_max_min_idx { - ($NATIVE: ident, $OP: ident) => { - paste::item! { - /// check max then update - #[inline(always)] - pub fn [<$OP _ $NATIVE>](&mut self, idx: usize, value: $NATIVE) { - if self.is_valid_at(idx) { - let v = value.$OP(self.[](idx)); - self.[](idx, v); - } else { - self.set_non_null_at(idx); - self.[](idx, value); - } - } - } - }; -} - -macro_rules! fn_bit_and_or_xor_idx { - ($NATIVE: ident, $OP: ident) => { - paste::item! { - /// check bit_and then update - #[inline(always)] - pub fn [<$OP _ $NATIVE>](&mut self, idx: usize, value: $NATIVE) { - if self.is_valid_at(idx) { - let v = value.$OP(self.[](idx)); - self.[](idx, v); - } else { - self.set_non_null_at(idx); - self.[](idx, value); - } - } - } - }; -} - -macro_rules! fn_get_idx_scalar { - ($NATIVE: ident, $SCALAR:ident) => { - paste::item! { - #[inline(always)] - pub fn [](&self, idx: usize) -> ScalarValue { - if self.is_valid_at(idx) { - ScalarValue::$SCALAR(Some(self.[](idx))) - } else { - ScalarValue::$SCALAR(None) - } - } - } - }; -} - -impl<'a> RowAccessor<'a> { - /// new - pub fn new(schema: &Schema) -> Self { - Self { - layout: Arc::new(RowLayout::new(schema)), - data: &mut [], - base_offset: 0, - } - } - - pub fn new_from_layout(layout: Arc) -> Self { - Self { - layout, - data: &mut [], - base_offset: 0, - } - } - - /// Update this row to point to position `offset` in `base` - pub fn point_to(&mut self, offset: usize, data: &'a mut [u8]) { - self.base_offset = offset; - self.data = data; - } - - #[inline] - fn assert_index_valid(&self, idx: usize) { - assert!(idx < self.layout.field_count); - } - - #[inline(always)] - fn field_offsets(&self) -> &[usize] { - &self.layout.field_offsets - } - - #[inline(always)] - fn null_free(&self) -> bool { - self.layout.null_free - } - - #[inline(always)] - fn null_bits(&self) -> &[u8] { - if self.null_free() { - &[] - } else { - let start = self.base_offset; - &self.data[start..start + self.layout.null_width] - } - } - - fn is_valid_at(&self, idx: usize) -> bool { - unsafe { get_bit_raw(self.null_bits().as_ptr(), idx) } - } - - // ------------------------------ - // ----- Fixed Sized getters ---- - // ------------------------------ - - fn get_bool(&self, idx: usize) -> bool { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - let value = &self.data[self.base_offset + offset..]; - value[0] != 0 - } - - fn get_u8(&self, idx: usize) -> u8 { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[self.base_offset + offset] - } - - fn_get_idx!(u16, 2); - fn_get_idx!(u32, 4); - fn_get_idx!(u64, 8); - fn_get_idx!(i8, 1); - fn_get_idx!(i16, 2); - fn_get_idx!(i32, 4); - fn_get_idx!(i64, 8); - fn_get_idx!(f32, 4); - fn_get_idx!(f64, 8); - fn_get_idx!(i128, 16); - - fn_get_idx_opt!(bool); - fn_get_idx_opt!(u8); - fn_get_idx_opt!(u16); - fn_get_idx_opt!(u32); - fn_get_idx_opt!(u64); - fn_get_idx_opt!(i8); - fn_get_idx_opt!(i16); - fn_get_idx_opt!(i32); - fn_get_idx_opt!(i64); - fn_get_idx_opt!(f32); - fn_get_idx_opt!(f64); - fn_get_idx_opt!(i128); - - fn_get_idx_scalar!(bool, Boolean); - fn_get_idx_scalar!(u8, UInt8); - fn_get_idx_scalar!(u16, UInt16); - fn_get_idx_scalar!(u32, UInt32); - fn_get_idx_scalar!(u64, UInt64); - fn_get_idx_scalar!(i8, Int8); - fn_get_idx_scalar!(i16, Int16); - fn_get_idx_scalar!(i32, Int32); - fn_get_idx_scalar!(i64, Int64); - fn_get_idx_scalar!(f32, Float32); - fn_get_idx_scalar!(f64, Float64); - - fn get_decimal128_scalar(&self, idx: usize, p: u8, s: i8) -> ScalarValue { - if self.is_valid_at(idx) { - ScalarValue::Decimal128(Some(self.get_i128(idx)), p, s) - } else { - ScalarValue::Decimal128(None, p, s) - } - } - - pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue { - match dt { - DataType::Boolean => self.get_bool_scalar(index), - DataType::Int8 => self.get_i8_scalar(index), - DataType::Int16 => self.get_i16_scalar(index), - DataType::Int32 => self.get_i32_scalar(index), - DataType::Int64 => self.get_i64_scalar(index), - DataType::UInt8 => self.get_u8_scalar(index), - DataType::UInt16 => self.get_u16_scalar(index), - DataType::UInt32 => self.get_u32_scalar(index), - DataType::UInt64 => self.get_u64_scalar(index), - DataType::Float32 => self.get_f32_scalar(index), - DataType::Float64 => self.get_f64_scalar(index), - DataType::Decimal128(p, s) => self.get_decimal128_scalar(index, *p, *s), - _ => unreachable!(), - } - } - - // ------------------------------ - // ----- Fixed Sized setters ---- - // ------------------------------ - - pub(crate) fn set_non_null_at(&mut self, idx: usize) { - assert!( - !self.null_free(), - "Unexpected call to set_non_null_at on null-free row writer" - ); - let null_bits = &mut self.data[0..self.layout.null_width]; - unsafe { - set_bit_raw(null_bits.as_mut_ptr(), idx); - } - } - - fn set_bool(&mut self, idx: usize, value: bool) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = u8::from(value); - } - - fn set_u8(&mut self, idx: usize, value: u8) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = value; - } - - fn_set_idx!(u16, 2); - fn_set_idx!(u32, 4); - fn_set_idx!(u64, 8); - fn_set_idx!(i16, 2); - fn_set_idx!(i32, 4); - fn_set_idx!(i64, 8); - fn_set_idx!(f32, 4); - fn_set_idx!(f64, 8); - fn_set_idx!(i128, 16); - - fn set_i8(&mut self, idx: usize, value: i8) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = value.to_le_bytes()[0]; - } - - // ------------------------------ - // ---- Fixed sized updaters ---- - // ------------------------------ - - fn_add_idx!(u8); - fn_add_idx!(u16); - fn_add_idx!(u32); - fn_add_idx!(u64); - fn_add_idx!(i8); - fn_add_idx!(i16); - fn_add_idx!(i32); - fn_add_idx!(i64); - fn_add_idx!(f32); - fn_add_idx!(f64); - fn_add_idx!(i128); - - fn_max_min_idx!(bool, max); - fn_max_min_idx!(u8, max); - fn_max_min_idx!(u16, max); - fn_max_min_idx!(u32, max); - fn_max_min_idx!(u64, max); - fn_max_min_idx!(i8, max); - fn_max_min_idx!(i16, max); - fn_max_min_idx!(i32, max); - fn_max_min_idx!(i64, max); - fn_max_min_idx!(f32, max); - fn_max_min_idx!(f64, max); - fn_max_min_idx!(i128, max); - - fn_max_min_idx!(bool, min); - fn_max_min_idx!(u8, min); - fn_max_min_idx!(u16, min); - fn_max_min_idx!(u32, min); - fn_max_min_idx!(u64, min); - fn_max_min_idx!(i8, min); - fn_max_min_idx!(i16, min); - fn_max_min_idx!(i32, min); - fn_max_min_idx!(i64, min); - fn_max_min_idx!(f32, min); - fn_max_min_idx!(f64, min); - fn_max_min_idx!(i128, min); - - fn_bit_and_or_xor_idx!(bool, bitand); - fn_bit_and_or_xor_idx!(u8, bitand); - fn_bit_and_or_xor_idx!(u16, bitand); - fn_bit_and_or_xor_idx!(u32, bitand); - fn_bit_and_or_xor_idx!(u64, bitand); - fn_bit_and_or_xor_idx!(i8, bitand); - fn_bit_and_or_xor_idx!(i16, bitand); - fn_bit_and_or_xor_idx!(i32, bitand); - fn_bit_and_or_xor_idx!(i64, bitand); - - fn_bit_and_or_xor_idx!(bool, bitor); - fn_bit_and_or_xor_idx!(u8, bitor); - fn_bit_and_or_xor_idx!(u16, bitor); - fn_bit_and_or_xor_idx!(u32, bitor); - fn_bit_and_or_xor_idx!(u64, bitor); - fn_bit_and_or_xor_idx!(i8, bitor); - fn_bit_and_or_xor_idx!(i16, bitor); - fn_bit_and_or_xor_idx!(i32, bitor); - fn_bit_and_or_xor_idx!(i64, bitor); - - fn_bit_and_or_xor_idx!(u8, bitxor); - fn_bit_and_or_xor_idx!(u16, bitxor); - fn_bit_and_or_xor_idx!(u32, bitxor); - fn_bit_and_or_xor_idx!(u64, bitxor); - fn_bit_and_or_xor_idx!(i8, bitxor); - fn_bit_and_or_xor_idx!(i16, bitxor); - fn_bit_and_or_xor_idx!(i32, bitxor); - fn_bit_and_or_xor_idx!(i64, bitxor); -} diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs deleted file mode 100644 index 71471327536a0..0000000000000 --- a/datafusion/row/src/layout.rs +++ /dev/null @@ -1,157 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Various row layouts for different use case - -use crate::schema_null_free; -use arrow::datatypes::{DataType, Schema}; -use arrow::util::bit_util::{ceil, round_upto_power_of_2}; - -/// Row layout stores one or multiple 8-byte word(s) per field for CPU-friendly -/// and efficient processing. -/// -/// It is mainly used to represent the rows with frequently updated content, -/// for example, grouping state for hash aggregation. -/// -/// Each tuple consists of two parts: "`null bit set`" and "`values`". -/// -/// For null-free tuples, the null bit set can be omitted. -/// -/// The null bit set, when present, is aligned to 8 bytes. It stores one bit per field. -/// -/// In the region of the values, we store the fields in the order they are defined in the schema. -/// Each field is stored in one or multiple 8-byte words. -/// -/// ```plaintext -/// ┌─────────────────┬─────────────────────┐ -/// │Validity Bitmask │ Fields │ -/// │ (8-byte aligned)│ (8-byte words) │ -/// └─────────────────┴─────────────────────┘ -/// ``` -/// -/// For example, given the schema (Int8, Float32, Int64) with a null-free tuple -/// -/// Encoding the tuple (1, 3.14, 42) -/// -/// Requires 24 bytes (3 fields * 8 bytes each): -/// -/// ```plaintext -/// ┌──────────────────────┬──────────────────────┬──────────────────────┐ -/// │ 0x01 │ 0x4048F5C3 │ 0x0000002A │ -/// └──────────────────────┴──────────────────────┴──────────────────────┘ -/// 0 8 16 24 -/// ``` -/// -/// If the schema allows null values and the tuple is (1, NULL, 42) -/// -/// Encoding the tuple requires 32 bytes (1 * 8 bytes for the null bit set + 3 fields * 8 bytes each): -/// -/// ```plaintext -/// ┌──────────────────────────┬──────────────────────┬──────────────────────┬──────────────────────┐ -/// │ 0b00000101 │ 0x01 │ 0x00000000 │ 0x0000002A │ -/// │ (7 bytes padding after) │ │ │ │ -/// └──────────────────────────┴──────────────────────┴──────────────────────┴──────────────────────┘ -/// 0 8 16 24 32 -/// ``` -#[derive(Debug, Clone)] -pub struct RowLayout { - /// If a row is null free according to its schema - pub(crate) null_free: bool, - /// The number of bytes used to store null bits for each field. - pub(crate) null_width: usize, - /// Length in bytes for `values` part of the current tuple. - pub(crate) values_width: usize, - /// Total number of fields for each tuple. - pub(crate) field_count: usize, - /// Starting offset for each fields in the raw bytes. - pub(crate) field_offsets: Vec, -} - -impl RowLayout { - /// new - pub fn new(schema: &Schema) -> Self { - assert!( - row_supported(schema), - "Row with {schema:?} not supported yet.", - ); - let null_free = schema_null_free(schema); - let field_count = schema.fields().len(); - let null_width = if null_free { - 0 - } else { - round_upto_power_of_2(ceil(field_count, 8), 8) - }; - let (field_offsets, values_width) = word_aligned_offsets(null_width, schema); - Self { - null_free, - null_width, - values_width, - field_count, - field_offsets, - } - } - - /// Get fixed part width for this layout - #[inline(always)] - pub fn fixed_part_width(&self) -> usize { - self.null_width + self.values_width - } -} - -fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec, usize) { - let mut offsets = vec![]; - let mut offset = null_width; - for f in schema.fields() { - offsets.push(offset); - assert!(!matches!(f.data_type(), DataType::Decimal256(_, _))); - // All of the current support types can fit into one single 8-bytes word except for Decimal128. - // For Decimal128, its width is of two 8-bytes words. - match f.data_type() { - DataType::Decimal128(_, _) => offset += 16, - _ => offset += 8, - } - } - (offsets, offset - null_width) -} - -/// Return true of data in `schema` can be converted to raw-bytes -/// based rows. -/// -/// Note all schemas can be supported in the row format -pub fn row_supported(schema: &Schema) -> bool { - schema.fields().iter().all(|f| { - let dt = f.data_type(); - use DataType::*; - matches!( - dt, - Boolean - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float32 - | Float64 - | Date32 - | Date64 - | Decimal128(_, _) - ) - }) -} diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs deleted file mode 100644 index 902fa881b19bc..0000000000000 --- a/datafusion/row/src/lib.rs +++ /dev/null @@ -1,303 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains code to translate arrays back and forth to a -//! row based format. The row based format is backed by raw bytes -//! ([`[u8]`]) and used to optimize certain operations. -//! -//! In general, DataFusion is a so called "vectorized" execution -//! model, specifically it uses the optimized calculation kernels in -//! [`arrow`] to amortize dispatch overhead. -//! -//! However, as mentioned in [this paper], there are some "row -//! oriented" operations in a database that are not typically amenable -//! to vectorization. The "classics" are: hash table updates in joins -//! and hash aggregates, as well as comparing tuples in sort / -//! merging. -//! -//! [this paper]: https://db.in.tum.de/~kersten/vectorization_vs_compilation.pdf - -use arrow::array::{make_builder, ArrayBuilder, ArrayRef}; -use arrow::datatypes::Schema; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; -pub use layout::row_supported; -use std::sync::Arc; - -pub mod accessor; -pub mod layout; -pub mod reader; -mod validity; -pub mod writer; - -/// Tell if schema contains no nullable field -pub(crate) fn schema_null_free(schema: &Schema) -> bool { - schema.fields().iter().all(|f| !f.is_nullable()) -} - -/// Columnar Batch buffer that assists creating `RecordBatches` -pub struct MutableRecordBatch { - arrays: Vec>, - schema: Arc, -} - -impl MutableRecordBatch { - /// new - pub fn new(target_batch_size: usize, schema: Arc) -> Self { - let arrays = new_arrays(&schema, target_batch_size); - Self { arrays, schema } - } - - /// Finalize the batch, output and reset this buffer - pub fn output(&mut self) -> ArrowResult { - let result = make_batch(self.schema.clone(), self.arrays.drain(..).collect()); - result - } - - pub fn output_as_columns(&mut self) -> Vec { - get_columns(self.arrays.drain(..).collect()) - } -} - -fn new_arrays(schema: &Schema, batch_size: usize) -> Vec> { - schema - .fields() - .iter() - .map(|field| { - let dt = field.data_type(); - make_builder(dt, batch_size) - }) - .collect::>() -} - -fn make_batch( - schema: Arc, - mut arrays: Vec>, -) -> ArrowResult { - let columns = arrays.iter_mut().map(|array| array.finish()).collect(); - RecordBatch::try_new(schema, columns) -} - -fn get_columns(mut arrays: Vec>) -> Vec { - arrays.iter_mut().map(|array| array.finish()).collect() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::layout::RowLayout; - use crate::reader::read_as_batch; - use crate::writer::write_batch_unchecked; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; - use datafusion_common::Result; - use DataType::*; - - macro_rules! fn_test_single_type { - ($ARRAY: ident, $TYPE: expr, $VEC: expr) => { - paste::item! { - #[test] - #[allow(non_snake_case)] - fn []() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, true)])); - let record_width = RowLayout::new(schema.as_ref()).fixed_part_width(); - let a = $ARRAY::from($VEC); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; - let mut vector = vec![0; record_width * batch.num_rows()]; - let row_offsets = - { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&vector, schema, &row_offsets)? }; - assert_eq!(batch, output_batch); - Ok(()) - } - - #[test] - #[allow(non_snake_case)] - fn []() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)])); - let record_width = RowLayout::new(schema.as_ref()).fixed_part_width(); - let v = $VEC.into_iter().filter(|o| o.is_some()).collect::>(); - let a = $ARRAY::from(v); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; - let mut vector = vec![0; record_width * batch.num_rows()]; - let row_offsets = - { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&vector, schema, &row_offsets)? }; - assert_eq!(batch, output_batch); - Ok(()) - } - } - }; - } - - fn_test_single_type!( - BooleanArray, - Boolean, - vec![Some(true), Some(false), None, Some(true), None] - ); - - fn_test_single_type!( - Int8Array, - Int8, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - Int16Array, - Int16, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - Int32Array, - Int32, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - Int64Array, - Int64, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - UInt8Array, - UInt8, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - UInt16Array, - UInt16, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - UInt32Array, - UInt32, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - UInt64Array, - UInt64, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - Float32Array, - Float32, - vec![Some(5.0), Some(7.0), None, Some(0.0), Some(111.0)] - ); - - fn_test_single_type!( - Float64Array, - Float64, - vec![Some(5.0), Some(7.0), None, Some(0.0), Some(111.0)] - ); - - fn_test_single_type!( - Date32Array, - Date32, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - fn_test_single_type!( - Date64Array, - Date64, - vec![Some(5), Some(7), None, Some(0), Some(111)] - ); - - #[test] - fn test_single_decimal128() -> Result<()> { - let v = vec![ - Some(0), - Some(1), - None, - Some(-1), - Some(i128::MIN), - Some(i128::MAX), - ]; - let schema = - Arc::new(Schema::new(vec![Field::new("a", Decimal128(38, 10), true)])); - let record_width = RowLayout::new(schema.as_ref()).fixed_part_width(); - let a = Decimal128Array::from(v); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; - let mut vector = vec![0; record_width * batch.num_rows()]; - let row_offsets = - { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&vector, schema, &row_offsets)? }; - assert_eq!(batch, output_batch); - Ok(()) - } - - #[test] - fn test_single_decimal128_null_free() -> Result<()> { - let v = vec![ - Some(0), - Some(1), - None, - Some(-1), - Some(i128::MIN), - Some(i128::MAX), - ]; - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - Decimal128(38, 10), - false, - )])); - let record_width = RowLayout::new(schema.as_ref()).fixed_part_width(); - let v = v.into_iter().filter(|o| o.is_some()).collect::>(); - let a = Decimal128Array::from(v); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?; - let mut vector = vec![0; record_width * batch.num_rows()]; - let row_offsets = - { write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) }; - let output_batch = { read_as_batch(&vector, schema, &row_offsets)? }; - assert_eq!(batch, output_batch); - Ok(()) - } - - #[test] - #[should_panic(expected = "not supported yet")] - fn test_unsupported_type() { - let a: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world"])); - let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); - let schema = batch.schema(); - let mut vector = vec![0; 1024]; - write_batch_unchecked(&mut vector, 0, &batch, 0, schema); - } - - #[test] - #[should_panic(expected = "not supported yet")] - fn test_unsupported_type_write() { - let a: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); - let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); - let schema = batch.schema(); - let mut vector = vec![0; 1024]; - write_batch_unchecked(&mut vector, 0, &batch, 0, schema); - } - - #[test] - #[should_panic(expected = "not supported yet")] - fn test_unsupported_type_read() { - let schema = Arc::new(Schema::new(vec![Field::new("a", Utf8, false)])); - let vector = vec![0; 1024]; - let row_offsets = vec![0]; - read_as_batch(&vector, schema, &row_offsets).unwrap(); - } -} diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs deleted file mode 100644 index 10c9896df70ab..0000000000000 --- a/datafusion/row/src/reader.rs +++ /dev/null @@ -1,366 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`read_as_batch`] converts raw bytes to [`RecordBatch`] - -use crate::layout::RowLayout; -use crate::validity::{all_valid, NullBitsFormatter}; -use crate::MutableRecordBatch; -use arrow::array::*; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use arrow::util::bit_util::get_bit_raw; -use datafusion_common::{DataFusionError, Result}; -use std::sync::Arc; - -/// Read raw-bytes from `data` rows starting at `offsets` out to a [`RecordBatch`] -/// -/// -/// ```text -/// Read data to RecordBatch ┌──────────────────┐ -/// │ │ -/// │ │ -/// ┌───────────────────────┐ │ │ -/// │ │ │ RecordBatch │ -/// │ [u8] │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▶│ │ -/// │ │ │ (... N Rows ...) │ -/// └───────────────────────┘ │ │ -/// │ │ -/// │ │ -/// └──────────────────┘ -/// ``` -pub fn read_as_batch( - data: &[u8], - schema: Arc, - offsets: &[usize], -) -> Result { - let row_num = offsets.len(); - let mut output = MutableRecordBatch::new(row_num, schema.clone()); - let mut row = RowReader::new(&schema); - - for offset in offsets.iter().take(row_num) { - row.point_to(*offset, data); - read_row(&row, &mut output, &schema); - } - - output.output().map_err(DataFusionError::ArrowError) -} - -#[macro_export] -macro_rules! get_idx { - ($NATIVE: ident, $SELF: ident, $IDX: ident, $WIDTH: literal) => {{ - $SELF.assert_index_valid($IDX); - let offset = $SELF.field_offsets()[$IDX]; - let start = $SELF.base_offset + offset; - let end = start + $WIDTH; - $NATIVE::from_le_bytes($SELF.data[start..end].try_into().unwrap()) - }}; -} - -#[macro_export] -macro_rules! fn_get_idx { - ($NATIVE: ident, $WIDTH: literal) => { - paste::item! { - fn [](&self, idx: usize) -> $NATIVE { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - let start = self.base_offset + offset; - let end = start + $WIDTH; - $NATIVE::from_le_bytes(self.data[start..end].try_into().unwrap()) - } - } - }; -} - -#[macro_export] -macro_rules! fn_get_idx_opt { - ($NATIVE: ident) => { - paste::item! { - pub fn [](&self, idx: usize) -> Option<$NATIVE> { - if self.is_valid_at(idx) { - Some(self.[](idx)) - } else { - None - } - } - } - }; -} - -/// Read the tuple `data[base_offset..]` we are currently pointing to -pub struct RowReader<'a> { - /// Layout on how to read each field - layout: RowLayout, - /// Raw bytes slice where the tuple stores - data: &'a [u8], - /// Start position for the current tuple in the raw bytes slice. - base_offset: usize, -} - -impl<'a> std::fmt::Debug for RowReader<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.null_free() { - write!(f, "null_free") - } else { - let null_bits = self.null_bits(); - write!( - f, - "{:?}", - NullBitsFormatter::new(null_bits, self.layout.field_count) - ) - } - } -} - -impl<'a> RowReader<'a> { - /// new - pub fn new(schema: &Schema) -> Self { - Self { - layout: RowLayout::new(schema), - data: &[], - base_offset: 0, - } - } - - /// Update this row to point to position `offset` in `base` - pub fn point_to(&mut self, offset: usize, data: &'a [u8]) { - self.base_offset = offset; - self.data = data; - } - - #[inline] - fn assert_index_valid(&self, idx: usize) { - assert!(idx < self.layout.field_count); - } - - #[inline(always)] - fn field_offsets(&self) -> &[usize] { - &self.layout.field_offsets - } - - #[inline(always)] - fn null_free(&self) -> bool { - self.layout.null_free - } - - #[inline(always)] - fn null_bits(&self) -> &[u8] { - if self.null_free() { - &[] - } else { - let start = self.base_offset; - &self.data[start..start + self.layout.null_width] - } - } - - #[inline(always)] - fn all_valid(&self) -> bool { - if self.null_free() { - true - } else { - let null_bits = self.null_bits(); - all_valid(null_bits, self.layout.field_count) - } - } - - fn is_valid_at(&self, idx: usize) -> bool { - unsafe { get_bit_raw(self.null_bits().as_ptr(), idx) } - } - - fn get_bool(&self, idx: usize) -> bool { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - let value = &self.data[self.base_offset + offset..]; - value[0] != 0 - } - - fn get_u8(&self, idx: usize) -> u8 { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[self.base_offset + offset] - } - - fn_get_idx!(u16, 2); - fn_get_idx!(u32, 4); - fn_get_idx!(u64, 8); - fn_get_idx!(i8, 1); - fn_get_idx!(i16, 2); - fn_get_idx!(i32, 4); - fn_get_idx!(i64, 8); - fn_get_idx!(f32, 4); - fn_get_idx!(f64, 8); - - fn get_date32(&self, idx: usize) -> i32 { - get_idx!(i32, self, idx, 4) - } - - fn get_date64(&self, idx: usize) -> i64 { - get_idx!(i64, self, idx, 8) - } - - fn get_decimal128(&self, idx: usize) -> i128 { - get_idx!(i128, self, idx, 16) - } - - fn_get_idx_opt!(bool); - fn_get_idx_opt!(u8); - fn_get_idx_opt!(u16); - fn_get_idx_opt!(u32); - fn_get_idx_opt!(u64); - fn_get_idx_opt!(i8); - fn_get_idx_opt!(i16); - fn_get_idx_opt!(i32); - fn_get_idx_opt!(i64); - fn_get_idx_opt!(f32); - fn_get_idx_opt!(f64); - - fn get_date32_opt(&self, idx: usize) -> Option { - if self.is_valid_at(idx) { - Some(self.get_date32(idx)) - } else { - None - } - } - - fn get_date64_opt(&self, idx: usize) -> Option { - if self.is_valid_at(idx) { - Some(self.get_date64(idx)) - } else { - None - } - } - - fn get_decimal128_opt(&self, idx: usize) -> Option { - if self.is_valid_at(idx) { - Some(self.get_decimal128(idx)) - } else { - None - } - } -} - -/// Read the row currently pointed by RowWriter to the output columnar batch buffer -pub fn read_row(row: &RowReader, batch: &mut MutableRecordBatch, schema: &Schema) { - if row.all_valid() { - for ((col_idx, to), field) in batch - .arrays - .iter_mut() - .enumerate() - .zip(schema.fields().iter()) - { - read_field_null_free(to, field.data_type(), col_idx, row) - } - } else { - for ((col_idx, to), field) in batch - .arrays - .iter_mut() - .enumerate() - .zip(schema.fields().iter()) - { - read_field(to, field.data_type(), col_idx, row) - } - } -} - -macro_rules! fn_read_field { - ($NATIVE: ident, $ARRAY: ident) => { - paste::item! { - pub(crate) fn [](to: &mut Box, col_idx: usize, row: &RowReader) { - let to = to - .as_any_mut() - .downcast_mut::<$ARRAY>() - .unwrap(); - to.append_option(row.[](col_idx)); - } - - pub(crate) fn [](to: &mut Box, col_idx: usize, row: &RowReader) { - let to = to - .as_any_mut() - .downcast_mut::<$ARRAY>() - .unwrap(); - to.append_value(row.[](col_idx)); - } - } - }; -} - -fn_read_field!(bool, BooleanBuilder); -fn_read_field!(u8, UInt8Builder); -fn_read_field!(u16, UInt16Builder); -fn_read_field!(u32, UInt32Builder); -fn_read_field!(u64, UInt64Builder); -fn_read_field!(i8, Int8Builder); -fn_read_field!(i16, Int16Builder); -fn_read_field!(i32, Int32Builder); -fn_read_field!(i64, Int64Builder); -fn_read_field!(f32, Float32Builder); -fn_read_field!(f64, Float64Builder); -fn_read_field!(date32, Date32Builder); -fn_read_field!(date64, Date64Builder); -fn_read_field!(decimal128, Decimal128Builder); - -fn read_field( - to: &mut Box, - dt: &DataType, - col_idx: usize, - row: &RowReader, -) { - use DataType::*; - match dt { - Boolean => read_field_bool(to, col_idx, row), - UInt8 => read_field_u8(to, col_idx, row), - UInt16 => read_field_u16(to, col_idx, row), - UInt32 => read_field_u32(to, col_idx, row), - UInt64 => read_field_u64(to, col_idx, row), - Int8 => read_field_i8(to, col_idx, row), - Int16 => read_field_i16(to, col_idx, row), - Int32 => read_field_i32(to, col_idx, row), - Int64 => read_field_i64(to, col_idx, row), - Float32 => read_field_f32(to, col_idx, row), - Float64 => read_field_f64(to, col_idx, row), - Date32 => read_field_date32(to, col_idx, row), - Date64 => read_field_date64(to, col_idx, row), - Decimal128(_, _) => read_field_decimal128(to, col_idx, row), - _ => unimplemented!(), - } -} - -fn read_field_null_free( - to: &mut Box, - dt: &DataType, - col_idx: usize, - row: &RowReader, -) { - use DataType::*; - match dt { - Boolean => read_field_bool_null_free(to, col_idx, row), - UInt8 => read_field_u8_null_free(to, col_idx, row), - UInt16 => read_field_u16_null_free(to, col_idx, row), - UInt32 => read_field_u32_null_free(to, col_idx, row), - UInt64 => read_field_u64_null_free(to, col_idx, row), - Int8 => read_field_i8_null_free(to, col_idx, row), - Int16 => read_field_i16_null_free(to, col_idx, row), - Int32 => read_field_i32_null_free(to, col_idx, row), - Int64 => read_field_i64_null_free(to, col_idx, row), - Float32 => read_field_f32_null_free(to, col_idx, row), - Float64 => read_field_f64_null_free(to, col_idx, row), - Date32 => read_field_date32_null_free(to, col_idx, row), - Date64 => read_field_date64_null_free(to, col_idx, row), - Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row), - _ => unimplemented!(), - } -} diff --git a/datafusion/row/src/validity.rs b/datafusion/row/src/validity.rs deleted file mode 100644 index 45f5e19f1894f..0000000000000 --- a/datafusion/row/src/validity.rs +++ /dev/null @@ -1,161 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Row format validity utilities - -use arrow::util::bit_util::get_bit_raw; -use std::fmt::Write; - -const ALL_VALID_MASK: [u8; 8] = [1, 3, 7, 15, 31, 63, 127, 255]; - -/// Returns if all fields are valid -pub fn all_valid(data: &[u8], n: usize) -> bool { - for item in data.iter().take(n / 8) { - if *item != ALL_VALID_MASK[7] { - return false; - } - } - if n % 8 == 0 { - true - } else { - data[n / 8] == ALL_VALID_MASK[n % 8 - 1] - } -} - -/// Show null bit for each field in a tuple, 1 for valid and 0 for null. -/// For a tuple with nine total fields, valid at field 0, 6, 7, 8 shows as `[10000011, 1]`. -pub struct NullBitsFormatter<'a> { - null_bits: &'a [u8], - field_count: usize, -} - -impl<'a> NullBitsFormatter<'a> { - /// new - pub fn new(null_bits: &'a [u8], field_count: usize) -> Self { - Self { - null_bits, - field_count, - } - } -} - -impl<'a> std::fmt::Debug for NullBitsFormatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut is_first = true; - let data = self.null_bits; - for i in 0..self.field_count { - if is_first { - f.write_char('[')?; - is_first = false; - } else if i % 8 == 0 { - f.write_str(", ")?; - } - if unsafe { get_bit_raw(data.as_ptr(), i) } { - f.write_char('1')?; - } else { - f.write_char('0')?; - } - } - f.write_char(']')?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::util::bit_util::{ceil, set_bit_raw, unset_bit_raw}; - use rand::Rng; - - fn test_validity(bs: &[bool]) { - let n = bs.len(); - let mut data = vec![0; ceil(n, 8)]; - for (i, b) in bs.iter().enumerate() { - if *b { - let data_argument = &mut data; - unsafe { - set_bit_raw(data_argument.as_mut_ptr(), i); - }; - } else { - let data_argument = &mut data; - unsafe { - unset_bit_raw(data_argument.as_mut_ptr(), i); - }; - } - } - let expected = bs.iter().all(|f| *f); - assert_eq!(all_valid(&data, bs.len()), expected); - } - - #[test] - fn test_all_valid() { - let sizes = [4, 8, 12, 16, 19, 23, 32, 44]; - for i in sizes { - { - // contains false - let input = { - let mut rng = rand::thread_rng(); - let mut input: Vec = vec![false; i]; - rng.fill(&mut input[..]); - input - }; - test_validity(&input); - } - - { - // all true - let input = vec![true; i]; - test_validity(&input); - } - } - } - - #[test] - fn test_formatter() -> std::fmt::Result { - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[0b11000001], 8)), - "[10000011]" - ); - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[0b11000001, 1], 9)), - "[10000011, 1]" - ); - assert_eq!(format!("{:?}", NullBitsFormatter::new(&[1], 2)), "[10]"); - assert_eq!(format!("{:?}", NullBitsFormatter::new(&[1], 3)), "[100]"); - assert_eq!(format!("{:?}", NullBitsFormatter::new(&[1], 4)), "[1000]"); - assert_eq!(format!("{:?}", NullBitsFormatter::new(&[1], 5)), "[10000]"); - assert_eq!(format!("{:?}", NullBitsFormatter::new(&[1], 6)), "[100000]"); - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[1], 7)), - "[1000000]" - ); - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[1], 8)), - "[10000000]" - ); - // extra bytes are ignored - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[0b11000001, 1, 1, 1], 9)), - "[10000011, 1]" - ); - assert_eq!( - format!("{:?}", NullBitsFormatter::new(&[0b11000001, 1, 1], 16)), - "[10000011, 10000000]" - ); - Ok(()) - } -} diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs deleted file mode 100644 index 14ce6afe6832e..0000000000000 --- a/datafusion/row/src/writer.rs +++ /dev/null @@ -1,333 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RowWriter`] writes [`RecordBatch`]es to `Vec` to stitch attributes together - -use crate::layout::RowLayout; -use arrow::array::*; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use arrow::util::bit_util::{set_bit_raw, unset_bit_raw}; -use datafusion_common::cast::{as_date32_array, as_date64_array, as_decimal128_array}; -use datafusion_common::Result; -use std::sync::Arc; - -/// Append batch from `row_idx` to `output` buffer start from `offset` -/// # Panics -/// -/// This function will panic if the output buffer doesn't have enough space to hold all the rows -pub fn write_batch_unchecked( - output: &mut [u8], - offset: usize, - batch: &RecordBatch, - row_idx: usize, - schema: Arc, -) -> Vec { - let mut writer = RowWriter::new(&schema); - let mut current_offset = offset; - let mut offsets = vec![]; - let columns = batch.columns(); - for cur_row in row_idx..batch.num_rows() { - offsets.push(current_offset); - let row_width = write_row(&mut writer, cur_row, &schema, columns); - output[current_offset..current_offset + row_width] - .copy_from_slice(writer.get_row()); - current_offset += row_width; - writer.reset() - } - offsets -} - -/// Bench interpreted version write -#[inline(never)] -pub fn bench_write_batch( - batches: &[Vec], - schema: Arc, -) -> Result> { - let mut writer = RowWriter::new(&schema); - let mut lengths = vec![]; - - for batch in batches.iter().flatten() { - let columns = batch.columns(); - for cur_row in 0..batch.num_rows() { - let row_width = write_row(&mut writer, cur_row, &schema, columns); - lengths.push(row_width); - writer.reset() - } - } - - Ok(lengths) -} - -#[macro_export] -macro_rules! set_idx { - ($WIDTH: literal, $SELF: ident, $IDX: ident, $VALUE: ident) => {{ - $SELF.assert_index_valid($IDX); - let offset = $SELF.field_offsets()[$IDX]; - $SELF.data[offset..offset + $WIDTH].copy_from_slice(&$VALUE.to_le_bytes()); - }}; -} - -#[macro_export] -macro_rules! fn_set_idx { - ($NATIVE: ident, $WIDTH: literal) => { - paste::item! { - fn [](&mut self, idx: usize, value: $NATIVE) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset..offset + $WIDTH].copy_from_slice(&value.to_le_bytes()); - } - } - }; -} - -/// Reusable row writer backed by `Vec` -/// -/// ```text -/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ -/// RowWriter │ -/// ┌───────────────────────┐ │ [RowFormat] -/// │ │ │ -/// │ │ │(copy from Array -/// │ │ to [u8]) │ ┌───────────────────────┐ -/// │ RecordBatch │ └ ─ ─ ─ ─ ─ ─ ─ ─ │ RowFormat │ -/// │ │──────────────────────────────▶│ Vec │ -/// │ (... N Rows ...) │ │ │ -/// │ │ └───────────────────────┘ -/// │ │ -/// │ │ -/// └───────────────────────┘ -/// ``` -pub struct RowWriter { - /// Layout on how to write each field - layout: RowLayout, - /// Buffer for the current tuple being written. - data: Vec, - /// Length in bytes for the current tuple, 8-bytes word aligned. - pub(crate) row_width: usize, -} - -impl RowWriter { - /// New - pub fn new(schema: &Schema) -> Self { - let layout = RowLayout::new(schema); - let init_capacity = layout.fixed_part_width(); - Self { - layout, - data: vec![0; init_capacity], - row_width: init_capacity, - } - } - - /// Reset the row writer state for new tuple - pub fn reset(&mut self) { - self.data.fill(0); - self.row_width = self.layout.fixed_part_width(); - } - - #[inline] - fn assert_index_valid(&self, idx: usize) { - assert!(idx < self.layout.field_count); - } - - #[inline(always)] - fn field_offsets(&self) -> &[usize] { - &self.layout.field_offsets - } - - #[inline(always)] - fn null_free(&self) -> bool { - self.layout.null_free - } - - pub(crate) fn set_null_at(&mut self, idx: usize) { - assert!( - !self.null_free(), - "Unexpected call to set_null_at on null-free row writer" - ); - let null_bits = &mut self.data[0..self.layout.null_width]; - unsafe { - unset_bit_raw(null_bits.as_mut_ptr(), idx); - } - } - - pub(crate) fn set_non_null_at(&mut self, idx: usize) { - assert!( - !self.null_free(), - "Unexpected call to set_non_null_at on null-free row writer" - ); - let null_bits = &mut self.data[0..self.layout.null_width]; - unsafe { - set_bit_raw(null_bits.as_mut_ptr(), idx); - } - } - - fn set_bool(&mut self, idx: usize, value: bool) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = u8::from(value); - } - - fn set_u8(&mut self, idx: usize, value: u8) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = value; - } - - fn_set_idx!(u16, 2); - fn_set_idx!(u32, 4); - fn_set_idx!(u64, 8); - fn_set_idx!(i16, 2); - fn_set_idx!(i32, 4); - fn_set_idx!(i64, 8); - fn_set_idx!(f32, 4); - fn_set_idx!(f64, 8); - - fn set_i8(&mut self, idx: usize, value: i8) { - self.assert_index_valid(idx); - let offset = self.field_offsets()[idx]; - self.data[offset] = value.to_le_bytes()[0]; - } - - fn set_date32(&mut self, idx: usize, value: i32) { - set_idx!(4, self, idx, value) - } - - fn set_date64(&mut self, idx: usize, value: i64) { - set_idx!(8, self, idx, value) - } - - fn set_decimal128(&mut self, idx: usize, value: i128) { - set_idx!(16, self, idx, value) - } - - /// Get raw bytes - pub fn get_row(&self) -> &[u8] { - &self.data[0..self.row_width] - } -} - -/// Stitch attributes of tuple in `batch` at `row_idx` and returns the tuple width -pub fn write_row( - row_writer: &mut RowWriter, - row_idx: usize, - schema: &Schema, - columns: &[ArrayRef], -) -> usize { - // Get the row from the batch denoted by row_idx - if row_writer.null_free() { - for ((i, f), col) in schema.fields().iter().enumerate().zip(columns.iter()) { - write_field(i, row_idx, col, f.data_type(), row_writer); - } - } else { - for ((i, f), col) in schema.fields().iter().enumerate().zip(columns.iter()) { - if !col.is_null(row_idx) { - row_writer.set_non_null_at(i); - write_field(i, row_idx, col, f.data_type(), row_writer); - } else { - row_writer.set_null_at(i); - } - } - } - - row_writer.row_width -} - -macro_rules! fn_write_field { - ($NATIVE: ident, $ARRAY: ident) => { - paste::item! { - pub(crate) fn [](to: &mut RowWriter, from: &Arc, col_idx: usize, row_idx: usize) { - let from = from - .as_any() - .downcast_ref::<$ARRAY>() - .unwrap(); - to.[](col_idx, from.value(row_idx)); - } - } - }; -} - -fn_write_field!(bool, BooleanArray); -fn_write_field!(u8, UInt8Array); -fn_write_field!(u16, UInt16Array); -fn_write_field!(u32, UInt32Array); -fn_write_field!(u64, UInt64Array); -fn_write_field!(i8, Int8Array); -fn_write_field!(i16, Int16Array); -fn_write_field!(i32, Int32Array); -fn_write_field!(i64, Int64Array); -fn_write_field!(f32, Float32Array); -fn_write_field!(f64, Float64Array); - -pub(crate) fn write_field_date32( - to: &mut RowWriter, - from: &Arc, - col_idx: usize, - row_idx: usize, -) { - match as_date32_array(from) { - Ok(from) => to.set_date32(col_idx, from.value(row_idx)), - Err(e) => panic!("{e}"), - }; -} - -pub(crate) fn write_field_date64( - to: &mut RowWriter, - from: &Arc, - col_idx: usize, - row_idx: usize, -) { - let from = as_date64_array(from).unwrap(); - to.set_date64(col_idx, from.value(row_idx)); -} - -pub(crate) fn write_field_decimal128( - to: &mut RowWriter, - from: &Arc, - col_idx: usize, - row_idx: usize, -) { - let from = as_decimal128_array(from).unwrap(); - to.set_decimal128(col_idx, from.value(row_idx)); -} - -fn write_field( - col_idx: usize, - row_idx: usize, - col: &Arc, - dt: &DataType, - row: &mut RowWriter, -) { - use DataType::*; - match dt { - Boolean => write_field_bool(row, col, col_idx, row_idx), - UInt8 => write_field_u8(row, col, col_idx, row_idx), - UInt16 => write_field_u16(row, col, col_idx, row_idx), - UInt32 => write_field_u32(row, col, col_idx, row_idx), - UInt64 => write_field_u64(row, col, col_idx, row_idx), - Int8 => write_field_i8(row, col, col_idx, row_idx), - Int16 => write_field_i16(row, col, col_idx, row_idx), - Int32 => write_field_i32(row, col, col_idx, row_idx), - Int64 => write_field_i64(row, col, col_idx, row_idx), - Float32 => write_field_f32(row, col, col_idx, row_idx), - Float64 => write_field_f64(row, col, col_idx, row_idx), - Date32 => write_field_date32(row, col, col_idx, row_idx), - Date64 => write_field_date64(row, col, col_idx, row_idx), - Decimal128(_, _) => write_field_decimal128(row, col, col_idx, row_idx), - _ => unimplemented!(), - } -} diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index af49b0cba314a..b91a2ac1fbd7e 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-sql" description = "DataFusion SQL Query Planner" keywords = ["datafusion", "sql", "parser", "planner"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -39,13 +39,13 @@ unicode_expressions = [] [dependencies] arrow = { workspace = true } arrow-schema = { workspace = true } -datafusion-common = { path = "../common", version = "26.0.0" } -datafusion-expr = { path = "../expr", version = "26.0.0" } -log = "^0.4" -sqlparser = "0.34" +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +log = { workspace = true } +sqlparser = { workspace = true } [dev-dependencies] -ctor = "0.2.0" -env_logger = "0.10" +ctor = { workspace = true } +env_logger = { workspace = true } paste = "^1.0" -rstest = "0.17" +rstest = "0.18" diff --git a/datafusion/sql/README.md b/datafusion/sql/README.md index 2ad994e4eba5c..256fa774b4105 100644 --- a/datafusion/sql/README.md +++ b/datafusion/sql/README.md @@ -20,7 +20,7 @@ # DataFusion SQL Query Planner This crate provides a general purpose SQL query planner that can parse SQL and translate queries into logical -plans. Although this crate is used by the [DataFusion](df) query engine, it was designed to be easily usable from any +plans. Although this crate is used by the [DataFusion][df] query engine, it was designed to be easily usable from any project that requires a SQL query planner and does not make any assumptions about how the resulting logical plan will be translated to a physical plan. For example, there is no concept of row-based versus columnar execution in the logical plan. diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index b0a0a73e7ce74..9df65b99a748a 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -17,7 +17,8 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; @@ -48,20 +49,20 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let schema_provider = MySchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = MyContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // show the plan println!("{plan:?}"); } -struct MySchemaProvider { +struct MyContextProvider { options: ConfigOptions, tables: HashMap>, } -impl MySchemaProvider { +impl MyContextProvider { fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -103,14 +104,11 @@ fn create_table_source(fields: Vec) -> Arc { ))) } -impl ContextProvider for MySchemaProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), - _ => Err(DataFusionError::Plan(format!( - "Table not found: {}", - name.table() - ))), + _ => plan_err!("Table not found: {}", name.table()), } } @@ -126,6 +124,10 @@ impl ContextProvider for MySchemaProvider { None } + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + fn options(&self) -> &ConfigOptions { &self.options } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 607a9fb432f4a..ade8b96b5cc21 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -18,11 +18,14 @@ //! Implementation of the `arrow_cast` function that allows //! casting to arbitrary arrow types (rather than SQL types) -use std::{fmt::Display, iter::Peekable, str::Chars}; +use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use datafusion_common::{ + plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_common::plan_err; use datafusion_expr::{Expr, ExprSchemable}; pub const ARROW_CAST_NAME: &str = "arrow_cast"; @@ -51,21 +54,18 @@ pub const ARROW_CAST_NAME: &str = "arrow_cast"; /// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { if args.len() != 2 { - return Err(DataFusionError::Plan(format!( - "arrow_cast needs 2 arguments, {} provided", - args.len() - ))); + return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } let arg1 = args.pop().unwrap(); let arg0 = args.pop().unwrap(); - // arg1 must be a stirng + // arg1 must be a string let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { v } else { - return Err(DataFusionError::Plan(format!( + return plan_err!( "arrow_cast requires its second argument to be a constant string, got {arg1}" - ))); + ); }; // do the actual lookup to the appropriate data type @@ -100,9 +100,7 @@ pub fn parse_data_type(val: &str) -> Result { } fn make_error(val: &str, msg: &str) -> DataFusionError { - DataFusionError::Plan( - format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) - ) + plan_datafusion_err!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) } fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { @@ -150,6 +148,8 @@ impl<'a> Parser<'a> { Token::Decimal128 => self.parse_decimal_128(), Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), + Token::List => self.parse_list(), + Token::LargeList => self.parse_large_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -157,6 +157,26 @@ impl<'a> Parser<'a> { } } + /// Parses the List type + fn parse_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::List(Arc::new(Field::new( + "item", data_type, true, + )))) + } + + /// Parses the LargeList type + fn parse_large_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -486,6 +506,9 @@ impl<'a> Tokenizer<'a> { "Date32" => Token::SimpleType(DataType::Date32), "Date64" => Token::SimpleType(DataType::Date64), + "List" => Token::List, + "LargeList" => Token::LargeList, + "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), @@ -573,12 +596,16 @@ enum Token { None, Integer(i64), DoubleQuotedString(String), + List, + LargeList, } impl Display for Token { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Token::SimpleType(t) => write!(f, "{t}"), + Token::List => write!(f, "List"), + Token::LargeList => write!(f, "LargeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sql/src/expr/binary_op.rs b/datafusion/sql/src/expr/binary_op.rs index c5d2238ac0b77..d9c85663e50e2 100644 --- a/datafusion/sql/src/expr/binary_op.rs +++ b/datafusion/sql/src/expr/binary_op.rs @@ -16,7 +16,7 @@ // under the License. use crate::planner::{ContextProvider, SqlToRel}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_expr::Operator; use sqlparser::ast::BinaryOperator; @@ -43,12 +43,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + BinaryOperator::PGBitwiseXor => Ok(Operator::BitwiseXor), BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), BinaryOperator::StringConcat => Ok(Operator::StringConcat), - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported SQL binary operator {op:?}" - ))), + _ => not_impl_err!("Unsupported SQL binary operator {op:?}"), } } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 104a65832dcde..73de4fa439071 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -16,11 +16,12 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, DataFusionError, Result}; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; -use datafusion_expr::function_err::suggest_valid_function; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::window_frame::regularize; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, +}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::function::suggest_valid_function; +use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFunction, @@ -35,66 +36,91 @@ use super::arrow_cast::ARROW_CAST_NAME; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, - mut function: SQLFunction, + function: SQLFunction, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let name = if function.name.0.len() > 1 { + let SQLFunction { + name, + args, + over, + distinct, + filter, + null_treatment, + special: _, // true if not called with trailing parens + order_by, + } = function; + + if let Some(null_treatment) = null_treatment { + return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"); + } + + let name = if name.0.len() > 1 { // DF doesn't handle compound identifiers // (e.g. "foo.bar") for function names yet - function.name.to_string() + name.to_string() } else { - crate::utils::normalize_ident(function.name.0[0].clone()) + crate::utils::normalize_ident(name.0[0].clone()) }; + // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function + if let Some(fm) = self.context_provider.get_function_meta(&name) { + let args = self.function_args_to_expr(args, schema, planner_context)?; + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); + } + // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); }; // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. - let is_function_window = function.over.is_some(); - if !function.order_by.is_empty() && is_function_window { - return Err(DataFusionError::Plan( - "Aggregate ORDER BY is not implemented for window functions".to_string(), - )); + let is_function_window = over.is_some(); + if !order_by.is_empty() && is_function_window { + return plan_err!( + "Aggregate ORDER BY is not implemented for window functions" + ); } // then, window function - if let Some(WindowType::WindowSpec(window)) = function.over.take() { + if let Some(WindowType::WindowSpec(window)) = over { let partition_by = window .partition_by .into_iter() .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; - let order_by = - self.order_by_to_sort_expr(&window.order_by, schema, planner_context)?; + let mut order_by = self.order_by_to_sort_expr( + &window.order_by, + schema, + planner_context, + // Numeric literals in window function ORDER BY are treated as constants + false, + )?; let window_frame = window .window_frame .as_ref() .map(|window_frame| { let window_frame = window_frame.clone().try_into()?; - regularize(window_frame, order_by.len()) + check_window_frame(&window_frame, order_by.len()) + .map(|_| window_frame) }) .transpose()?; + let window_frame = if let Some(window_frame) = window_frame { + regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else { WindowFrame::new(!order_by.is_empty()) }; + if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { - let (aggregate_fun, args) = self.aggregate_fn_to_expr( - aggregate_fun, - function.args, - schema, - planner_context, - )?; + let args = + self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(aggregate_fun), @@ -106,11 +132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, - self.function_args_to_expr( - function.args, - schema, - planner_context, - )?, + self.function_args_to_expr(args, schema, planner_context)?, partition_by, order_by, window_frame, @@ -119,55 +141,40 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(expr); } } else { + // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function + if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let args = self.function_args_to_expr(args, schema, planner_context)?; + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, + ))); + } + // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { - let distinct = function.distinct; - let order_by = self.order_by_to_sort_expr( - &function.order_by, - schema, - planner_context, - )?; + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; let order_by = (!order_by.is_empty()).then_some(order_by); - let (fun, args) = self.aggregate_fn_to_expr( - fun, - function.args, - schema, - planner_context, - )?; + let args = self.function_args_to_expr(args, schema, planner_context)?; + let filter: Option> = filter + .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) + .transpose()? + .map(Box::new); + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, order_by, + fun, args, distinct, filter, order_by, ))); }; - // finally, user-defined functions (UDF) and UDAF - if let Some(fm) = self.schema_provider.get_function_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); - } - - // User defined aggregate functions - if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, - ))); - } - // Special case arrow_cast (as its type is dependent on its argument value) if name == ARROW_CAST_NAME { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; return super::arrow_cast::create_arrow_cast(args, schema); } } // Could not find the relevant function, so return an error let suggested_func_name = suggest_valid_function(&name, is_function_window); - Err(DataFusionError::Plan(format!( - "Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?" - ))) + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } pub(super) fn sql_named_function_to_expr( @@ -183,13 +190,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn find_window_func(&self, name: &str) -> Result { window_function::find_df_window_func(name) + // next check user defined aggregates .or_else(|| { - self.schema_provider + self.context_provider .get_aggregate_meta(name) .map(WindowFunction::AggregateUDF) }) + // next check user defined window functions + .or_else(|| { + self.context_provider + .get_window_meta(name) + .map(WindowFunction::WindowUDF) + }) .ok_or_else(|| { - DataFusionError::Plan(format!("There is no window function named {name}")) + plan_datafusion_err!("There is no window function named {name}") }) } @@ -207,14 +221,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, - } => Ok(Expr::Wildcard), + } => Ok(Expr::Wildcard { qualifier: None }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard), - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported qualified wildcard argument: {sql:?}" - ))), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Wildcard { qualifier: None }) + } + _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } @@ -228,28 +242,4 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>() } - - pub(super) fn aggregate_fn_to_expr( - &self, - fun: AggregateFunction, - args: Vec, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result<(AggregateFunction, Vec)> { - let args = match fun { - // Special case rewrite COUNT(*) to COUNT(constant) - AggregateFunction::Count => args - .into_iter() - .map(|a| match a { - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) - } - _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), - }) - .collect::>>()?, - _ => self.function_args_to_expr(args, schema, planner_context)?, - }; - - Ok((fun, args)) - } } diff --git a/datafusion/sql/src/expr/grouping_set.rs b/datafusion/sql/src/expr/grouping_set.rs index c5a0b6da7dc0a..254f5079b7b11 100644 --- a/datafusion/sql/src/expr/grouping_set.rs +++ b/datafusion/sql/src/expr/grouping_set.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::{Expr, GroupingSet}; use sqlparser::ast::Expr as SQLExpr; @@ -48,10 +49,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|v| { if v.len() != 1 { - Err(DataFusionError::Internal( + plan_err!( "Tuple expressions are not supported for Rollup expressions" - .to_string(), - )) + ) } else { self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) } @@ -70,10 +70,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|v| { if v.len() != 1 { - Err(DataFusionError::Internal( - "Tuple expressions not are supported for Cube expressions" - .to_string(), - )) + plan_err!("Tuple expressions not are supported for Cube expressions") } else { self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index fdf3d0f20bdda..9f53ff579e7c8 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -17,9 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Column, DFField, DFSchema, DataFusionError, Result, ScalarValue, TableReference, + internal_err, plan_datafusion_err, Column, DFField, DFSchema, DataFusionError, + Result, TableReference, }; -use datafusion_expr::{Case, Expr, GetIndexedField}; +use datafusion_expr::{Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -33,12 +34,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; let ty = self - .schema_provider + .context_provider .get_variable_type(&var_names) .ok_or_else(|| { - DataFusionError::Execution(format!( - "variable {var_names:?} has no type information" - )) + plan_datafusion_err!("variable {var_names:?} has no type information") })?; Ok(Expr::ScalarVariable(ty, var_names)) } else { @@ -90,9 +89,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { if ids.len() < 2 { - return Err(DataFusionError::Internal(format!( - "Not a compound identifier: {ids:?}" - ))); + return internal_err!("Not a compound identifier: {ids:?}"); } if ids[0].value.starts_with('@') { @@ -101,7 +98,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|id| self.normalizer.normalize(id)) .collect(); let ty = self - .schema_provider + .context_provider .get_variable_type(&var_names) .ok_or_else(|| { DataFusionError::Execution(format!( @@ -119,9 +116,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Though ideally once that support is in place, this code should work with it // TODO: remove when can support multiple nested identifiers if ids.len() > 5 { - return Err(DataFusionError::Internal(format!( - "Unsupported compound identifier: {ids:?}" - ))); + return internal_err!("Unsupported compound identifier: {ids:?}"); } let search_result = search_dfschema(&ids, schema); @@ -130,16 +125,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some((field, nested_names)) if !nested_names.is_empty() => { // TODO: remove when can support multiple nested identifiers if nested_names.len() > 1 { - return Err(DataFusionError::Internal(format!( + return internal_err!( "Nested identifiers not yet supported for column {}", field.qualified_column().quoted_flat_name() - ))); + ); } let nested_name = nested_names[0].to_string(); - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(Expr::Column(field.qualified_column())), - ScalarValue::Utf8(Some(nested_name)), - ))) + Ok(Expr::Column(field.qualified_column()).field(nested_name)) } // found matching field with no spare identifier(s) Some((field, _nested_names)) => { @@ -149,9 +141,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // return default where use all identifiers to not have a nested field // this len check is because at 5 identifiers will have to have a nested field if ids.len() == 5 { - Err(DataFusionError::Internal(format!( - "Unsupported compound identifier: {ids:?}" - ))) + internal_err!("Unsupported compound identifier: {ids:?}") } else { // check the outer_query_schema and try to find a match if let Some(outer) = planner_context.outer_query_schema() { @@ -162,10 +152,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if !nested_names.is_empty() => { // TODO: remove when can support nested identifiers for OuterReferenceColumn - Err(DataFusionError::Internal(format!( + internal_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", field.qualified_column().quoted_flat_name() - ))) + ) } // found matching field with no spare identifier(s) Some((field, _nested_names)) => { @@ -272,10 +262,7 @@ fn form_identifier(idents: &[String]) -> Result<(Option, &String }), &idents[3], )), - _ => Err(DataFusionError::Internal(format!( - "Incorrect number of identifiers: {}", - idents.len() - ))), + _ => internal_err!("Incorrect number of identifiers: {}", idents.len()), } } @@ -441,10 +428,10 @@ mod test { #[test] fn test_form_identifier() -> Result<()> { let err = form_identifier(&[]).expect_err("empty identifiers didn't fail"); - let expected = "Internal error: Incorrect number of identifiers: 0. \ + let expected = "Internal error: Incorrect number of identifiers: 0.\n\ This was likely caused by a bug in DataFusion's code and we would \ welcome that you file an bug report in our issue tracker"; - assert_eq!(err.to_string(), expected); + assert!(expected.starts_with(&err.strip_backtrace())); let ids = vec!["a".to_string()]; let (qualifier, column) = form_identifier(&ids)?; @@ -479,10 +466,10 @@ mod test { "e".to_string(), ]) .expect_err("too many identifiers didn't fail"); - let expected = "Internal error: Incorrect number of identifiers: 5. \ + let expected = "Internal error: Incorrect number of identifiers: 5.\n\ This was likely caused by a bug in DataFusion's code and we would \ welcome that you file an bug report in our issue tracker"; - assert_eq!(err.to_string(), expected); + assert!(expected.starts_with(&err.strip_backtrace())); Ok(()) } diff --git a/datafusion/sql/src/expr/json_access.rs b/datafusion/sql/src/expr/json_access.rs new file mode 100644 index 0000000000000..681b72b4e71ac --- /dev/null +++ b/datafusion/sql/src/expr/json_access.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::planner::{ContextProvider, SqlToRel}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_expr::Operator; +use sqlparser::ast::JsonOperator; + +impl<'a, S: ContextProvider> SqlToRel<'a, S> { + pub(crate) fn parse_sql_json_access(&self, op: JsonOperator) -> Result { + match op { + JsonOperator::AtArrow => Ok(Operator::AtArrow), + JsonOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("Unsupported SQL json operator {op:?}"), + } + } +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 66422aa43f2ad..27351e10eb34e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -20,6 +20,7 @@ mod binary_op; mod function; mod grouping_set; mod identifier; +mod json_access; mod order_by; mod subquery; mod substring; @@ -28,15 +29,19 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue}; +use arrow_schema::TimeUnit; +use datafusion_common::{ + internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, + ScalarValue, +}; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr::{InList, Placeholder}; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast, + Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, Interval, TrimWhereField, Value}; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; use sqlparser::parser::ParserError::ParserError; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -70,6 +75,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); } + SQLExpr::JsonAccess { + left, + operator, + right, + } => { + let op = self.parse_sql_json_access(operator)?; + stack.push(StackEntry::Operator(op)); + stack.push(StackEntry::SQLExpr(right)); + stack.push(StackEntry::SQLExpr(left)); + } _ => { let expr = self.sql_expr_to_logical_expr_internal( *sql_expr, @@ -108,7 +123,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?; expr = self.rewrite_partial_qualifier(expr, schema); self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; - let expr = infer_placeholder_types(expr, schema)?; + let expr = expr.infer_placeholder_types(schema)?; Ok(expr) } @@ -152,76 +167,99 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Value(value) => { self.parse_value(value, planner_context.prepare_param_data_types()) } - SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction (ScalarFunction::new( - BuiltinScalarFunction::DatePart, - vec![ - Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))), - self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, - ], - ))), + SQLExpr::Extract { field, expr } => { + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::DatePart, + vec![ + Expr::Literal(ScalarValue::from(format!("{field}"))), + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ], + ))) + } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), - SQLExpr::Interval(Interval { - value, - leading_field, - leading_precision, - last_field, - fractional_seconds_precision, - })=> self.sql_interval_to_expr( - *value, - schema, - planner_context, - leading_field, - leading_precision, - last_field, - fractional_seconds_precision, - ), - SQLExpr::Identifier(id) => self.sql_identifier_to_expr(id, schema, planner_context), + SQLExpr::Interval(interval) => { + self.sql_interval_to_expr(false, interval, schema, planner_context) + } + SQLExpr::Identifier(id) => { + self.sql_identifier_to_expr(id, schema, planner_context) + } SQLExpr::MapAccess { column, keys } => { if let SQLExpr::Identifier(id) = *column { - plan_indexed(col(self.normalizer.normalize(id)), keys) + self.plan_indexed( + col(self.normalizer.normalize(id)), + keys, + schema, + planner_context, + ) } else { - Err(DataFusionError::NotImplemented(format!( + not_impl_err!( "map access requires an identifier, found column {column} instead" - ))) + ) } } SQLExpr::ArrayIndex { obj, indexes } => { - let expr = self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; - plan_indexed(expr, indexes) + let expr = + self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; + self.plan_indexed(expr, indexes, schema, planner_context) } - SQLExpr::CompoundIdentifier(ids) => self.sql_compound_identifier_to_expr(ids, schema, planner_context), + SQLExpr::CompoundIdentifier(ids) => { + self.sql_compound_identifier_to_expr(ids, schema, planner_context) + } SQLExpr::Case { operand, conditions, results, else_result, - } => self.sql_case_identifier_to_expr(operand, conditions, results, else_result, schema, planner_context), + } => self.sql_case_identifier_to_expr( + operand, + conditions, + results, + else_result, + schema, + planner_context, + ), SQLExpr::Cast { - expr, - data_type, - } => Ok(Expr::Cast(Cast::new( - Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), - self.convert_data_type(&data_type)?, - ))), + expr, data_type, .. + } => { + let dt = self.convert_data_type(&data_type)?; + let expr = + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; + + // numeric constants are treated as seconds (rather as nanoseconds) + // to align with postgres / duckdb semantics + let expr = match &dt { + DataType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == DataType::Int64 => + { + Expr::Cast(Cast::new( + Box::new(expr), + DataType::Timestamp(TimeUnit::Second, tz.clone()), + )) + } + _ => expr, + }; + + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) + } SQLExpr::TryCast { - expr, - data_type, + expr, data_type, .. } => Ok(Expr::TryCast(TryCast::new( - Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), + Box::new(self.sql_expr_to_logical_expr( + *expr, + schema, + planner_context, + )?), self.convert_data_type(&data_type)?, ))), - SQLExpr::TypedString { - data_type, - value, - } => Ok(Expr::Cast(Cast::new( + SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast(Cast::new( Box::new(lit(value)), self.convert_data_type(&data_type)?, ))), @@ -234,31 +272,65 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), - SQLExpr::IsDistinctFrom(left, right) => Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(self.sql_expr_to_logical_expr(*left, schema, planner_context)?), - Operator::IsDistinctFrom, - Box::new(self.sql_expr_to_logical_expr(*right, schema, planner_context)?), - ))), + SQLExpr::IsDistinctFrom(left, right) => { + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?), + Operator::IsDistinctFrom, + Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?), + ))) + } - SQLExpr::IsNotDistinctFrom(left, right) => Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(self.sql_expr_to_logical_expr(*left, schema, planner_context)?), - Operator::IsNotDistinctFrom, - Box::new(self.sql_expr_to_logical_expr(*right, schema, planner_context)?), - ))), + SQLExpr::IsNotDistinctFrom(left, right) => { + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?), + Operator::IsNotDistinctFrom, + Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?), + ))) + } - SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), + SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new( + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, + ))), - SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(op, *expr, schema, planner_context), + SQLExpr::UnaryOp { op, expr } => { + self.parse_sql_unary_op(op, *expr, schema, planner_context) + } SQLExpr::Between { expr, @@ -266,10 +338,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { low, high, } => Ok(Expr::Between(Between::new( - Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), + Box::new(self.sql_expr_to_logical_expr( + *expr, + schema, + planner_context, + )?), negated, Box::new(self.sql_expr_to_logical_expr(*low, schema, planner_context)?), - Box::new(self.sql_expr_to_logical_expr(*high, schema, planner_context)?), + Box::new(self.sql_expr_to_logical_expr( + *high, + schema, + planner_context, + )?), ))), SQLExpr::InList { @@ -278,18 +358,52 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated, } => self.sql_in_list_to_expr(*expr, list, negated, schema, planner_context), - SQLExpr::Like { negated, expr, pattern, escape_char } => self.sql_like_to_expr(negated, *expr, *pattern, escape_char, schema, planner_context), + SQLExpr::Like { + negated, + expr, + pattern, + escape_char, + } => self.sql_like_to_expr( + negated, + *expr, + *pattern, + escape_char, + schema, + planner_context, + false, + ), - SQLExpr::ILike { negated, expr, pattern, escape_char } => self.sql_ilike_to_expr(negated, *expr, *pattern, escape_char, schema, planner_context), + SQLExpr::ILike { + negated, + expr, + pattern, + escape_char, + } => self.sql_like_to_expr( + negated, + *expr, + *pattern, + escape_char, + schema, + planner_context, + true, + ), - SQLExpr::SimilarTo { negated, expr, pattern, escape_char } => self.sql_similarto_to_expr(negated, *expr, *pattern, escape_char, schema, planner_context), + SQLExpr::SimilarTo { + negated, + expr, + pattern, + escape_char, + } => self.sql_similarto_to_expr( + negated, + *expr, + *pattern, + escape_char, + schema, + planner_context, + ), - SQLExpr::BinaryOp { - .. - } => { - Err(DataFusionError::Internal( - "binary_op should be handled by sql_expr_to_logical_expr.".to_string() - )) + SQLExpr::BinaryOp { .. } => { + internal_err!("binary_op should be handled by sql_expr_to_logical_expr.") } #[cfg(feature = "unicode_expressions")] @@ -297,42 +411,132 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, substring_from, substring_for, - } => self.sql_substring_to_expr(expr, substring_from, substring_for, schema, planner_context), + special: false, + } => self.sql_substring_to_expr( + expr, + substring_from, + substring_for, + schema, + planner_context, + ), #[cfg(not(feature = "unicode_expressions"))] - SQLExpr::Substring { - .. - } => { - Err(DataFusionError::Internal( - "statement substring requires compilation with feature flag: unicode_expressions.".to_string() - )) + SQLExpr::Substring { .. } => { + internal_err!( + "statement substring requires compilation with feature flag: unicode_expressions." + ) } - SQLExpr::Trim { expr, trim_where, trim_what } => self.sql_trim_to_expr(*expr, trim_where, trim_what, schema, planner_context), + SQLExpr::Trim { + expr, + trim_where, + trim_what, + .. + } => self.sql_trim_to_expr( + *expr, + trim_where, + trim_what, + schema, + planner_context, + ), + + SQLExpr::AggregateExpressionWithFilter { expr, filter } => { + self.sql_agg_with_filter_to_expr(*expr, *filter, schema, planner_context) + } - SQLExpr::AggregateExpressionWithFilter { expr, filter } => self.sql_agg_with_filter_to_expr(*expr, *filter, schema, planner_context), + SQLExpr::Function(function) => { + self.sql_function_to_expr(function, schema, planner_context) + } - SQLExpr::Function(function) => self.sql_function_to_expr(function, schema, planner_context), + SQLExpr::Rollup(exprs) => { + self.sql_rollup_to_expr(exprs, schema, planner_context) + } + SQLExpr::Cube(exprs) => self.sql_cube_to_expr(exprs, schema, planner_context), + SQLExpr::GroupingSets(exprs) => { + self.sql_grouping_sets_to_expr(exprs, schema, planner_context) + } - SQLExpr::Rollup(exprs) => self.sql_rollup_to_expr(exprs, schema, planner_context), - SQLExpr::Cube(exprs) => self.sql_cube_to_expr(exprs,schema, planner_context), - SQLExpr::GroupingSets(exprs) => self.sql_grouping_sets_to_expr(exprs, schema, planner_context), + SQLExpr::Floor { + expr, + field: _field, + } => self.sql_named_function_to_expr( + *expr, + BuiltinScalarFunction::Floor, + schema, + planner_context, + ), + SQLExpr::Ceil { + expr, + field: _field, + } => self.sql_named_function_to_expr( + *expr, + BuiltinScalarFunction::Ceil, + schema, + planner_context, + ), + SQLExpr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => self.sql_overlay_to_expr( + *expr, + *overlay_what, + *overlay_from, + overlay_for, + schema, + planner_context, + ), + SQLExpr::Nested(e) => { + self.sql_expr_to_logical_expr(*e, schema, planner_context) + } - SQLExpr::Floor { expr, field: _field } => self.sql_named_function_to_expr(*expr, BuiltinScalarFunction::Floor, schema, planner_context), - SQLExpr::Ceil { expr, field: _field } => self.sql_named_function_to_expr(*expr, BuiltinScalarFunction::Ceil, schema, planner_context), + SQLExpr::Exists { subquery, negated } => { + self.parse_exists_subquery(*subquery, negated, schema, planner_context) + } + SQLExpr::InSubquery { + expr, + subquery, + negated, + } => { + self.parse_in_subquery(*expr, *subquery, negated, schema, planner_context) + } + SQLExpr::Subquery(subquery) => { + self.parse_scalar_subquery(*subquery, schema, planner_context) + } - SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, planner_context), + SQLExpr::ArrayAgg(array_agg) => { + self.parse_array_agg(array_agg, schema, planner_context) + } - SQLExpr::Exists { subquery, negated } => self.parse_exists_subquery(*subquery, negated, schema, planner_context), - SQLExpr::InSubquery { expr, subquery, negated } => self.parse_in_subquery(*expr, *subquery, negated, schema, planner_context), - SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(*subquery, schema, planner_context), + SQLExpr::Struct { values, fields } => { + self.parse_struct(values, fields, schema, planner_context) + } - SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema, planner_context), + _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), + } + } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported ast node in sqltorel: {sql:?}" - ))), + fn parse_struct( + &self, + values: Vec, + fields: Vec, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + if !fields.is_empty() { + return not_impl_err!("Struct fields are not supported yet"); } + let args = values + .into_iter() + .map(|value| { + self.sql_expr_to_logical_expr(value, input_schema, planner_context) + }) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::Struct, + args, + ))) } fn parse_array_agg( @@ -351,21 +555,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = array_agg; let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr(&order_by, input_schema, planner_context)?) + Some(self.order_by_to_sort_expr( + &order_by, + input_schema, + planner_context, + true, + )?) } else { None }; if let Some(limit) = limit { - return Err(DataFusionError::NotImplemented(format!( - "LIMIT not supported in ARRAY_AGG: {limit}" - ))); + return not_impl_err!("LIMIT not supported in ARRAY_AGG: {limit}"); } if within_group { - return Err(DataFusionError::NotImplemented( - "WITHIN GROUP not supported in ARRAY_AGG".to_string(), - )); + return not_impl_err!("WITHIN GROUP not supported in ARRAY_AGG"); } let args = @@ -398,6 +603,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + #[allow(clippy::too_many_arguments)] fn sql_like_to_expr( &self, negated: bool, @@ -406,43 +612,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, + case_insensitive: bool, ) -> Result { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { - return Err(DataFusionError::Plan( - "Invalid pattern in LIKE expression".to_string(), - )); + return plan_err!("Invalid pattern in LIKE expression"); } Ok(Expr::Like(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), escape_char, - ))) - } - - fn sql_ilike_to_expr( - &self, - negated: bool, - expr: SQLExpr, - pattern: SQLExpr, - escape_char: Option, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; - let pattern_type = pattern.get_type(schema)?; - if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { - return Err(DataFusionError::Plan( - "Invalid pattern in ILIKE expression".to_string(), - )); - } - Ok(Expr::ILike(Like::new( - negated, - Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), - Box::new(pattern), - escape_char, + case_insensitive, ))) } @@ -458,15 +640,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { - return Err(DataFusionError::Plan( - "Invalid pattern in SIMILAR TO expression".to_string(), - )); + return plan_err!("Invalid pattern in SIMILAR TO expression"); } Ok(Expr::SimilarTo(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), Box::new(pattern), escape_char, + false, ))) } @@ -496,6 +677,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } + fn sql_overlay_to_expr( + &self, + expr: SQLExpr, + overlay_what: SQLExpr, + overlay_from: SQLExpr, + overlay_for: Option>, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::OverLay; + let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let what_arg = + self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; + let from_arg = + self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; + let args = match overlay_for { + Some(for_expr) => { + let for_expr = + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; + vec![arg, what_arg, from_arg, for_expr] + } + None => vec![arg, what_arg, from_arg], + }; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } + fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, @@ -505,7 +712,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, @@ -521,81 +728,76 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?)), order_by, ))), - _ => Err(DataFusionError::Internal( + _ => plan_err!( "AggregateExpressionWithFilter expression was not an AggregateFunction" - .to_string(), - )), + ), } } -} -// modifies expr if it is a placeholder with datatype of right -fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { - if data_type.is_none() { - let other_dt = other.get_type(schema); - match other_dt { - Err(e) => { - return Err(e.context(format!( - "Can not find type of {other} needed to infer type of {expr}" - )))?; - } - Ok(dt) => { - *data_type = Some(dt); - } + fn plan_indices( + &self, + expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let field = match expr.clone() { + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => GetFieldAccess::NamedStructField { + name: ScalarValue::from(s), + }, + SQLExpr::JsonAccess { + left, + operator: JsonOperator::Colon, + right, + } => { + let start = Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?); + let stop = Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?); + + GetFieldAccess::ListRange { start, stop } } + _ => GetFieldAccess::ListIndex { + key: Box::new(self.sql_expr_to_logical_expr( + expr, + schema, + planner_context, + )?), + }, }; - } - Ok(()) -} -/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try -/// to infer their [`DataType`] from the context of their use. -fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { - expr.transform(&|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - Ok(Transformed::Yes(expr)) - }) -} + Ok(field) + } -fn plan_key(key: SQLExpr) -> Result { - let scalar = match key { - SQLExpr::Value(Value::Number(s, _)) => ScalarValue::Int64(Some( - s.parse() - .map_err(|_| ParserError(format!("Cannot parse {s} as i64.")))?, - )), - SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => { - ScalarValue::Utf8(Some(s)) - } - _ => { - return Err(DataFusionError::SQL(ParserError(format!( - "Unsuported index key expression: {key:?}" - )))); - } - }; + fn plan_indexed( + &self, + expr: Expr, + mut keys: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let indices = keys.pop().ok_or_else(|| { + ParserError("Internal error: Missing index key expression".to_string()) + })?; - Ok(scalar) -} + let expr = if !keys.is_empty() { + self.plan_indexed(expr, keys, schema, planner_context)? + } else { + expr + }; -fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { - let key = keys.pop().ok_or_else(|| { - ParserError("Internal error: Missing index key expression".to_string()) - })?; - - let expr = if !keys.is_empty() { - plan_indexed(expr, keys)? - } else { - expr - }; - - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - plan_key(key)?, - ))) + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(expr), + self.plan_indices(indices, schema, planner_context)?, + ))) + } } #[cfg(test)] @@ -611,16 +813,16 @@ mod tests { use datafusion_common::config::ConfigOptions; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource}; + use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use crate::TableReference; - struct TestSchemaProvider { + struct TestContextProvider { options: ConfigOptions, tables: HashMap>, } - impl TestSchemaProvider { + impl TestContextProvider { pub fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -639,17 +841,11 @@ mod tests { } } - impl ContextProvider for TestSchemaProvider { - fn get_table_provider( - &self, - name: TableReference, - ) -> Result> { + impl ContextProvider for TestContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), - _ => Err(DataFusionError::Plan(format!( - "Table not found: {}", - name.table() - ))), + _ => plan_err!("Table not found: {}", name.table()), } } @@ -668,6 +864,10 @@ mod tests { fn options(&self) -> &ConfigOptions { &self.options } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } } fn create_table_source(fields: Vec) -> Arc { @@ -695,8 +895,8 @@ mod tests { .unwrap(); let sql_expr = parser.parse_expr().unwrap(); - let schema_provider = TestSchemaProvider::new(); - let sql_to_rel = SqlToRel::new(&schema_provider); + let context_provider = TestContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); // Should not stack overflow sql_to_rel.sql_expr_to_logical_expr( diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index b32388f1bcdf7..772255bd9773a 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -16,18 +16,25 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{ + plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, +}; use datafusion_expr::expr::Sort; use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// convert sql [OrderByExpr] to `Vec` + /// Convert sql [OrderByExpr] to `Vec`. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index + /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, exprs: &[OrderByExpr], schema: &DFSchema, planner_context: &mut PlannerContext, + literal_to_column: bool, ) -> Result> { let mut expr_vec = vec![]; for e in exprs { @@ -38,21 +45,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v .parse::() - .map_err(|err| DataFusionError::Plan(err.to_string()))?; + .map_err(|err| plan_datafusion_err!("{}", err))?; if field_index == 0 { - return Err(DataFusionError::Plan( - "Order by index starts at 1 for column indexes".to_string(), - )); + return plan_err!( + "Order by index starts at 1 for column indexes" + ); } else if schema.fields().len() < field_index { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Order by column out of bounds, specified: {}, max: {}", field_index, schema.fields().len() - ))); + ); } let field = schema.field(field_index - 1); diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 1a95266615425..71b2a11cd4143 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{BuiltinScalarFunction, Expr}; @@ -60,11 +61,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, substring_from: None, substring_for: None, + special: false, }; - return Err(DataFusionError::Plan(format!( - "Substring without for/from is not valid {orig_sql:?}" - ))); + return plan_err!("Substring without for/from is not valid {orig_sql:?}"); } }; diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index d24fc71540369..08ff6f2c3622a 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -16,8 +16,8 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, DataFusionError, Result}; -use datafusion_expr::{lit, Expr}; +use datafusion_common::{not_impl_err, DFSchema, DataFusionError, Result}; +use datafusion_expr::Expr; use sqlparser::ast::{Expr as SQLExpr, UnaryOperator, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -39,22 +39,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match expr { // optimization: if it's a number literal, we apply the negative operator // here directly to calculate the new literal. - SQLExpr::Value(Value::Number(n, _)) => match n.parse::() { - Ok(n) => Ok(lit(-n)), - Err(_) => Ok(lit(-n - .parse::() - .map_err(|_e| { - DataFusionError::Internal(format!( - "negative operator can be only applied to integer and float operands, got: {n}")) - })?)), - }, + SQLExpr::Value(Value::Number(n, _)) => { + self.parse_sql_number(&n, true) + } + SQLExpr::Interval(interval) => { + self.sql_interval_to_expr(true, interval, schema, planner_context) + } // not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?))), + _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( + expr, + schema, + planner_context, + )?))), } } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported SQL unary operator {op:?}" - ))), + _ => not_impl_err!("Unsupported SQL unary operator {op:?}"), } } } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3d959f17ce0cd..708f7c60011a5 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -17,14 +17,19 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; +use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::{lit, Expr, Operator}; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; -use sqlparser::ast::{BinaryOperator, DateTimeField, Expr as SQLExpr, Value}; +use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; -use std::collections::HashSet; +use std::borrow::Cow; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn parse_value( @@ -33,53 +38,55 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { param_data_types: &[DataType], ) -> Result { match value { - Value::Number(n, _) => self.parse_sql_number(&n), + Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), Value::Null => Ok(Expr::Literal(ScalarValue::Null)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) } - _ => Err(DataFusionError::Plan(format!( - "Unsupported Value '{value:?}'", - ))), + Value::HexStringLiteral(s) => { + if let Some(v) = try_decode_hex_literal(&s) { + Ok(lit(v)) + } else { + plan_err!("Invalid HexStringLiteral '{s}'") + } + } + _ => plan_err!("Unsupported Value '{value:?}'"), } } /// Parse number in sql string, convert to Expr::Literal - fn parse_sql_number(&self, n: &str) -> Result { - if let Ok(n) = n.parse::() { - Ok(lit(n)) - } else if let Ok(n) = n.parse::() { - Ok(lit(n)) - } else if self.options.parse_float_as_decimal { - // remove leading zeroes - let str = n.trim_start_matches('0'); - if let Some(i) = str.find('.') { - let p = str.len() - 1; - let s = str.len() - i - 1; - let str = str.replace('.', ""); - let n = str.parse::().map_err(|_| { - DataFusionError::from(ParserError(format!( - "Cannot parse {str} as i128 when building decimal" - ))) - })?; - Ok(Expr::Literal(ScalarValue::Decimal128( - Some(n), - p as u8, - s as i8, - ))) - } else { - let number = n.parse::().map_err(|_| { - DataFusionError::from(ParserError(format!( - "Cannot parse {n} as i128 when building decimal" - ))) - })?; - Ok(Expr::Literal(ScalarValue::Decimal128(Some(number), 38, 0))) + pub(super) fn parse_sql_number( + &self, + unsigned_number: &str, + negative: bool, + ) -> Result { + let signed_number: Cow = if negative { + Cow::Owned(format!("-{unsigned_number}")) + } else { + Cow::Borrowed(unsigned_number) + }; + + // Try to parse as i64 first, then u64 if negative is false, then decimal or f64 + + if let Ok(n) = signed_number.parse::() { + return Ok(lit(n)); + } + + if !negative { + if let Ok(n) = unsigned_number.parse::() { + return Ok(lit(n)); } + } + + if self.options.parse_float_as_decimal { + parse_decimal_128(unsigned_number, negative) } else { - n.parse::().map(lit).map_err(|_| { - DataFusionError::from(ParserError(format!("Cannot parse {n} as f64"))) + signed_number.parse::().map(lit).map_err(|_| { + DataFusionError::from(ParserError(format!( + "Cannot parse {signed_number} as f64" + ))) }) } } @@ -95,15 +102,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let index = param[1..].parse::(); let idx = match index { Ok(0) => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid placeholder, zero is not a valid index: {param}" - ))); + ); } Ok(index) => index - 1, Err(_) => { - return Err(DataFusionError::Plan(format!( - "Invalid placeholder, not a number: {param}" - ))); + return if param_data_types.is_empty() { + Ok(Expr::Placeholder(Placeholder::new(param, None))) + } else { + // when PREPARE Statement, param_data_types length is always 0 + plan_err!("Invalid placeholder, not a number: {param}") + }; } }; // Check if the placeholder is in the parameter list @@ -133,73 +143,78 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, &mut PlannerContext::new(), )?; + match value { - Expr::Literal(scalar) => { - values.push(scalar); + Expr::Literal(_) => { + values.push(value); + } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + .. + }) => { + if fun == BuiltinScalarFunction::MakeArray { + values.push(value); + } else { + return not_impl_err!( + "ScalarFunctions without MakeArray are not supported: {value}" + ); + } } _ => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Arrays with elements other than literal are not supported: {value}" - ))); + ); } } } - let data_types: HashSet = - values.iter().map(|e| e.get_datatype()).collect(); - - if data_types.is_empty() { - Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) - } else if data_types.len() > 1 { - Err(DataFusionError::NotImplemented(format!( - "Arrays with different types are not supported: {data_types:?}", - ))) - } else { - let data_type = values[0].get_datatype(); - - Ok(lit(ScalarValue::new_list(Some(values), data_type))) - } + Ok(Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + values, + ))) } /// Convert a SQL interval expression to a DataFusion logical plan /// expression - /// - /// Waiting for this issue to be resolved: - /// `` - #[allow(clippy::too_many_arguments)] pub(super) fn sql_interval_to_expr( &self, - value: SQLExpr, + negative: bool, + interval: Interval, schema: &DFSchema, planner_context: &mut PlannerContext, - leading_field: Option, - leading_precision: Option, - last_field: Option, - fractional_seconds_precision: Option, ) -> Result { - if leading_precision.is_some() { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported Interval Expression with leading_precision {leading_precision:?}" - ))); + if interval.leading_precision.is_some() { + return not_impl_err!( + "Unsupported Interval Expression with leading_precision {:?}", + interval.leading_precision + ); } - if last_field.is_some() { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported Interval Expression with last_field {last_field:?}" - ))); + if interval.last_field.is_some() { + return not_impl_err!( + "Unsupported Interval Expression with last_field {:?}", + interval.last_field + ); } - if fractional_seconds_precision.is_some() { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported Interval Expression with fractional_seconds_precision {fractional_seconds_precision:?}" - ))); + if interval.fractional_seconds_precision.is_some() { + return not_impl_err!( + "Unsupported Interval Expression with fractional_seconds_precision {:?}", + interval.fractional_seconds_precision + ); } // Only handle string exprs for now - let value = match value { + let value = match *interval.value { SQLExpr::Value( Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => s, + ) => { + if negative { + format!("-{s}") + } else { + s + } + } // Support expressions like `interval '1 month' + date/timestamp`. // Such expressions are parsed like this by sqlparser-rs // @@ -221,30 +236,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::Plus => Operator::Plus, BinaryOperator::Minus => Operator::Minus, _ => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported interval operator: {op:?}" - ))); + return not_impl_err!("Unsupported interval operator: {op:?}"); } }; - match (leading_field, left.as_ref(), right.as_ref()) { + match (interval.leading_field, left.as_ref(), right.as_ref()) { (_, _, SQLExpr::Value(_)) => { let left_expr = self.sql_interval_to_expr( - *left, + negative, + Interval { + value: left, + leading_field: interval.leading_field, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, schema, planner_context, - leading_field, - None, - None, - None, )?; let right_expr = self.sql_interval_to_expr( - *right, + false, + Interval { + value: right, + leading_field: interval.leading_field, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, schema, planner_context, - leading_field, - None, - None, - None, )?; return Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left_expr), @@ -259,13 +278,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // is not a value. (None, _, _) => { let left_expr = self.sql_interval_to_expr( - *left, + negative, + Interval { + value: left, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, schema, planner_context, - None, - None, - None, - None, )?; let right_expr = self.sql_expr_to_logical_expr( *right, @@ -280,16 +302,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } _ => { let value = SQLExpr::BinaryOp { left, op, right }; - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Unsupported interval argument. Expected string literal, got: {value:?}" - ))); + ); } } } _ => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported interval argument. Expected string literal, got: {value:?}" - ))); + return not_impl_err!( + "Unsupported interval argument. Expected string literal, got: {:?}", + interval.value + ); } }; @@ -301,7 +324,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { // leading_field really means the unit if specified // for example, "month" in `INTERVAL '5' month` - match leading_field.as_ref() { + match interval.leading_field.as_ref() { Some(leading_field) => { format!("{value} {leading_field}") } @@ -345,3 +368,104 @@ fn has_units(val: &str) -> bool { || val.ends_with("nanosecond") || val.ends_with("nanoseconds") } + +/// Try to decode bytes from hex literal string. +/// +/// None will be returned if the input literal is hex-invalid. +fn try_decode_hex_literal(s: &str) -> Option> { + let hex_bytes = s.as_bytes(); + + let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2); + + let start_idx = hex_bytes.len() % 2; + if start_idx > 0 { + // The first byte is formed of only one char. + decoded_bytes.push(try_decode_hex_char(hex_bytes[0])?); + } + + for i in (start_idx..hex_bytes.len()).step_by(2) { + let high = try_decode_hex_char(hex_bytes[i])?; + let low = try_decode_hex_char(hex_bytes[i + 1])?; + decoded_bytes.push(high << 4 | low); + } + + Some(decoded_bytes) +} + +/// Try to decode a byte from a hex char. +/// +/// None will be returned if the input char is hex-invalid. +const fn try_decode_hex_char(c: u8) -> Option { + match c { + b'A'..=b'F' => Some(c - b'A' + 10), + b'a'..=b'f' => Some(c - b'a' + 10), + b'0'..=b'9' => Some(c - b'0'), + _ => None, + } +} + +/// Parse Decimal128 from a string +/// +/// TODO: support parsing from scientific notation +fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { + // remove leading zeroes + let trimmed = unsigned_number.trim_start_matches('0'); + // parse precision and scale, remove decimal point if exists + let (precision, scale, replaced_str) = if trimmed == "." { + // special cases for numbers such as “0.”, “000.”, and so on. + (1, 0, Cow::Borrowed("0")) + } else if let Some(i) = trimmed.find('.') { + ( + trimmed.len() - 1, + trimmed.len() - i - 1, + Cow::Owned(trimmed.replace('.', "")), + ) + } else { + // no decimal point, keep as is + (trimmed.len(), 0, Cow::Borrowed(trimmed)) + }; + + let number = replaced_str.parse::().map_err(|e| { + DataFusionError::from(ParserError(format!( + "Cannot parse {replaced_str} as i128 when building decimal: {e}" + ))) + })?; + + // check precision overflow + if precision as u8 > DECIMAL128_MAX_PRECISION { + return Err(DataFusionError::from(ParserError(format!( + "Cannot parse {replaced_str} as i128 when building decimal: precision overflow" + )))); + } + + Ok(Expr::Literal(ScalarValue::Decimal128( + Some(if negative { -number } else { number }), + precision as u8, + scale as i8, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode_hex_literal() { + let cases = [ + ("", Some(vec![])), + ("FF00", Some(vec![255, 0])), + ("a00a", Some(vec![160, 10])), + ("FF0", Some(vec![15, 240])), + ("f", Some(vec![15])), + ("FF0X", None), + ("X0", None), + ("XX", None), + ("x", None), + ]; + + for (input, expect) in cases { + let output = try_decode_hex_literal(input); + assert_eq!(output, expect); + } + } +} diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index c0c1a4ac91186..d805f61397e90 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -15,8 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! This module provides a SQL parser that translates SQL queries into an abstract syntax -//! tree (AST), and a SQL query planner that creates a logical plan from the AST. +//! This module provides: +//! +//! 1. A SQL parser, [`DFParser`], that translates SQL query text into +//! an abstract syntax tree (AST), [`Statement`]. +//! +//! 2. A SQL query planner [`SqlToRel`] that creates [`LogicalPlan`]s +//! from [`Statement`]s. +//! +//! [`DFParser`]: parser::DFParser +//! [`Statement`]: parser::Statement +//! [`SqlToRel`]: planner::SqlToRel +//! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan mod expr; pub mod parser; diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 38dacf35be134..9c104ff18a9b3 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion SQL Parser based on [`sqlparser`] +//! [`DFParser`]: DataFusion SQL Parser based on [`sqlparser`] use datafusion_common::parsers::CompressionTypeVariant; use sqlparser::ast::{OrderByExpr, Query, Value}; @@ -44,6 +44,35 @@ fn parse_file_type(s: &str) -> Result { Ok(s.to_uppercase()) } +/// DataFusion specific EXPLAIN (needed so we can EXPLAIN datafusion +/// specific COPY and other statements) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExplainStatement { + pub analyze: bool, + pub verbose: bool, + pub statement: Box, +} + +impl fmt::Display for ExplainStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self { + analyze, + verbose, + statement, + } = self; + + write!(f, "EXPLAIN ")?; + if *analyze { + write!(f, "ANALYZE ")?; + } + if *verbose { + write!(f, "VERBOSE ")?; + } + + write!(f, "{statement}") + } +} + /// DataFusion extension DDL for `COPY` /// /// # Syntax: @@ -62,7 +91,7 @@ fn parse_file_type(s: &str) -> Result { /// (format parquet, /// partitions 16, /// row_group_limit_rows 100000, -// row_group_limit_bytes 200000 +/// row_group_limit_bytes 200000 /// ) /// /// COPY (SELECT l_orderkey from lineitem) to 'lineitem.parquet'; @@ -74,7 +103,7 @@ pub struct CopyToStatement { /// The URL to where the data is heading pub target: String, /// Target specific options - pub options: HashMap, + pub options: Vec<(String, Value)>, } impl fmt::Display for CopyToStatement { @@ -88,10 +117,8 @@ impl fmt::Display for CopyToStatement { write!(f, "COPY {source} TO {target}")?; if !options.is_empty() { - let mut opts: Vec<_> = - options.iter().map(|(k, v)| format!("{k} {v}")).collect(); + let opts: Vec<_> = options.iter().map(|(k, v)| format!("{k} {v}")).collect(); // print them in sorted order - opts.sort_unstable(); write!(f, " ({})", opts.join(", "))?; } @@ -170,13 +197,15 @@ pub struct CreateExternalTable { pub unbounded: bool, /// Table(provider) specific options pub options: HashMap, + /// A table-level constraint + pub constraints: Vec, } impl fmt::Display for CreateExternalTable { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "CREATE EXTERNAL TABLE ")?; if self.if_not_exists { - write!(f, "IF NOT EXSISTS ")?; + write!(f, "IF NOT EXISTS ")?; } write!(f, "{} ", self.name)?; write!(f, "STORED AS {} ", self.file_type)?; @@ -191,9 +220,13 @@ pub struct DescribeTableStmt { pub table_name: ObjectName, } -/// DataFusion Statement representations. +/// DataFusion SQL Statement. /// -/// Tokens parsed by [`DFParser`] are converted into these values. +/// This can either be a [`Statement`] from [`sqlparser`] from a +/// standard SQL dialect, or a DataFusion extension such as `CREATE +/// EXTERAL TABLE`. See [`DFParser`] for more information. +/// +/// [`Statement`]: sqlparser::ast::Statement #[derive(Debug, Clone, PartialEq, Eq)] pub enum Statement { /// ANSI SQL AST node (from sqlparser-rs) @@ -204,6 +237,8 @@ pub enum Statement { DescribeTableStmt(DescribeTableStmt), /// Extension: `COPY TO` CopyTo(CopyToStatement), + /// EXPLAIN for extensions + Explain(ExplainStatement), } impl fmt::Display for Statement { @@ -213,14 +248,19 @@ impl fmt::Display for Statement { Statement::CreateExternalTable(stmt) => write!(f, "{stmt}"), Statement::DescribeTableStmt(_) => write!(f, "DESCRIBE TABLE ..."), Statement::CopyTo(stmt) => write!(f, "{stmt}"), + Statement::Explain(stmt) => write!(f, "{stmt}"), } } } -/// DataFusion SQL Parser based on [`sqlparser`] +/// Datafusion SQL Parser based on [`sqlparser`] +/// +/// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. /// -/// This parser handles DataFusion specific statements, delegating to -/// [`Parser`](sqlparser::parser::Parser) for other SQL statements. +/// DataFusion mostly follows existing SQL dialects via +/// `sqlparser`. However, certain statements such as `COPY` and +/// `CREATE EXTERNAL TABLE` have special syntax in DataFusion. See +/// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { parser: Parser<'a>, } @@ -298,24 +338,24 @@ impl<'a> DFParser<'a> { Token::Word(w) => { match w.keyword { Keyword::CREATE => { - // move one token forward - self.parser.next_token(); - // use custom parsing + self.parser.next_token(); // CREATE self.parse_create() } Keyword::COPY => { - // move one token forward - self.parser.next_token(); + self.parser.next_token(); // COPY self.parse_copy() } Keyword::DESCRIBE => { - // move one token forward - self.parser.next_token(); - // use custom parsing + self.parser.next_token(); // DESCRIBE self.parse_describe() } + Keyword::EXPLAIN => { + // (TODO parse all supported statements) + self.parser.next_token(); // EXPLAIN + self.parse_explain() + } _ => { - // use the native parser + // use sqlparser-rs parser Ok(Statement::Statement(Box::from( self.parser.parse_statement()?, ))) @@ -360,7 +400,7 @@ impl<'a> DFParser<'a> { let options = if self.parser.peek_token().token == Token::LParen { self.parse_value_options()? } else { - HashMap::new() + vec![] }; Ok(Statement::CopyTo(CopyToStatement { @@ -412,6 +452,19 @@ impl<'a> DFParser<'a> { } } + /// Parse a SQL `EXPLAIN` + pub fn parse_explain(&mut self) -> Result { + let analyze = self.parser.parse_keyword(Keyword::ANALYZE); + let verbose = self.parser.parse_keyword(Keyword::VERBOSE); + let statement = self.parse_statement()?; + + Ok(Statement::Explain(ExplainStatement { + statement: Box::new(statement), + analyze, + verbose, + })) + } + /// Parse a SQL `CREATE` statement handling `CREATE EXTERNAL TABLE` pub fn parse_create(&mut self) -> Result { if self.parser.parse_keyword(Keyword::EXTERNAL) { @@ -578,7 +631,7 @@ impl<'a> DFParser<'a> { self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parser.parse_object_name()?; - let (columns, _) = self.parse_columns()?; + let (columns, constraints) = self.parse_columns()?; #[derive(Default)] struct Builder { @@ -664,8 +717,7 @@ impl<'a> DFParser<'a> { break; } else { return Err(ParserError::ParserError(format!( - "Unexpected token {}", - token + "Unexpected token {token}" ))); } } @@ -698,6 +750,7 @@ impl<'a> DFParser<'a> { .unwrap_or(CompressionTypeVariant::UNCOMPRESSED), unbounded, options: builder.options.unwrap_or(HashMap::new()), + constraints, }; Ok(Statement::CreateExternalTable(create)) } @@ -707,7 +760,7 @@ impl<'a> DFParser<'a> { let token = self.parser.next_token(); match &token.token { Token::Word(w) => parse_file_type(&w.value), - _ => self.expected("one of PARQUET, NDJSON, or CSV", token), + _ => self.expected("one of ARROW, PARQUET, NDJSON, or CSV", token), } } @@ -750,14 +803,14 @@ impl<'a> DFParser<'a> { /// Unlike [`Self::parse_string_options`], this method supports /// keywords as key names as well as multiple value types such as /// Numbers as well as Strings. - fn parse_value_options(&mut self) -> Result, ParserError> { - let mut options = HashMap::new(); + fn parse_value_options(&mut self) -> Result, ParserError> { + let mut options = vec![]; self.parser.expect_token(&Token::LParen)?; loop { let key = self.parse_option_key()?; let value = self.parse_option_value()?; - options.insert(key, value); + options.push((key, value)); let comma = self.parser.consume_token(&Token::Comma); if self.parser.consume_token(&Token::RParen) { // allow a trailing comma, even though it's not in standard @@ -849,6 +902,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -867,6 +921,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -886,6 +941,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -905,6 +961,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -924,6 +981,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -946,6 +1004,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -973,6 +1032,7 @@ mod tests { )?, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -992,6 +1052,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1010,6 +1071,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1028,6 +1090,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1047,6 +1110,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1071,6 +1135,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::from([("k1".into(), "v1".into())]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1093,11 +1158,12 @@ mod tests { ("k1".into(), "v1".into()), ("k2".into(), "v2".into()), ]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; // Ordered Col - let sqls = vec!["CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1) LOCATION 'foo.csv'", + let sqls = ["CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1) LOCATION 'foo.csv'", "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 NULLS FIRST) LOCATION 'foo.csv'", "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 NULLS LAST) LOCATION 'foo.csv'", "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 ASC) LOCATION 'foo.csv'", @@ -1138,6 +1204,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; } @@ -1178,6 +1245,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1214,6 +1282,7 @@ mod tests { file_compression_type: UNCOMPRESSED, unbounded: false, options: HashMap::new(), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1262,6 +1331,7 @@ mod tests { ("ROW_GROUP_SIZE".into(), "1024".into()), ("TRUNCATE".into(), "NO".into()), ]), + constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1277,13 +1347,39 @@ mod tests { let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), - options: HashMap::new(), + options: vec![], }); assert_eq!(verified_stmt(sql), expected); Ok(()) } + #[test] + fn explain_copy_to_table_to_table() -> Result<(), ParserError> { + let cases = vec![ + ("EXPLAIN COPY foo TO bar", false, false), + ("EXPLAIN ANALYZE COPY foo TO bar", true, false), + ("EXPLAIN VERBOSE COPY foo TO bar", false, true), + ("EXPLAIN ANALYZE VERBOSE COPY foo TO bar", true, true), + ]; + for (sql, analyze, verbose) in cases { + println!("sql: {sql}, analyze: {analyze}, verbose: {verbose}"); + + let expected_copy = Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + options: vec![], + }); + let expected = Statement::Explain(ExplainStatement { + analyze, + verbose, + statement: Box::new(expected_copy), + }); + assert_eq!(verified_stmt(sql), expected); + } + Ok(()) + } + #[test] fn copy_to_query_to_table() -> Result<(), ParserError> { let statement = verified_stmt("SELECT 1"); @@ -1305,7 +1401,7 @@ mod tests { let expected = Statement::CopyTo(CopyToStatement { source: CopyToSource::Query(query), target: "bar".to_string(), - options: HashMap::new(), + options: vec![], }); assert_eq!(verified_stmt(sql), expected); Ok(()) @@ -1317,10 +1413,10 @@ mod tests { let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), - options: HashMap::from([( + options: vec![( "row_group_size".to_string(), Value::Number("55".to_string(), false), - )]), + )], }); assert_eq!(verified_stmt(sql), expected); Ok(()) @@ -1328,17 +1424,11 @@ mod tests { #[test] fn copy_to_multi_options() -> Result<(), ParserError> { + // order of options is preserved let sql = "COPY foo TO bar (format parquet, row_group_size 55, compression snappy)"; - // canonical order is alphabetical - let canonical = - "COPY foo TO bar (compression snappy, format parquet, row_group_size 55)"; - let expected_options = HashMap::from([ - ( - "compression".to_string(), - Value::UnQuotedString("snappy".to_string()), - ), + let expected_options = vec![ ( "format".to_string(), Value::UnQuotedString("parquet".to_string()), @@ -1347,14 +1437,17 @@ mod tests { "row_group_size".to_string(), Value::Number("55".to_string(), false), ), - ]); + ( + "compression".to_string(), + Value::UnQuotedString("snappy".to_string()), + ), + ]; - let options = - if let Statement::CopyTo(copy_to) = one_statement_parses_to(sql, canonical) { - copy_to.options - } else { - panic!("Expected copy"); - }; + let options = if let Statement::CopyTo(copy_to) = verified_stmt(sql) { + copy_to.options + } else { + panic!("Expected copy"); + }; assert_eq!(options, expected_options); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ceec01037425f..c5c30e3a22536 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -15,37 +15,58 @@ // specific language governing permissions and limitations // under the License. -//! SQL Query Planner (produces logical plan from SQL AST) +//! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST) use std::collections::HashMap; use std::sync::Arc; use std::vec; use arrow_schema::*; -use datafusion_common::field_not_found; -use sqlparser::ast::ExactNumberInfo; +use datafusion_common::{ + field_not_found, internal_err, plan_datafusion_err, SchemaError, +}; +use datafusion_expr::WindowUDF; use sqlparser::ast::TimezoneInfo; +use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{unqualified_field_not_found, DFSchema, DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, + Result, +}; use datafusion_common::{OwnedTableReference, TableReference}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; use datafusion_expr::TableSource; -use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF, SubqueryAlias}; +use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; use crate::utils::make_decimal_type; /// The ContextProvider trait allows the query planner to obtain meta-data about tables and /// functions referenced in SQL statements pub trait ContextProvider { + #[deprecated(since = "32.0.0", note = "please use `get_table_source` instead")] + fn get_table_provider(&self, name: TableReference) -> Result> { + self.get_table_source(name) + } /// Getter for a datasource - fn get_table_provider(&self, name: TableReference) -> Result>; + fn get_table_source(&self, name: TableReference) -> Result>; + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type fn get_variable_type(&self, variable_names: &[String]) -> Option; @@ -179,22 +200,22 @@ impl PlannerContext { /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { - pub(crate) schema_provider: &'a S, + pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) normalizer: IdentNormalizer, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner - pub fn new(schema_provider: &'a S) -> Self { - Self::new_with_options(schema_provider, ParserOptions::default()) + pub fn new(context_provider: &'a S) -> Self { + Self::new_with_options(context_provider, ParserOptions::default()) } /// Create a new query planner - pub fn new_with_options(schema_provider: &'a S, options: ParserOptions) -> Self { + pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let normalize = options.enable_ident_normalization; SqlToRel { - schema_provider, + context_provider, options, normalizer: IdentNormalizer::new(normalize), } @@ -219,18 +240,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Schema::new(fields)) } - /// Apply the given TableAlias to the top-level projection. + /// Returns a vector of (column_name, default_expr) pairs + pub(super) fn build_column_defaults( + &self, + columns: &Vec, + planner_context: &mut PlannerContext, + ) -> Result> { + let mut column_defaults = vec![]; + // Default expressions are restricted, column references are not allowed + let empty_schema = DFSchema::empty(); + let error_desc = |e: DataFusionError| match e { + DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }) => { + plan_datafusion_err!( + "Column reference is not allowed in the DEFAULT expression : {}", + e + ) + } + _ => e, + }; + + for column in columns { + if let Some(default_sql_expr) = + column.options.iter().find_map(|o| match &o.option { + ColumnOption::Default(expr) => Some(expr), + _ => None, + }) + { + let default_expr = self + .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) + .map_err(error_desc)?; + column_defaults + .push((self.normalizer.normalize(column.name.clone()), default_expr)); + } + } + Ok(column_defaults) + } + + /// Apply the given TableAlias to the input plan pub(crate) fn apply_table_alias( &self, plan: LogicalPlan, alias: TableAlias, ) -> Result { - let apply_name_plan = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - plan, - self.normalizer.normalize(alias.name), - )?); + let plan = self.apply_expr_alias(plan, alias.columns)?; - self.apply_expr_alias(apply_name_plan, alias.columns) + LogicalPlanBuilder::from(plan) + .alias(self.normalizer.normalize(alias.name))? + .build() } pub(crate) fn apply_expr_alias( @@ -241,11 +297,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if idents.is_empty() { Ok(plan) } else if idents.len() != plan.schema().fields().len() { - Err(DataFusionError::Plan(format!( + plan_err!( "Source table contains {} columns but only {} names given as column alias", plan.schema().fields().len(), - idents.len(), - ))) + idents.len() + ) } else { let fields = plan.schema().fields().clone(); LogicalPlanBuilder::from(plan) @@ -281,46 +337,47 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map_err(|_: DataFusionError| { field_not_found(col.relation.clone(), col.name.as_str(), schema) }), - _ => Err(DataFusionError::Internal("Not a column".to_string())), + _ => internal_err!("Not a column"), }) } pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Array(Some(inner_sql_type)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { let data_type = self.convert_simple_data_type(inner_sql_type)?; Ok(DataType::List(Arc::new(Field::new( "field", data_type, true, )))) } - SQLDataType::Array(None) => Err(DataFusionError::NotImplemented( - "Arrays with unspecified type is not supported".to_string(), - )), + SQLDataType::Array(ArrayElemTypeDef::None) => { + not_impl_err!("Arrays with unspecified type is not supported") + } other => self.convert_simple_data_type(other), } } fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Boolean => Ok(DataType::Boolean), + SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean), SQLDataType::TinyInt(_) => Ok(DataType::Int8), - SQLDataType::SmallInt(_) => Ok(DataType::Int16), - SQLDataType::Int(_) | SQLDataType::Integer(_) => Ok(DataType::Int32), - SQLDataType::BigInt(_) => Ok(DataType::Int64), + SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16), + SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32), + SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64), SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8), - SQLDataType::UnsignedSmallInt(_) => Ok(DataType::UInt16), - SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => { + SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16), + SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => { Ok(DataType::UInt32) } - SQLDataType::UnsignedBigInt(_) => Ok(DataType::UInt64), + SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), SQLDataType::Float(_) => Ok(DataType::Float32), - SQLDataType::Real => Ok(DataType::Float32), - SQLDataType::Double | SQLDataType::DoublePrecision => Ok(DataType::Float64), + SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text - | SQLDataType::String => Ok(DataType::Utf8), + | SQLDataType::String(_) => Ok(DataType::Utf8), SQLDataType::Timestamp(None, tz_info) => { let tz = if matches!(tz_info, TimezoneInfo::Tz) || matches!(tz_info, TimezoneInfo::WithTimeZone) @@ -328,7 +385,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp With Time Zone // INPUT : [SQLDataType] TimestampTz + [RuntimeConfig] Time Zone // OUTPUT: [ArrowDataType] Timestamp - self.schema_provider.options().execution.time_zone.clone() + self.context_provider.options().execution.time_zone.clone() } else { // Timestamp Without Time zone None @@ -343,9 +400,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(DataType::Time64(TimeUnit::Nanosecond)) } else { // We dont support TIMETZ and TIME WITH TIME ZONE for now - Err(DataFusionError::NotImplemented(format!( + not_impl_err!( "Unsupported SQL type {sql_type:?}" - ))) + ) } } SQLDataType::Numeric(exact_number_info) @@ -390,9 +447,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Dec(_) | SQLDataType::BigNumeric(_) | SQLDataType::BigDecimal(_) - | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( + | SQLDataType::Clob(_) + | SQLDataType::Bytes(_) + | SQLDataType::Int64 + | SQLDataType::Float64 + | SQLDataType::Struct(_) + => not_impl_err!( "Unsupported SQL type {sql_type:?}" - ))), + ), } } @@ -459,10 +521,7 @@ pub(crate) fn idents_to_table_reference( let catalog = taker.take(enable_normalization); Ok(OwnedTableReference::full(catalog, schema, table)) } - _ => Err(DataFusionError::Plan(format!( - "Unsupported compound identifier '{:?}'", - taker.0, - ))), + _ => plan_err!("Unsupported compound identifier '{:?}'", taker.0), } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 2d7771d8c753c..dd4cab126261e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,9 +19,11 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, @@ -52,18 +54,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process CTEs from top to bottom // do not allow self-references if with.recursive { - return Err(DataFusionError::NotImplemented( - "Recursive CTEs are not supported".to_string(), - )); + return not_impl_err!("Recursive CTEs are not supported"); } for cte in with.cte_tables { // A `WITH` block can't use the same name more than once let cte_name = self.normalizer.normalize(cte.alias.name.clone()); if planner_context.contains_cte(&cte_name) { - return Err(DataFusionError::SQL(ParserError(format!( + return sql_err!(ParserError(format!( "WITH query name {cte_name:?} specified more than once" - )))); + ))); } // create logical plan & pass backreferencing CTEs // CTE expr don't need extend outer_query_schema @@ -86,10 +86,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let select_into = select.into.unwrap(); LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { name: self.object_name_to_table_reference(select_into.name)?, - primary_key: Vec::new(), + constraints: Constraints::empty(), input: Arc::new(plan), if_not_exists: false, or_replace: false, + column_defaults: vec![], })) } _ => plan, @@ -117,15 +118,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )? { Expr::Literal(ScalarValue::Int64(Some(s))) => { if s < 0 { - return Err(DataFusionError::Plan(format!( - "Offset must be >= 0, '{s}' was provided." - ))); + return plan_err!("Offset must be >= 0, '{s}' was provided."); } Ok(s as usize) } - _ => Err(DataFusionError::Plan( - "Unexpected expression in OFFSET clause".to_string(), - )), + _ => plan_err!("Unexpected expression in OFFSET clause"), }?, _ => 0, }; @@ -142,9 +139,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Literal(ScalarValue::Int64(Some(n))) if n >= 0 => { Ok(n as usize) } - _ => Err(DataFusionError::Plan( - "LIMIT must not be negative".to_string(), - )), + _ => plan_err!("LIMIT must not be negative"), }?; Some(n) } @@ -166,7 +161,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; + + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + // In case of `DISTINCT ON` we must capture the sort expressions since during the plan + // optimization we're effectively doing a `first_value` aggregation according to them. + let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + } else { + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + } } } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index eedae28385fc6..b119672eae5f9 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -16,7 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{Column, DataFusionError, Result}; +use datafusion_common::{not_impl_err, Column, DataFusionError, Result}; use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins}; use std::collections::HashSet; @@ -106,9 +106,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_join(left, right, constraint, JoinType::Full, planner_context) } JoinOperator::CrossJoin => self.parse_cross_join(left, right), - other => Err(DataFusionError::NotImplemented(format!( - "Unsupported JOIN operator {other:?}" - ))), + other => not_impl_err!("Unsupported JOIN operator {other:?}"), } } @@ -134,12 +132,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse ON expression let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?; LogicalPlanBuilder::from(left) - .join( - right, - join_type, - (Vec::::new(), Vec::::new()), - Some(expr), - )? + .join_on(right, join_type, Some(expr))? .build() } JoinConstraint::Using(idents) => { @@ -174,9 +167,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } } - JoinConstraint::None => Err(DataFusionError::NotImplemented( - "NONE constraint is not supported".to_string(), - )), + JoinConstraint::None => not_impl_err!("NONE constraint is not supported"), } } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 3cc0e5d77701d..b233f47a058fb 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -16,9 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ + not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference, +}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::TableFactor; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { name, alias, .. } => { - // normalize name and alias - let table_ref = self.object_name_to_table_reference(name)?; - let table_name = table_ref.to_string(); - let cte = planner_context.get_cte(&table_name); - ( - match ( - cte, - self.schema_provider.get_table_provider(table_ref.clone()), - ) { - (Some(cte_plan), _) => Ok(cte_plan.clone()), - (_, Ok(provider)) => { - LogicalPlanBuilder::scan(table_ref, provider, None)?.build() - } - (None, Err(e)) => Err(e), - }?, - alias, - ) + TableFactor::Table { + name, alias, args, .. + } => { + if let Some(func_args) = args { + let tbl_func_name = name.0.first().unwrap().value.to_string(); + let args = func_args + .into_iter() + .flat_map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg + { + self.sql_expr_to_logical_expr( + expr, + &DFSchema::empty(), + planner_context, + ) + } else { + plan_err!("Unsupported function argument type: {:?}", arg) + } + }) + .collect::>(); + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + let plan = LogicalPlanBuilder::scan( + TableReference::Bare { + table: std::borrow::Cow::Borrowed("tmp_table"), + }, + provider, + None, + )? + .build()?; + (plan, alias) + } else { + // normalize name and alias + let table_ref = self.object_name_to_table_reference(name)?; + let table_name = table_ref.to_string(); + let cte = planner_context.get_cte(&table_name); + ( + match ( + cte, + self.context_provider.get_table_source(table_ref.clone()), + ) { + (Some(cte_plan), _) => Ok(cte_plan.clone()), + (_, Ok(provider)) => { + LogicalPlanBuilder::scan(table_ref, provider, None)? + .build() + } + (None, Err(e)) => Err(e), + }?, + alias, + ) + } } TableFactor::Derived { subquery, alias, .. @@ -64,9 +100,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ), // @todo Support TableFactory::TableFunction? _ => { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "Unsupported ast node {relation:?} in create_relation" - ))); + ); } }; if let Some(alias) = alias { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index f12830df4254c..a0819e4aaf8e8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -15,12 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs, resolve_columns, resolve_positions_to_exprs, }; -use datafusion_common::{DataFusionError, Result}; + +use datafusion_common::Column; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, }; @@ -29,15 +35,14 @@ use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, }; -use datafusion_expr::Expr::Alias; use datafusion_expr::{ Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; - -use sqlparser::ast::{Distinct, Expr as SQLExpr, WildcardAdditionalOptions, WindowType}; +use sqlparser::ast::{ + Distinct, Expr as SQLExpr, GroupByExpr, ReplaceSelectItem, WildcardAdditionalOptions, + WindowType, +}; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; -use std::collections::HashSet; -use std::sync::Arc; impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logic plan from an SQL select @@ -48,19 +53,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // check for unsupported syntax first if !select.cluster_by.is_empty() { - return Err(DataFusionError::NotImplemented("CLUSTER BY".to_string())); + return not_impl_err!("CLUSTER BY"); } if !select.lateral_views.is_empty() { - return Err(DataFusionError::NotImplemented("LATERAL VIEWS".to_string())); + return not_impl_err!("LATERAL VIEWS"); } if select.qualify.is_some() { - return Err(DataFusionError::NotImplemented("QUALIFY".to_string())); + return not_impl_err!("QUALIFY"); } if select.top.is_some() { - return Err(DataFusionError::NotImplemented("TOP".to_string())); + return not_impl_err!("TOP"); } if !select.sort_by.is_empty() { - return Err(DataFusionError::NotImplemented("SORT BY".to_string())); + return not_impl_err!("SORT BY"); } // process `from` clause @@ -68,7 +73,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection(select.selection, plan, planner_context)?; + let base_plan = self.plan_selection(select.selection, plan, planner_context)?; // handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -76,16 +81,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( - &plan, + &base_plan, select.projection, empty_from, planner_context, )?; // having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; let mut combined_schema = (**projected_plan.schema()).clone(); - combined_schema.merge(plan.schema()); + combined_schema.merge(base_plan.schema()); // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); @@ -129,29 +134,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); // All of the group by expressions - let group_by_exprs = select - .group_by - .into_iter() - .map(|e| { - let group_by_expr = - self.sql_expr_to_logical_expr(e, &combined_schema, planner_context)?; - // aliases from the projection can conflict with same-named expressions in the input - let mut alias_map = alias_map.clone(); - for f in plan.schema().fields() { - alias_map.remove(f.name()); - } - let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; - let group_by_expr = - resolve_positions_to_exprs(&group_by_expr, &select_exprs) - .unwrap_or(group_by_expr); - let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; - self.validate_schema_satisfies_exprs( - plan.schema(), - &[group_by_expr.clone()], - )?; - Ok(group_by_expr) - }) - .collect::>>()?; + let group_by_exprs = if let GroupByExpr::Expressions(exprs) = select.group_by { + exprs + .into_iter() + .map(|e| { + let group_by_expr = self.sql_expr_to_logical_expr( + e, + &combined_schema, + planner_context, + )?; + // aliases from the projection can conflict with same-named expressions in the input + let mut alias_map = alias_map.clone(); + for f in base_plan.schema().fields() { + alias_map.remove(f.name()); + } + let group_by_expr = + resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; + let group_by_expr = + resolve_positions_to_exprs(&group_by_expr, &select_exprs) + .unwrap_or(group_by_expr); + let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; + self.validate_schema_satisfies_exprs( + base_plan.schema(), + &[group_by_expr.clone()], + )?; + Ok(group_by_expr) + }) + .collect::>>()? + } else { + // 'group by all' groups wrt. all select expressions except 'AggregateFunction's. + // Filter and collect non-aggregate select expressions + select_exprs + .iter() + .filter(|select_expr| match select_expr { + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } + _ => true, + }) + .cloned() + .collect() + }; // process group by, aggregation or having let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs @@ -159,17 +183,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { || !aggr_exprs.is_empty() { self.aggregate( - plan, + &base_plan, &select_exprs, having_expr_opt.as_ref(), - group_by_exprs, - aggr_exprs, + &group_by_exprs, + &aggr_exprs, )? } else { match having_expr_opt { - Some(having_expr) => return Err(DataFusionError::Plan( - format!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"))), - None => (plan, select_exprs, having_expr_opt) + Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) } }; @@ -202,21 +225,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = project(plan, select_exprs_post_aggr)?; // process distinct clause - let distinct = select - .distinct - .map(|distinct| match distinct { - Distinct::Distinct => Ok(true), - Distinct::On(_) => Err(DataFusionError::NotImplemented( - "DISTINCT ON Exprs not supported".to_string(), - )), - }) - .transpose()? - .unwrap_or(false); + let plan = match select.distinct { + None => Ok(plan), + Some(Distinct::Distinct) => { + LogicalPlanBuilder::from(plan).distinct()?.build() + } + Some(Distinct::On(on_expr)) => { + if !aggr_exprs.is_empty() + || !group_by_exprs.is_empty() + || !window_func_exprs.is_empty() + { + return not_impl_err!("DISTINCT ON expressions with GROUP BY, aggregation or window functions are not supported "); + } - let plan = if distinct { - LogicalPlanBuilder::from(plan).distinct()?.build() - } else { - Ok(plan) + let on_expr = on_expr + .into_iter() + .map(|e| { + self.sql_expr_to_logical_expr(e, plan.schema(), planner_context) + }) + .collect::>>()?; + + // Build the final plan + return LogicalPlanBuilder::from(base_plan) + .distinct_on(on_expr, select_exprs, None)? + .build(); + } }?; // DISTRIBUTE BY @@ -348,29 +381,59 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let expr = Alias(Box::new(col), self.normalizer.normalize(alias)); + let name = self.normalizer.normalize(alias); + // avoiding adding an alias if the column name is the same. + let expr = match &col { + Expr::Column(column) if column.name.eq(&name) => col, + _ => col.alias(name), + }; Ok(vec![expr]) } SelectItem::Wildcard(options) => { Self::check_wildcard_options(&options)?; if empty_from { - return Err(DataFusionError::Plan( - "SELECT * with no tables specified is not valid".to_string(), - )); + return plan_err!("SELECT * with no tables specified is not valid"); } // do not expand from outer schema - expand_wildcard(plan.schema().as_ref(), plan, Some(options)) + let expanded_exprs = + expand_wildcard(plan.schema().as_ref(), plan, Some(&options))?; + // If there is a REPLACE statement, replace that column with the given + // replace expression. Column name remains the same. + if let Some(replace) = options.opt_replace { + self.replace_columns( + plan, + empty_from, + planner_context, + expanded_exprs, + replace, + ) + } else { + Ok(expanded_exprs) + } } SelectItem::QualifiedWildcard(ref object_name, options) => { Self::check_wildcard_options(&options)?; let qualifier = format!("{object_name}"); // do not expand from outer schema - expand_qualified_wildcard( + let expanded_exprs = expand_qualified_wildcard( &qualifier, plan.schema().as_ref(), - Some(options), - ) + Some(&options), + )?; + // If there is a REPLACE statement, replace that column with the given + // replace expression. Column name remains the same. + if let Some(replace) = options.opt_replace { + self.replace_columns( + plan, + empty_from, + planner_context, + expanded_exprs, + replace, + ) + } else { + Ok(expanded_exprs) + } } } } @@ -381,18 +444,55 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { opt_exclude: _opt_exclude, opt_except: _opt_except, opt_rename, - opt_replace, + opt_replace: _opt_replace, } = options; - if opt_rename.is_some() || opt_replace.is_some() { + if opt_rename.is_some() { Err(DataFusionError::NotImplemented( - "wildcard * with RENAME or REPLACE not supported ".to_string(), + "wildcard * with RENAME not supported ".to_string(), )) } else { Ok(()) } } + /// If there is a REPLACE statement in the projected expression in the form of + /// "REPLACE (some_column_within_an_expr AS some_column)", this function replaces + /// that column with the given replace expression. Column name remains the same. + /// Multiple REPLACEs are also possible with comma separations. + fn replace_columns( + &self, + plan: &LogicalPlan, + empty_from: bool, + planner_context: &mut PlannerContext, + mut exprs: Vec, + replace: ReplaceSelectItem, + ) -> Result> { + for expr in exprs.iter_mut() { + if let Expr::Column(Column { name, .. }) = expr { + if let Some(item) = replace + .items + .iter() + .find(|item| item.column_name.value == *name) + { + let new_expr = self.sql_select_to_rex( + SelectItem::UnnamedExpr(item.expr.clone()), + plan, + empty_from, + planner_context, + )?[0] + .clone(); + *expr = Expr::Alias(Alias { + expr: Box::new(new_expr), + relation: None, + name: name.clone(), + }); + } + } + } + Ok(exprs) + } + /// Wrap a plan in a projection fn project(&self, input: LogicalPlan, expr: Vec) -> Result { self.validate_schema_satisfies_exprs(input.schema(), &expr)?; @@ -425,17 +525,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// the aggregate fn aggregate( &self, - input: LogicalPlan, + input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, - group_by_exprs: Vec, - aggr_exprs: Vec, + group_by_exprs: &[Expr], + aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; + let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { + &agg.group_expr + } else { + unreachable!(); + }; + // in this next section of code we are re-writing the projection to refer to columns // output by the aggregate plan. For example, if the projection contains the expression // `SUM(a)` then we replace that with a reference to a column `SUM(a)` produced by @@ -444,7 +550,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // combine the original grouping and aggregate expressions into one list (note that // we do not add the "having" expression since that is not part of the projection) let mut aggr_projection_exprs = vec![]; - for expr in &group_by_exprs { + for expr in group_by_exprs { match expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { aggr_projection_exprs.extend_from_slice(exprs) @@ -460,25 +566,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => aggr_projection_exprs.push(expr.clone()), } } - aggr_projection_exprs.extend_from_slice(&aggr_exprs); + aggr_projection_exprs.extend_from_slice(aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs .iter() - .map(|expr| resolve_columns(expr, &input)) + .map(|expr| resolve_columns(expr, input)) .collect::>>()?; // next we replace any expressions that are not a column with a column referencing // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() - .map(|expr| expr_as_column_expr(expr, &input)) + .map(|expr| expr_as_column_expr(expr, input)) .collect::>>()?; // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() - .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) + .map(|expr| rebase_expr(expr, &aggr_projection_exprs, input)) .collect::>>()?; // finally, we have some validation that the re-written projection can be resolved @@ -493,7 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // aggregation. let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = - rebase_expr(having_expr, &aggr_projection_exprs, &input)?; + rebase_expr(having_expr, &aggr_projection_exprs, input)?; check_columns_satisfy_exprs( &column_exprs_post_aggr, @@ -515,10 +621,10 @@ fn check_conflicting_windows(window_defs: &[NamedWindowDefinition]) -> Result<() for (i, window_def_i) in window_defs.iter().enumerate() { for window_def_j in window_defs.iter().skip(i + 1) { if window_def_i.0 == window_def_j.0 { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The window {} is defined multiple times!", window_def_i.0 - ))); + ); } } } @@ -547,9 +653,7 @@ fn match_window_definitions( } // All named windows must be defined with a WindowSpec. if let Some(WindowType::NamedWindow(ident)) = &f.over { - return Err(DataFusionError::Plan(format!( - "The window {ident} is not defined!" - ))); + return plan_err!("The window {ident} is not defined!"); } } } diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 48a4fdddc24be..7300d49be0f55 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -16,7 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier}; @@ -38,6 +38,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let all = match set_quantifier { SetQuantifier::All => true, SetQuantifier::Distinct | SetQuantifier::None => false, + SetQuantifier::ByName => { + return not_impl_err!("UNION BY NAME not implemented"); + } + SetQuantifier::AllByName => { + return not_impl_err!("UNION ALL BY NAME not implemented") + } + SetQuantifier::DistinctByName => { + return not_impl_err!("UNION DISTINCT BY NAME not implemented") + } }; let left_plan = self.set_expr_to_plan(*left, planner_context)?; @@ -64,9 +73,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } SetExpr::Query(q) => self.query_to_plan(*q, planner_context), - _ => Err(DataFusionError::NotImplemented(format!( - "Query {set_expr} not implemented yet" - ))), + _ => not_impl_err!("Query {set_expr} not implemented yet"), } } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 7f914c6b91368..12083554f0932 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -15,20 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::Arc; + use crate::parser::{ - CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, LexOrdering, - Statement as DFStatement, + CopyToSource, CopyToStatement, CreateExternalTable, DFParser, DescribeTableStmt, + ExplainStatement, LexOrdering, Statement as DFStatement, }; use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, }; use crate::utils::normalize_ident; + use arrow_schema::DataType; +use datafusion_common::file_options::StatementOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ExprSchema, - OwnedTableReference, Result, SchemaReference, TableReference, ToDFSchema, + not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, Column, + Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + Result, ScalarValue, SchemaReference, TableReference, ToDFSchema, }; +use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -44,15 +51,11 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SchemaName, - SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, - TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, + Assignment, ColumnDef, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, + SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, + TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, }; - -use datafusion_expr::expr::Placeholder; use sqlparser::parser::ParserError::ParserError; -use std::collections::{BTreeMap, HashMap, HashSet}; -use std::sync::Arc; fn ident_to_string(ident: &Ident) -> String { normalize_ident(ident.to_owned()) @@ -79,6 +82,54 @@ fn get_schema_name(schema_name: &SchemaName) -> String { } } +/// Construct `TableConstraint`(s) for the given columns by iterating over +/// `columns` and extracting individual inline constraint definitions. +fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { + let mut constraints = vec![]; + for column in columns { + for ast::ColumnOptionDef { name, option } in &column.options { + match option { + ast::ColumnOption::Unique { is_primary } => { + constraints.push(ast::TableConstraint::Unique { + name: name.clone(), + columns: vec![column.name.clone()], + is_primary: *is_primary, + }) + } + ast::ColumnOption::ForeignKey { + foreign_table, + referred_columns, + on_delete, + on_update, + } => constraints.push(ast::TableConstraint::ForeignKey { + name: name.clone(), + columns: vec![], + foreign_table: foreign_table.clone(), + referred_columns: referred_columns.to_vec(), + on_delete: *on_delete, + on_update: *on_update, + }), + ast::ColumnOption::Check(expr) => { + constraints.push(ast::TableConstraint::Check { + name: name.clone(), + expr: Box::new(expr.clone()), + }) + } + // Other options are not constraint related. + ast::ColumnOption::Default(_) + | ast::ColumnOption::Null + | ast::ColumnOption::NotNull + | ast::ColumnOption::DialectSpecific(_) + | ast::ColumnOption::CharacterSet(_) + | ast::ColumnOption::Generated { .. } + | ast::ColumnOption::Comment(_) + | ast::ColumnOption::OnUpdate(_) => {} + } + } + } + constraints +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an DataFusion SQL statement pub fn statement_to_plan(&self, statement: DFStatement) -> Result { @@ -87,16 +138,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { DFStatement::Statement(s) => self.sql_statement_to_plan(*s), DFStatement::DescribeTableStmt(s) => self.describe_table_to_plan(s), DFStatement::CopyTo(s) => self.copy_to_plan(s), + DFStatement::Explain(ExplainStatement { + verbose, + analyze, + statement, + }) => self.explain_to_plan(verbose, analyze, *statement), } } /// Generate a logical plan from an SQL statement pub fn sql_statement_to_plan(&self, statement: Statement) -> Result { - self.sql_statement_to_plan_with_context(statement, &mut PlannerContext::new()) + self.sql_statement_to_plan_with_context_impl( + statement, + &mut PlannerContext::new(), + ) } /// Generate a logical plan from an SQL statement - fn sql_statement_to_plan_with_context( + pub fn sql_statement_to_plan_with_context( + &self, + statement: Statement, + planner_context: &mut PlannerContext, + ) -> Result { + self.sql_statement_to_plan_with_context_impl(statement, planner_context) + } + + fn sql_statement_to_plan_with_context_impl( &self, statement: Statement, planner_context: &mut PlannerContext, @@ -110,7 +177,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { format: _, describe_alias: _, .. - } => self.explain_statement_to_plan(verbose, analyze, *statement), + } => { + self.explain_to_plan(verbose, analyze, DFStatement::Statement(statement)) + } Statement::Query(query) => self.query_to_plan(*query, planner_context), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), Statement::SetVariable { @@ -130,71 +199,89 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, .. - } if table_properties.is_empty() && with_options.is_empty() => match query { - Some(query) => { - let primary_key = Self::primary_key_from_constraints(&constraints)?; - - let plan = self.query_to_plan(*query, planner_context)?; - let input_schema = plan.schema(); - - let plan = if !columns.is_empty() { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; - if schema.fields().len() != input_schema.fields().len() { - return Err(DataFusionError::Plan(format!( + } if table_properties.is_empty() && with_options.is_empty() => { + // Merge inline constraints and existing constraints + let mut all_constraints = constraints; + let inline_constraints = calc_inline_constraints_from_columns(&columns); + all_constraints.extend(inline_constraints); + // Build column default values + let column_defaults = + self.build_column_defaults(&columns, planner_context)?; + match query { + Some(query) => { + let plan = self.query_to_plan(*query, planner_context)?; + let input_schema = plan.schema(); + + let plan = if !columns.is_empty() { + let schema = self.build_schema(columns)?.to_dfschema_ref()?; + if schema.fields().len() != input_schema.fields().len() { + return plan_err!( "Mismatch: {} columns specified, but result has {} columns", schema.fields().len(), input_schema.fields().len() - ))); - } - let input_fields = input_schema.fields(); - let project_exprs = schema - .fields() - .iter() - .zip(input_fields) - .map(|(field, input_field)| { - cast(col(input_field.name()), field.data_type().clone()) + ); + } + let input_fields = input_schema.fields(); + let project_exprs = schema + .fields() + .iter() + .zip(input_fields) + .map(|(field, input_field)| { + cast( + col(input_field.name()), + field.data_type().clone(), + ) .alias(field.name()) - }) - .collect::>(); - LogicalPlanBuilder::from(plan.clone()) - .project(project_exprs)? - .build()? - } else { - plan - }; - - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( - CreateMemoryTable { - name: self.object_name_to_table_reference(name)?, - primary_key, - input: Arc::new(plan), - if_not_exists, - or_replace, - }, - ))) - } + }) + .collect::>(); + LogicalPlanBuilder::from(plan.clone()) + .project(project_exprs)? + .build()? + } else { + plan + }; + + let constraints = Constraints::new_from_table_constraints( + &all_constraints, + plan.schema(), + )?; + + Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + CreateMemoryTable { + name: self.object_name_to_table_reference(name)?, + constraints, + input: Arc::new(plan), + if_not_exists, + or_replace, + column_defaults, + }, + ))) + } - None => { - let primary_key = Self::primary_key_from_constraints(&constraints)?; - - let schema = self.build_schema(columns)?.to_dfschema_ref()?; - let plan = EmptyRelation { - produce_one_row: false, - schema, - }; - let plan = LogicalPlan::EmptyRelation(plan); - - Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( - CreateMemoryTable { - name: self.object_name_to_table_reference(name)?, - primary_key, - input: Arc::new(plan), - if_not_exists, - or_replace, - }, - ))) + None => { + let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let plan = EmptyRelation { + produce_one_row: false, + schema, + }; + let plan = LogicalPlan::EmptyRelation(plan); + let constraints = Constraints::new_from_table_constraints( + &all_constraints, + plan.schema(), + )?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + CreateMemoryTable { + name: self.object_name_to_table_reference(name)?, + constraints, + input: Arc::new(plan), + if_not_exists, + or_replace, + column_defaults, + }, + ))) + } } - }, + } Statement::CreateView { or_replace, @@ -216,9 +303,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Statement::ShowCreate { obj_type, obj_name } => match obj_type { ShowCreateObject::Table => self.show_create_table_to_plan(obj_name), - _ => Err(DataFusionError::NotImplemented( - "Only `SHOW CREATE TABLE ...` statement is supported".to_string(), - )), + _ => { + not_impl_err!("Only `SHOW CREATE TABLE ...` statement is supported") + } }, Statement::CreateSchema { schema_name, @@ -248,6 +335,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { cascade, restrict: _, purge: _, + temporary: _, } => { // We don't support cascade and purge for now. // nor do we support multiple object names @@ -289,10 +377,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { cascade, schema: DFSchemaRef::new(DFSchema::empty()), })))}, - _ => Err(DataFusionError::NotImplemented( + _ => not_impl_err!( "Only `DROP TABLE/VIEW/SCHEMA ...` statement is supported currently" - .to_string(), - )), + ), } } Statement::Prepare { @@ -311,7 +398,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .with_prepare_param_data_types(data_types.clone()); // Build logical plan for inner statement of the prepare statement - let plan = self.sql_statement_to_plan_with_context( + let plan = self.sql_statement_to_plan_with_context_impl( *statement, &mut planner_context, )?; @@ -348,46 +435,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table, on, returning, + ignore, } => { if or.is_some() { - Err(DataFusionError::Plan( - "Inserts with or clauses not supported".to_owned(), - ))?; - } - if overwrite { - Err(DataFusionError::Plan( - "Insert overwrite is not supported".to_owned(), - ))?; + plan_err!("Inserts with or clauses not supported")?; } if partitioned.is_some() { - Err(DataFusionError::Plan( - "Partitioned inserts not yet supported".to_owned(), - ))?; + plan_err!("Partitioned inserts not yet supported")?; } if !after_columns.is_empty() { - Err(DataFusionError::Plan( - "After-columns clause not supported".to_owned(), - ))?; + plan_err!("After-columns clause not supported")?; } if table { - Err(DataFusionError::Plan( - "Table clause not supported".to_owned(), - ))?; + plan_err!("Table clause not supported")?; } if on.is_some() { - Err(DataFusionError::Plan( - "Insert-on clause not supported".to_owned(), - ))?; + plan_err!("Insert-on clause not supported")?; } if returning.is_some() { - Err(DataFusionError::Plan( - "Insert-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Insert-returning clause not supported")?; } + if ignore { + plan_err!("Insert-ignore clause not supported")?; + } + let Some(source) = source else { + plan_err!("Inserts without a source not supported")? + }; let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source) + self.insert_to_plan(table_name, columns, source, overwrite) } - Statement::Update { table, assignments, @@ -396,9 +472,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, } => { if returning.is_some() { - Err(DataFusionError::Plan( - "Update-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Update-returning clause not yet supported")?; } self.update_to_plan(table, assignments, from, selection) } @@ -409,28 +483,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { selection, returning, from, + order_by, + limit, } => { if !tables.is_empty() { - return Err(DataFusionError::NotImplemented( - "DELETE not supported".to_string(), - )); + plan_err!("DELETE
not supported")?; } if using.is_some() { - Err(DataFusionError::Plan( - "Using clause not supported".to_owned(), - ))?; + plan_err!("Using clause not supported")?; } + if returning.is_some() { - Err(DataFusionError::Plan( - "Delete-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Delete-returning clause not yet supported")?; } + + if !order_by.is_empty() { + plan_err!("Delete-order-by clause not yet supported")?; + } + + if limit.is_some() { + plan_err!("Delete-limit clause not yet supported")?; + } + let table_name = self.get_delete_target(from)?; self.delete_to_plan(table_name, selection) } - Statement::StartTransaction { modes } => { + Statement::StartTransaction { + modes, + begin: false, + } => { let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &ast::TransactionMode| match m { @@ -486,7 +569,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::Rollback { chain } => { + Statement::Rollback { chain, savepoint } => { + if savepoint.is_some() { + plan_err!("Savepoints not supported")?; + } let statement = PlanStatement::TransactionEnd(TransactionEnd { conclusion: TransactionConclusion::Rollback, chain, @@ -495,29 +581,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Statement(statement)) } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported SQL statement: {sql:?}" - ))), + _ => not_impl_err!("Unsupported SQL statement: {sql:?}"), } } fn get_delete_target(&self, mut from: Vec) -> Result { if from.len() != 1 { - return Err(DataFusionError::NotImplemented(format!( + return not_impl_err!( "DELETE FROM only supports single table, got {}: {from:?}", from.len() - ))); + ); } let table_factor = from.pop().unwrap(); if !table_factor.joins.is_empty() { - return Err(DataFusionError::NotImplemented( - "DELETE FROM only supports single table, got: joins".to_string(), - )); + return not_impl_err!("DELETE FROM only supports single table, got: joins"); } - let TableFactor::Table{name, ..} = table_factor.relation else { - return Err(DataFusionError::NotImplemented(format!( + let TableFactor::Table { name, .. } = table_factor.relation else { + return not_impl_err!( "DELETE FROM only supports single table, got: {table_factor:?}" - ))) + ); }; Ok(name) @@ -535,9 +617,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // we only support the basic "SHOW TABLES" // https://github.com/apache/arrow-datafusion/issues/3188 if db_name.is_some() || filter.is_some() || full || extended { - Err(DataFusionError::Plan( - "Unsupported parameters to SHOW TABLES".to_string(), - )) + plan_err!("Unsupported parameters to SHOW TABLES") } else { let query = "SELECT * FROM information_schema.tables;"; let mut rewrite = DFParser::parse_sql(query)?; @@ -545,10 +625,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.statement_to_plan(rewrite.pop_front().unwrap()) // length of rewrite is 1 } } else { - Err(DataFusionError::Plan( - "SHOW TABLES is not supported unless information_schema is enabled" - .to_string(), - )) + plan_err!("SHOW TABLES is not supported unless information_schema is enabled") } } @@ -559,21 +636,62 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let DescribeTableStmt { table_name } = statement; let table_ref = self.object_name_to_table_reference(table_name)?; - let table_source = self.schema_provider.get_table_provider(table_ref)?; + let table_source = self.context_provider.get_table_source(table_ref)?; let schema = table_source.schema(); + let output_schema = DFSchema::try_from(LogicalPlan::describe_schema()).unwrap(); + Ok(LogicalPlan::DescribeTable(DescribeTable { schema, - dummy_schema: DFSchemaRef::new(DFSchema::empty()), + output_schema: Arc::new(output_schema), })) } - fn copy_to_plan(&self, _statement: CopyToStatement) -> Result { - // TODO: implement as part of https://github.com/apache/arrow-datafusion/issues/5654 - Err(DataFusionError::NotImplemented( - "`COPY .. TO ..` statement is not yet supported".to_string(), - )) + fn copy_to_plan(&self, statement: CopyToStatement) -> Result { + // determine if source is table or query and handle accordingly + let copy_source = statement.source; + let input = match copy_source { + CopyToSource::Relation(object_name) => { + let table_ref = + self.object_name_to_table_reference(object_name.clone())?; + let table_source = self.context_provider.get_table_source(table_ref)?; + LogicalPlanBuilder::scan( + object_name_to_string(&object_name), + table_source, + None, + )? + .build()? + } + CopyToSource::Query(query) => { + self.query_to_plan(query, &mut PlannerContext::new())? + } + }; + + // TODO, parse options as Vec<(String, String)> to avoid this conversion + let options = statement + .options + .iter() + .map(|(s, v)| (s.to_owned(), v.to_string())) + .collect::>(); + + let mut statement_options = StatementOptions::new(options); + let file_format = statement_options.try_infer_file_type(&statement.target)?; + let single_file_output = + statement_options.take_bool_option("single_file_output")?; + + // COPY defaults to outputting a single file if not otherwise specified + let single_file_output = single_file_output.unwrap_or(true); + + let copy_options = CopyOptions::SQLOptions(statement_options); + + Ok(LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: statement.target, + file_format, + single_file_output, + copy_options, + })) } fn build_order_by( @@ -584,24 +702,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Provide a schema before specifying the order while creating a table." - .to_owned(), - )); + ); } let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: - let expr_vec = self.order_by_to_sort_expr(&expr, schema, planner_context)?; + let expr_vec = + self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.to_columns()?.iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: - return Err(DataFusionError::Plan(format!( - "Column {column} is not in schema" - ))); + return plan_err!("Column {column} is not in schema"); } } } @@ -630,33 +746,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs, unbounded, options, + constraints, } = statement; - // semantic checks - if file_type == "PARQUET" && !columns.is_empty() { - Err(DataFusionError::Plan( - "Column definitions can not be specified for PARQUET files.".into(), - ))?; - } + // Merge inline constraints and existing constraints + let mut all_constraints = constraints; + let inline_constraints = calc_inline_constraints_from_columns(&columns); + all_constraints.extend(inline_constraints); - if file_type != "CSV" - && file_type != "JSON" + if (file_type == "PARQUET" || file_type == "AVRO" || file_type == "ARROW") && file_compression_type != CompressionTypeVariant::UNCOMPRESSED { - Err(DataFusionError::Plan( - "File compression type can be specified for CSV/JSON files.".into(), - ))?; + plan_err!( + "File compression type cannot be set for PARQUET, AVRO, or ARROW files." + )?; } + let mut planner_context = PlannerContext::new(); + + let column_defaults = self + .build_column_defaults(&columns, &mut planner_context)? + .into_iter() + .collect(); + let schema = self.build_schema(columns)?; let df_schema = schema.to_dfschema_ref()?; let ordered_exprs = - self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; + self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; // External tables do not support schemas at the moment, so the name is just a table name let name = OwnedTableReference::bare(name); - + let constraints = + Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, @@ -672,19 +794,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: ordered_exprs, unbounded, options, + constraints, + column_defaults, }, ))) } /// Generate a plan for EXPLAIN ... that will print out a plan /// - fn explain_statement_to_plan( + /// Note this is the sqlparser explain statement, not the + /// datafusion `EXPLAIN` statement. + fn explain_to_plan( &self, verbose: bool, analyze: bool, - statement: Statement, + statement: DFStatement, ) -> Result { - let plan = self.sql_statement_to_plan(statement)?; + let plan = self.statement_to_plan(statement)?; + if matches!(plan, LogicalPlan::Explain(_)) { + return plan_err!("Nested EXPLAINs are not supported"); + } let plan = Arc::new(plan); let schema = LogicalPlan::explain_schema(); let schema = schema.to_dfschema_ref()?; @@ -709,29 +838,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn show_variable_to_plan(&self, variable: &[Ident]) -> Result { - let variable = object_name_to_string(&ObjectName(variable.to_vec())); - if !self.has_table("information_schema", "df_settings") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW [VARIABLE] is not supported unless information_schema is enabled" - .to_string(), - )); + ); } - let variable_lower = variable.to_lowercase(); + let verbose = variable + .last() + .map(|s| ident_to_string(s) == "verbose") + .unwrap_or(false); + let mut variable_vec = variable.to_vec(); + let mut columns: String = "name, value".to_owned(); + + if verbose { + columns = format!("{columns}, description"); + variable_vec = variable_vec.split_at(variable_vec.len() - 1).0.to_vec(); + } - let query = if variable_lower == "all" { + let variable = object_name_to_string(&ObjectName(variable_vec)); + let base_query = format!("SELECT {columns} FROM information_schema.df_settings"); + let query = if variable == "all" { // Add an ORDER BY so the output comes out in a consistent order - String::from( - "SELECT name, setting FROM information_schema.df_settings ORDER BY name", - ) - } else if variable_lower == "timezone" || variable_lower == "time.zone" { + format!("{base_query} ORDER BY name") + } else if variable == "timezone" || variable == "time.zone" { // we could introduce alias in OptionDefinition if this string matching thing grows - String::from("SELECT name, setting FROM information_schema.df_settings WHERE name = 'datafusion.execution.time_zone'") + format!("{base_query} WHERE name = 'datafusion.execution.time_zone'") } else { - format!( - "SELECT name, setting FROM information_schema.df_settings WHERE name = '{variable}'" - ) + format!("{base_query} WHERE name = '{variable}'") }; let mut rewrite = DFParser::parse_sql(&query)?; @@ -748,15 +882,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { value: Vec, ) -> Result { if local { - return Err(DataFusionError::NotImplemented( - "LOCAL is not supported".to_string(), - )); + return not_impl_err!("LOCAL is not supported"); } if hivevar { - return Err(DataFusionError::NotImplemented( - "HIVEVAR is not supported".to_string(), - )); + return not_impl_err!("HIVEVAR is not supported"); } let variable = object_name_to_string(variable); @@ -784,10 +914,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | Value::HexStringLiteral(_) | Value::Null | Value::Placeholder(_) => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }, // for capture signed number e.g. +8, -8 @@ -795,17 +922,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { UnaryOperator::Plus => format!("+{expr}"), UnaryOperator::Minus => format!("-{expr}"), _ => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }, _ => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }; @@ -825,12 +946,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; - let provider = self.schema_provider.get_table_provider(table_ref.clone())?; - let schema = (*provider.schema()).clone(); + let table_source = self.context_provider.get_table_source(table_ref.clone())?; + let schema = (*table_source.schema()).clone(); let schema = DFSchema::try_from(schema)?; - let scan = - LogicalPlanBuilder::scan(object_name_to_string(&table_name), provider, None)? - .build()?; + let scan = LogicalPlanBuilder::scan( + object_name_to_string(&table_name), + table_source, + None, + )? + .build()?; let mut planner_context = PlannerContext::new(); let source = match predicate_expr { @@ -866,52 +990,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from: Option, predicate_expr: Option, ) -> Result { - let table_name = match &table.relation { - TableFactor::Table { name, .. } => name.clone(), - _ => Err(DataFusionError::Plan( - "Cannot update non-table relation!".to_string(), - ))?, + let (table_name, table_alias) = match &table.relation { + TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), + _ => plan_err!("Cannot update non-table relation!")?, }; // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; - let provider = self - .schema_provider - .get_table_provider(table_name.clone())?; - let arrow_schema = (*provider.schema()).clone(); - let table_schema = Arc::new(DFSchema::try_from(arrow_schema)?); - let values = table_schema.fields().iter().map(|f| { - ( - f.name().clone(), - ast::Expr::Identifier(ast::Ident::from(f.name().as_str())), - ) - }); + let table_source = self.context_provider.get_table_source(table_name.clone())?; + let table_schema = Arc::new(DFSchema::try_from_qualified_schema( + table_name.clone(), + &table_source.schema(), + )?); // Overwrite with assignment expressions let mut planner_context = PlannerContext::new(); let mut assign_map = assignments .iter() .map(|assign| { - let col_name: &Ident = assign.id.iter().last().ok_or_else(|| { - DataFusionError::Plan("Empty column id".to_string()) - })?; + let col_name: &Ident = assign + .id + .iter() + .last() + .ok_or_else(|| plan_datafusion_err!("Empty column id"))?; // Validate that the assignment target column exists table_schema.field_with_unqualified_name(&col_name.value)?; Ok((col_name.value.clone(), assign.value.clone())) }) .collect::>>()?; - let values = values - .into_iter() - .map(|(k, v)| { - let val = assign_map.remove(&k).unwrap_or(v); - (k, val) - }) - .collect::>(); - - // Build scan - let from = from.unwrap_or(table); - let scan = self.plan_from_tables(vec![from], &mut planner_context)?; + // Build scan, join with from table if it exists. + let mut input_tables = vec![table]; + input_tables.extend(from); + let scan = self.plan_from_tables(input_tables, &mut planner_context)?; // Filter let source = match predicate_expr { @@ -919,43 +1030,59 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr( predicate_expr, - &table_schema, + scan.schema(), &mut planner_context, )?; let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[&table_schema]], + &[&[scan.schema()]], &[using_columns], )?; LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; - // Projection - let mut exprs = vec![]; - for (col_name, expr) in values.into_iter() { - let expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?; - let expr = match expr { - datafusion_expr::Expr::Placeholder(Placeholder { - ref id, - ref data_type, - }) => match data_type { + // Build updated values for each column, using the previous value if not modified + let exprs = table_schema + .fields() + .iter() + .map(|field| { + let expr = match assign_map.remove(field.name()) { + Some(new_value) => { + let mut expr = self.sql_to_expr( + new_value, + source.schema(), + &mut planner_context, + )?; + // Update placeholder's datatype to the type of the target column + if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr + { + placeholder.data_type = placeholder + .data_type + .take() + .or_else(|| Some(field.data_type().clone())); + } + // Cast to target column type, if necessary + expr.cast_to(field.data_type(), source.schema())? + } None => { - let dt = table_schema.data_type(&Column::from_name(&col_name))?; - datafusion_expr::Expr::Placeholder(Placeholder::new( - id.clone(), - Some(dt.clone()), - )) + // If the target table has an alias, use it to qualify the column name + if let Some(alias) = &table_alias { + datafusion_expr::Expr::Column(Column::new( + Some(self.normalizer.normalize(alias.name.clone())), + field.name(), + )) + } else { + datafusion_expr::Expr::Column(field.qualified_column()) + } } - Some(_) => expr, - }, - _ => expr, - }; - let expr = expr.alias(col_name); - exprs.push(expr); - } + }; + Ok(expr.alias(field.name())) + }) + .collect::>>()?; + let source = project(source, exprs)?; let plan = LogicalPlan::Dml(DmlStatement { @@ -972,33 +1099,52 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name: ObjectName, columns: Vec, source: Box, + overwrite: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; - let provider = self - .schema_provider - .get_table_provider(table_name.clone())?; - let arrow_schema = (*provider.schema()).clone(); + let table_source = self.context_provider.get_table_source(table_name.clone())?; + let arrow_schema = (*table_source.schema()).clone(); let table_schema = DFSchema::try_from(arrow_schema)?; - let fields = if columns.is_empty() { + // Get insert fields and target table's value indices + // + // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // derived from the j-th output of the source. + // + // if value_indices[i] = None, it means that the value of the i-th target table's column is + // not provided, and should be filled with a default value later. + let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table - table_schema.fields().clone() + ( + table_schema.fields().clone(), + (0..table_schema.fields().len()) + .map(Some) + .collect::>(), + ) } else { + let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns - .iter() - .map(|c| { - Ok(table_schema - .field_with_unqualified_name( - &self.normalizer.normalize(c.clone()), - )? - .clone()) + .into_iter() + .map(|c| self.normalizer.normalize(c)) + .enumerate() + .map(|(i, c)| { + let column_index = table_schema + .index_of_column_by_name(None, &c)? + .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; + if value_indices[column_index].is_some() { + return Err(DataFusionError::SchemaError( + datafusion_common::SchemaError::DuplicateUnqualifiedField { + name: c, + }, + )); + } else { + value_indices[column_index] = Some(i); + } + Ok(table_schema.field(column_index).clone()) }) .collect::>>()?; - // Validate no duplicate fields - let table_schema = - DFSchema::new_with_metadata(fields, table_schema.metadata().clone())?; - table_schema.fields().clone() + (fields, value_indices) }; // infer types for Values clause... other types should be resolvable the regular way @@ -1009,15 +1155,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let ast::Expr::Value(Value::Placeholder(name)) = val { let name = name.replace('$', "").parse::().map_err(|_| { - DataFusionError::Plan(format!( - "Can't parse placeholder: {name}" - )) + plan_datafusion_err!("Can't parse placeholder: {name}") })? - 1; let field = fields.get(idx).ok_or_else(|| { - DataFusionError::Plan(format!( + plan_datafusion_err!( "Placeholder ${} refers to a non existent column", idx + 1 - )) + ) })?; let dt = field.field().data_type().clone(); let _ = prepare_param_data_types.insert(name, dt); @@ -1032,27 +1176,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { PlannerContext::new().with_prepare_param_data_types(prepare_param_data_types); let source = self.query_to_plan(*source, &mut planner_context)?; if fields.len() != source.schema().fields().len() { - Err(DataFusionError::Plan( - "Column count doesn't match insert query!".to_owned(), - ))?; + plan_err!("Column count doesn't match insert query!")?; } - let exprs = fields - .iter() - .zip(source.schema().fields().iter()) - .map(|(target_field, source_field)| { - let expr = - datafusion_expr::Expr::Column(source_field.unqualified_column()) - .cast_to(target_field.data_type(), source.schema())? - .alias(target_field.name()); - Ok(expr) + + let exprs = value_indices + .into_iter() + .enumerate() + .map(|(i, value_index)| { + let target_field = table_schema.field(i); + let expr = match value_index { + Some(v) => { + let source_field = source.schema().field(v); + datafusion_expr::Expr::Column(source_field.qualified_column()) + .cast_to(target_field.data_type(), source.schema())? + } + // The value is not specified. Fill in the default value for the column. + None => table_source + .get_column_default(target_field.name()) + .cloned() + .unwrap_or_else(|| { + // If there is no default for the column, then the default is NULL + datafusion_expr::Expr::Literal(ScalarValue::Null) + }) + .cast_to(target_field.data_type(), &DFSchema::empty())?, + }; + Ok(expr.alias(target_field.name())) }) .collect::>>()?; let source = project(source, exprs)?; + let op = if overwrite { + WriteOp::InsertOverwrite + } else { + WriteOp::InsertInto + }; + let plan = LogicalPlan::Dml(DmlStatement { table_name, table_schema: Arc::new(table_schema), - op: WriteOp::Insert, + op, input: Arc::new(source), }); Ok(plan) @@ -1066,16 +1228,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter: Option, ) -> Result { if filter.is_some() { - return Err(DataFusionError::Plan( - "SHOW COLUMNS with WHERE or LIKE is not supported".to_string(), - )); + return plan_err!("SHOW COLUMNS with WHERE or LIKE is not supported"); } if !self.has_table("information_schema", "columns") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW COLUMNS is not supported unless information_schema is enabled" - .to_string(), - )); + ); } // Figure out the where clause let where_clause = object_name_to_qualifier( @@ -1085,7 +1244,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self.schema_provider.get_table_provider(table_ref)?; + let _ = self.context_provider.get_table_source(table_ref)?; // treat both FULL and EXTENDED as the same let select_list = if full || extended { @@ -1108,10 +1267,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { sql_table_name: ObjectName, ) -> Result { if !self.has_table("information_schema", "tables") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW CREATE TABLE is not supported unless information_schema is enabled" - .to_string(), - )); + ); } // Figure out the where clause let where_clause = object_name_to_qualifier( @@ -1121,7 +1279,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self.schema_provider.get_table_provider(table_ref)?; + let _ = self.context_provider.get_table_source(table_ref)?; let query = format!( "SELECT table_catalog, table_schema, table_name, definition FROM information_schema.views WHERE {where_clause}" @@ -1138,58 +1296,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: schema.into(), table: table.into(), }; - self.schema_provider - .get_table_provider(tables_reference) + self.context_provider + .get_table_source(tables_reference) .is_ok() } - - fn primary_key_from_constraints( - constraints: &[TableConstraint], - ) -> Result> { - let pk: Result>> = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { - columns, - is_primary, - .. - } => match is_primary { - true => Ok(columns), - false => Err(DataFusionError::Plan( - "Non-primary unique constraints are not supported".to_string(), - )), - }, - TableConstraint::ForeignKey { .. } => Err(DataFusionError::Plan( - "Foreign key constraints are not currently supported".to_string(), - )), - TableConstraint::Check { .. } => Err(DataFusionError::Plan( - "Check constraints are not currently supported".to_string(), - )), - TableConstraint::Index { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), - TableConstraint::FulltextOrSpatial { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), - }) - .collect(); - let pk = pk?; - let pk = match pk.as_slice() { - [] => return Ok(vec![]), - [pk] => pk, - _ => { - return Err(DataFusionError::Plan( - "Only one primary key is supported!".to_string(), - ))? - } - }; - let primary_key: Vec = pk - .iter() - .map(|c| Column { - relation: None, - name: c.value.clone(), - }) - .collect(); - Ok(primary_key) - } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 0929aec6e5eb4..616a2fc749328 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,31 +17,31 @@ //! SQL Utility Functions -use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE}; +use arrow_schema::{ + DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, +}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use sqlparser::ast::Ident; +use datafusion_common::{exec_err, internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{ - AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, GetIndexedField, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, ScalarUDF, - WindowFunction, -}; -use datafusion_expr::expr::{Cast, Sort}; +use datafusion_expr::expr::{Alias, GroupingSet, WindowFunction}; +use datafusion_expr::expr_vec_fmt; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; -use datafusion_expr::{Expr, LogicalPlan, TryCast}; +use datafusion_expr::{Expr, LogicalPlan}; use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - clone_with_replacement(expr, &|nested_expr| { + expr.clone().transform_up(&|nested_expr| { match nested_expr { Expr::Column(col) => { - let field = plan.schema().field_from_column(col)?; - Ok(Some(Expr::Column(field.qualified_column()))) + let field = plan.schema().field_from_column(&col)?; + Ok(Transformed::Yes(Expr::Column(field.qualified_column()))) } _ => { // keep recursing - Ok(None) + Ok(Transformed::No(nested_expr)) } } }) @@ -66,11 +66,11 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - clone_with_replacement(expr, &|nested_expr| { - if base_exprs.contains(nested_expr) { - Ok(Some(expr_as_column_expr(nested_expr, plan)?)) + expr.clone().transform_up(&|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) } else { - Ok(None) + Ok(Transformed::No(nested_expr)) } }) } @@ -84,9 +84,7 @@ pub(crate) fn check_columns_satisfy_exprs( ) -> Result<()> { columns.iter().try_for_each(|c| match c { Expr::Column(_) => Ok(()), - _ => Err(DataFusionError::Internal( - "Expr::Column are required".to_string(), - )), + _ => internal_err!("Expr::Column are required"), })?; let column_exprs = find_column_exprs(exprs); for e in &column_exprs { @@ -120,322 +118,23 @@ fn check_column_satisfies_expr( message_prefix: &str, ) -> Result<()> { if !columns.contains(expr) { - return Err(DataFusionError::Plan(format!( - "{}: Expression {:?} could not be resolved from available columns: {}", + return plan_err!( + "{}: Expression {} could not be resolved from available columns: {}", message_prefix, expr, - columns - .iter() - .map(|e| format!("{e}")) - .collect::>() - .join(", ") - ))); + expr_vec_fmt!(columns) + ); } Ok(()) } -/// Returns a cloned `Expr`, but any of the `Expr`'s in the tree may be -/// replaced/customized by the replacement function. -/// -/// The replacement function is called repeatedly with `Expr`, starting with -/// the argument `expr`, then descending depth-first through its -/// descendants. The function chooses to replace or keep (clone) each `Expr`. -/// -/// The function's return type is `Result>>`, where: -/// -/// * `Ok(Some(replacement_expr))`: A replacement `Expr` is provided; it is -/// swapped in at the particular node in the tree. Any nested `Expr` are -/// not subject to cloning/replacement. -/// * `Ok(None)`: A replacement `Expr` is not provided. The `Expr` is -/// recreated, with all of its nested `Expr`'s subject to -/// cloning/replacement. -/// * `Err(err)`: Any error returned by the function is returned as-is by -/// `clone_with_replacement()`. -fn clone_with_replacement(expr: &Expr, replacement_fn: &F) -> Result -where - F: Fn(&Expr) -> Result>, -{ - let replacement_opt = replacement_fn(expr)?; - - match replacement_opt { - // If we were provided a replacement, use the replacement. Do not - // descend further. - Some(replacement) => Ok(replacement), - // No replacement was provided, clone the node and recursively call - // clone_with_replacement() on any nested expressions. - None => match expr { - Expr::AggregateFunction(AggregateFunction { - fun, - args, - distinct, - filter, - order_by, - }) => Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - args.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - *distinct, - filter.clone(), - order_by.clone(), - ))), - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - }) => Ok(Expr::WindowFunction(WindowFunction::new( - fun.clone(), - args.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - partition_by - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - order_by - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - window_frame.clone(), - ))), - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => Ok(Expr::AggregateUDF(AggregateUDF::new( - fun.clone(), - args.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - filter.clone(), - order_by.clone(), - ))), - Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias( - Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - alias_name.clone(), - )), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => Ok(Expr::Between(Between::new( - Box::new(clone_with_replacement(expr, replacement_fn)?), - *negated, - Box::new(clone_with_replacement(low, replacement_fn)?), - Box::new(clone_with_replacement(high, replacement_fn)?), - ))), - Expr::InList(InList { - expr: nested_expr, - list, - negated, - }) => Ok(Expr::InList(InList::new( - Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - list.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - *negated, - ))), - Expr::BinaryExpr(BinaryExpr { left, right, op }) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(clone_with_replacement(left, replacement_fn)?), - *op, - Box::new(clone_with_replacement(right, replacement_fn)?), - ))) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - }) => Ok(Expr::Like(Like::new( - *negated, - Box::new(clone_with_replacement(expr, replacement_fn)?), - Box::new(clone_with_replacement(pattern, replacement_fn)?), - *escape_char, - ))), - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => Ok(Expr::ILike(Like::new( - *negated, - Box::new(clone_with_replacement(expr, replacement_fn)?), - Box::new(clone_with_replacement(pattern, replacement_fn)?), - *escape_char, - ))), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - }) => Ok(Expr::SimilarTo(Like::new( - *negated, - Box::new(clone_with_replacement(expr, replacement_fn)?), - Box::new(clone_with_replacement(pattern, replacement_fn)?), - *escape_char, - ))), - Expr::Case(case) => Ok(Expr::Case(Case::new( - match &case.expr { - Some(case_expr) => { - Some(Box::new(clone_with_replacement(case_expr, replacement_fn)?)) - } - None => None, - }, - case.when_then_expr - .iter() - .map(|(a, b)| { - Ok(( - Box::new(clone_with_replacement(a, replacement_fn)?), - Box::new(clone_with_replacement(b, replacement_fn)?), - )) - }) - .collect::>>()?, - match &case.else_expr { - Some(else_expr) => { - Some(Box::new(clone_with_replacement(else_expr, replacement_fn)?)) - } - None => None, - }, - ))), - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - Ok(Expr::ScalarFunction(ScalarFunction::new( - *fun, - args.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - ))) - } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - Ok(Expr::ScalarUDF(ScalarUDF::new( - fun.clone(), - args.iter() - .map(|arg| clone_with_replacement(arg, replacement_fn)) - .collect::>>()?, - ))) - } - Expr::Negative(nested_expr) => Ok(Expr::Negative(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::Not(nested_expr) => Ok(Expr::Not(Box::new(clone_with_replacement( - nested_expr, - replacement_fn, - )?))), - Expr::IsNotNull(nested_expr) => Ok(Expr::IsNotNull(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsNull(nested_expr) => Ok(Expr::IsNull(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsTrue(nested_expr) => Ok(Expr::IsTrue(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsFalse(nested_expr) => Ok(Expr::IsFalse(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsUnknown(nested_expr) => Ok(Expr::IsUnknown(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsNotTrue(nested_expr) => Ok(Expr::IsNotTrue(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsNotFalse(nested_expr) => Ok(Expr::IsNotFalse(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::IsNotUnknown(nested_expr) => Ok(Expr::IsNotUnknown(Box::new( - clone_with_replacement(nested_expr, replacement_fn)?, - ))), - Expr::Cast(Cast { expr, data_type }) => Ok(Expr::Cast(Cast::new( - Box::new(clone_with_replacement(expr, replacement_fn)?), - data_type.clone(), - ))), - Expr::TryCast(TryCast { - expr: nested_expr, - data_type, - }) => Ok(Expr::TryCast(TryCast::new( - Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - data_type.clone(), - ))), - Expr::Sort(Sort { - expr: nested_expr, - asc, - nulls_first, - }) => Ok(Expr::Sort(Sort::new( - Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - *asc, - *nulls_first, - ))), - Expr::Column { .. } - | Expr::OuterReferenceColumn(_, _) - | Expr::Literal(_) - | Expr::ScalarVariable(_, _) - | Expr::Exists { .. } - | Expr::ScalarSubquery(_) => Ok(expr.clone()), - Expr::InSubquery(InSubquery { - expr: nested_expr, - subquery, - negated, - }) => Ok(Expr::InSubquery(InSubquery::new( - Box::new(clone_with_replacement(nested_expr, replacement_fn)?), - subquery.clone(), - *negated, - ))), - Expr::Wildcard => Ok(Expr::Wildcard), - Expr::QualifiedWildcard { .. } => Ok(expr.clone()), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), - key.clone(), - ))) - } - Expr::GroupingSet(set) => match set { - GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( - exprs - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - ))), - GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube( - exprs - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - ))), - GroupingSet::GroupingSets(lists_of_exprs) => { - let mut new_lists_of_exprs = vec![]; - for exprs in lists_of_exprs { - new_lists_of_exprs.push( - exprs - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - ); - } - Ok(Expr::GroupingSet(GroupingSet::GroupingSets( - new_lists_of_exprs, - ))) - } - }, - Expr::Placeholder(Placeholder { id, data_type }) => Ok(Expr::Placeholder( - Placeholder::new(id.clone(), data_type.clone()), - )), - }, - } -} - /// Returns mapping of each alias (`String`) to the expression (`Expr`) it is /// aliasing. pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { exprs .iter() .filter_map(|expr| match expr { - Expr::Alias(nested_expr, alias_name) => { - Some((alias_name.clone(), *nested_expr.clone())) - } + Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())), _ => None, }) .collect::>() @@ -457,7 +156,7 @@ pub(crate) fn resolve_positions_to_exprs( let index = (position - 1) as usize; let select_expr = &select_exprs[index]; Some(match select_expr { - Expr::Alias(nested_expr, _alias_name) => *nested_expr.clone(), + Expr::Alias(Alias { expr, .. }) => *expr.clone(), _ => select_expr.clone(), }) } @@ -471,15 +170,15 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - clone_with_replacement(expr, &|nested_expr| match nested_expr { + expr.clone().transform_up(&|nested_expr| match nested_expr { Expr::Column(c) if c.relation.is_none() => { if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Some(aliased_expr.clone())) + Ok(Transformed::Yes(aliased_expr.clone())) } else { - Ok(None) + Ok(Transformed::No(Expr::Column(c))) } } - _ => Ok(None), + _ => Ok(Transformed::No(nested_expr)), }) } @@ -490,20 +189,13 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr .iter() .map(|expr| match expr { Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by), - Expr::Alias(expr, _) => { - // convert &Box to &T - match &**expr { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => { - Ok(partition_by) - } - expr => Err(DataFusionError::Execution(format!( - "Impossibly got non-window expr {expr:?}" - ))), + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(WindowFunction { partition_by, .. }) => { + Ok(partition_by) } - } - expr => Err(DataFusionError::Execution(format!( - "Impossibly got non-window expr {expr:?}" - ))), + expr => exec_err!("Impossibly got non-window expr {expr:?}"), + }, + expr => exec_err!("Impossibly got non-window expr {expr:?}"), }) .collect::>>()?; let result = all_partition_keys @@ -526,21 +218,22 @@ pub(crate) fn make_decimal_type( (Some(p), Some(s)) => (p as u8, s as i8), (Some(p), None) => (p as u8, 0), (None, Some(_)) => { - return Err(DataFusionError::Internal( - "Cannot specify only scale for decimal data type".to_string(), - )) + return plan_err!("Cannot specify only scale for decimal data type") } (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), }; - // Arrow decimal is i128 meaning 38 maximum decimal digits if precision == 0 - || precision > DECIMAL128_MAX_PRECISION + || precision > DECIMAL256_MAX_PRECISION || scale.unsigned_abs() > precision { - Err(DataFusionError::Internal(format!( - "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 38`, and `scale <= precision`." - ))) + plan_err!( + "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`." + ) + } else if precision > DECIMAL128_MAX_PRECISION + && precision <= DECIMAL256_MAX_PRECISION + { + Ok(DataType::Decimal256(precision, scale)) } else { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/sql_integration.rs similarity index 87% rename from datafusion/sql/tests/integration_test.rs rename to datafusion/sql/tests/sql_integration.rs index 7161fa481cbe1..48ba50145308c 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -26,9 +26,10 @@ use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, }; +use datafusion_common::{plan_err, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, - AggregateUDF, ScalarUDF, TableSource, + AggregateUDF, ScalarUDF, TableSource, WindowUDF, }; use datafusion_sql::{ parser::DFParser, @@ -37,13 +38,6 @@ use datafusion_sql::{ use rstest::rstest; -#[cfg(test)] -#[ctor::ctor] -fn init() { - // Enable RUST_LOG logging configuration for tests - let _ = env_logger::try_init(); -} - #[test] fn parse_decimals() { let test_data = [ @@ -60,7 +54,7 @@ fn parse_decimals() { ("18446744073709551615", "UInt64(18446744073709551615)"), ( "18446744073709551616", - "Decimal128(Some(18446744073709551616),38,0)", + "Decimal128(Some(18446744073709551616),20,0)", ), ]; for (a, b) in test_data { @@ -102,7 +96,7 @@ fn parse_ident_normalization() { ), ( "SELECT AGE FROM PERSON", - "Err(Plan(\"No table named: PERSON found\"))", + "Error during planning: No table named: PERSON found", false, ), ( @@ -127,7 +121,11 @@ fn parse_ident_normalization() { enable_ident_normalization, }, ); - assert_eq!(expected, format!("{plan:?}")); + if plan.is_ok() { + assert_eq!(expected, format!("{plan:?}")); + } else { + assert_eq!(expected, plan.unwrap_err().strip_backtrace()); + } } } @@ -197,32 +195,50 @@ fn try_cast_from_aggregation() { } #[test] -fn cast_to_invalid_decimal_type() { +fn cast_to_invalid_decimal_type_precision_0() { // precision == 0 { let sql = "SELECT CAST(10 AS DECIMAL(0))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Internal("Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 38`, and `scale <= precision`.")"##, - format!("{err:?}") + "Error during planning: Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`.", + err.strip_backtrace() ); } +} + +#[test] +fn cast_to_invalid_decimal_type_precision_gt_38() { // precision > 38 { let sql = "SELECT CAST(10 AS DECIMAL(39))"; + let plan = "Projection: CAST(Int64(10) AS Decimal256(39, 0))\n EmptyRelation"; + quick_test(sql, plan); + } +} + +#[test] +fn cast_to_invalid_decimal_type_precision_gt_76() { + // precision > 76 + { + let sql = "SELECT CAST(10 AS DECIMAL(79))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Internal("Decimal(precision = 39, scale = 0) should satisfy `0 < precision <= 38`, and `scale <= precision`.")"##, - format!("{err:?}") + "Error during planning: Decimal(precision = 79, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`.", + err.strip_backtrace() ); } +} + +#[test] +fn cast_to_invalid_decimal_type_precision_lt_scale() { // precision < scale { let sql = "SELECT CAST(10 AS DECIMAL(5, 10))"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Internal("Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 38`, and `scale <= precision`.")"##, - format!("{err:?}") + "Error during planning: Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 76`, and `scale <= precision`.", + err.strip_backtrace() ); } } @@ -231,13 +247,56 @@ fn cast_to_invalid_decimal_type() { fn plan_create_table_with_pk() { let sql = "create table person (id int, name string, primary key(id))"; let plan = r#" -CreateMemoryTable: Bare { table: "person" } primary_key=[id] +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] + EmptyRelation + "# + .trim(); + quick_test(sql, plan); + + let sql = "create table person (id int primary key, name string)"; + let plan = r#" +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] + EmptyRelation + "# + .trim(); + quick_test(sql, plan); + + let sql = + "create table person (id int, name string unique not null, primary key(id))"; + let plan = r#" +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), Unique([1])] + EmptyRelation + "# + .trim(); + quick_test(sql, plan); + + let sql = "create table person (id int, name varchar, primary key(name, id));"; + let plan = r#" +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([1, 0])] EmptyRelation "# .trim(); quick_test(sql, plan); } +#[test] +fn plan_create_table_with_multi_pk() { + let sql = "create table person (id int, name string primary key, primary key(id))"; + let plan = r#" +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), PrimaryKey([1])] + EmptyRelation + "# + .trim(); + quick_test(sql, plan); +} + +#[test] +fn plan_create_table_with_unique() { + let sql = "create table person (id int unique, name string)"; + let plan = "CreateMemoryTable: Bare { table: \"person\" } constraints=[Unique([0])]\n EmptyRelation"; + quick_test(sql, plan); +} + #[test] fn plan_create_table_no_pk() { let sql = "create table person (id int, name string)"; @@ -250,10 +309,9 @@ CreateMemoryTable: Bare { table: "person" } } #[test] -#[should_panic(expected = "Non-primary unique constraints are not supported")] fn plan_create_table_check_constraint() { let sql = "create table person (id int, name string, unique(id))"; - let plan = ""; + let plan = "CreateMemoryTable: Bare { table: \"person\" } constraints=[Unique([0])]\n EmptyRelation"; quick_test(sql, plan); } @@ -325,23 +383,58 @@ fn plan_rollback_transaction_chained() { } #[test] -fn plan_insert() { - let sql = - "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; +fn plan_copy_to() { + let sql = "COPY test_decimal to 'output.csv'"; + let plan = r#" +CopyTo: format=csv output_url=output.csv single_file_output=true options: () + TableScan: test_decimal + "# + .trim(); + quick_test(sql, plan); +} + +#[test] +fn plan_explain_copy_to() { + let sql = "EXPLAIN COPY test_decimal to 'output.csv'"; let plan = r#" -Dml: op=[Insert] table=[person] - Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name - Values: (Int64(1), Utf8("Alan"), Utf8("Turing")) +Explain + CopyTo: format=csv output_url=output.csv single_file_output=true options: () + TableScan: test_decimal "# .trim(); quick_test(sql, plan); } +#[test] +fn plan_copy_to_query() { + let sql = "COPY (select * from test_decimal limit 10) to 'output.csv'"; + let plan = r#" +CopyTo: format=csv output_url=output.csv single_file_output=true options: () + Limit: skip=0, fetch=10 + Projection: test_decimal.id, test_decimal.price + TableScan: test_decimal + "# + .trim(); + quick_test(sql, plan); +} + +#[test] +fn plan_insert() { + let sql = + "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; + let plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (Int64(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; + quick_test(sql, plan); +} + #[test] fn plan_insert_no_target_columns() { let sql = "INSERT INTO test_decimal VALUES (1, 2), (3, 4)"; let plan = r#" -Dml: op=[Insert] table=[test_decimal] +Dml: op=[Insert Into] table=[test_decimal] Projection: CAST(column1 AS Int32) AS id, CAST(column2 AS Decimal128(10, 2)) AS price Values: (Int64(1), Int64(2)), (Int64(3), Int64(4)) "# @@ -378,10 +471,14 @@ Dml: op=[Insert] table=[test_decimal] "INSERT INTO person (id, first_name, last_name) VALUES ($2, $4, $6)", "Error during planning: Placeholder type could not be resolved" )] +#[case::placeholder_type_unresolved( + "INSERT INTO person (id, first_name, last_name) VALUES ($id, $first_name, $last_name)", + "Error during planning: Can't parse placeholder: $id" +)] #[test] fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); - assert_eq!(err.to_string(), error) + assert_eq!(err.strip_backtrace(), error) } #[test] @@ -390,7 +487,7 @@ fn plan_update() { let plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, Utf8("Kay") AS last_name, person.age AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: id = Int64(1) + Filter: person.id = Int64(1) TableScan: person "# .trim(); @@ -433,8 +530,8 @@ fn select_repeated_column() { let sql = "SELECT age, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"person.age\" at position 0 and \"person.age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"person.age\" at position 0 and \"person.age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -443,8 +540,8 @@ fn select_wildcard_with_repeated_column() { let sql = "SELECT *, age FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"person.age\" at position 3 and \"person.age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"person.age\" at position 3 and \"person.age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -513,11 +610,9 @@ fn select_compound_filter() { #[test] fn test_timestamp_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + \n Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None))\ \n TableScan: person"; - quick_test(sql, expected); } @@ -616,9 +711,9 @@ fn select_nested_with_filters() { fn table_with_column_alias() { let sql = "SELECT a, b, c FROM lineitem l (a, b, c)"; - let expected = "Projection: a, b, c\ - \n Projection: l.l_item_id AS a, l.l_description AS b, l.price AS c\ - \n SubqueryAlias: l\ + let expected = "Projection: l.a, l.b, l.c\ + \n SubqueryAlias: l\ + \n Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c\ \n TableScan: lineitem"; quick_test(sql, expected); @@ -630,8 +725,8 @@ fn table_with_column_alias_number_cols() { FROM lineitem l (a, b)"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Source table contains 3 columns but only 2 names given as column alias\")", - format!("{err:?}") + "Error during planning: Source table contains 3 columns but only 2 names given as column alias", + err.strip_backtrace() ); } @@ -640,8 +735,8 @@ fn select_with_ambiguous_column() { let sql = "SELECT id FROM person a, person b"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", - format!("{err:?}") + "Schema error: Ambiguous reference to unqualified field id", + err.strip_backtrace() ); } @@ -794,8 +889,8 @@ fn select_with_having() { HAVING age > 100 AND age < 200"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references: person.age > Int64(100) AND person.age < Int64(200) must appear in the GROUP BY clause or be used in an aggregate function\")", - format!("{err:?}") + "Error during planning: HAVING clause references: person.age > Int64(100) AND person.age < Int64(200) must appear in the GROUP BY clause or be used in an aggregate function", + err.strip_backtrace() ); } @@ -806,8 +901,8 @@ fn select_with_having_referencing_column_not_in_select() { HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references: person.first_name = Utf8(\\\"M\\\") must appear in the GROUP BY clause or be used in an aggregate function\")", - format!("{err:?}") + "Error during planning: HAVING clause references: person.first_name = Utf8(\"M\") must appear in the GROUP BY clause or be used in an aggregate function", + err.strip_backtrace() ); } @@ -819,8 +914,8 @@ fn select_with_having_refers_to_invalid_column() { HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: person.id, MAX(person.age)\")", - format!("{err:?}") + "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: person.id, MAX(person.age)", + err.strip_backtrace() ); } @@ -831,8 +926,8 @@ fn select_with_having_referencing_column_nested_in_select_expression() { HAVING age > 100"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references: person.age > Int64(100) must appear in the GROUP BY clause or be used in an aggregate function\")", - format!("{err:?}") + "Error during planning: HAVING clause references: person.age > Int64(100) must appear in the GROUP BY clause or be used in an aggregate function", + err.strip_backtrace() ); } @@ -843,8 +938,8 @@ fn select_with_having_with_aggregate_not_in_select() { HAVING MAX(age) > 100"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projection references non-aggregate values: Expression person.first_name could not be resolved from available columns: MAX(person.age)\")", - format!("{err:?}") + "Error during planning: Projection references non-aggregate values: Expression person.first_name could not be resolved from available columns: MAX(person.age)", + err.strip_backtrace() ); } @@ -879,10 +974,8 @@ fn select_aggregate_with_having_referencing_column_not_in_select() { HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references non-aggregate values: \ - Expression person.first_name could not be resolved from available columns: \ - COUNT(UInt8(1))\")", - format!("{err:?}") + "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: COUNT(*)", + err.strip_backtrace() ); } @@ -1002,10 +1095,8 @@ fn select_aggregate_with_group_by_with_having_referencing_column_not_in_group_by HAVING MAX(age) > 10 AND last_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"HAVING clause references non-aggregate values: \ - Expression person.last_name could not be resolved from available columns: \ - person.first_name, MAX(person.age)\")", - format!("{err:?}") + "Error during planning: HAVING clause references non-aggregate values: Expression person.last_name could not be resolved from available columns: person.first_name, MAX(person.age)", + err.strip_backtrace() ); } @@ -1084,8 +1175,8 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() { GROUP BY first_name HAVING MAX(age) > 100 AND COUNT(*) < 50"; let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND COUNT(UInt8(1)) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(UInt8(1))]]\ + \n Filter: MAX(person.age) > Int64(100) AND COUNT(*) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1106,6 +1197,22 @@ fn select_binary_expr_nested() { quick_test(sql, expected); } +#[test] +fn select_at_arrow_operator() { + let sql = "SELECT left @> right from array"; + let expected = "Projection: array.left @> array.right\ + \n TableScan: array"; + quick_test(sql, expected); +} + +#[test] +fn select_arrow_at_operator() { + let sql = "SELECT left <@ right from array"; + let expected = "Projection: array.left <@ array.right\ + \n TableScan: array"; + quick_test(sql, expected); +} + #[test] fn select_wildcard_with_groupby() { quick_test( @@ -1156,8 +1263,8 @@ fn select_simple_aggregate_repeated_aggregate() { let sql = "SELECT MIN(age), MIN(age) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"MIN(person.age)\" at position 0 and \"MIN(person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"MIN(person.age)\" at position 0 and \"MIN(person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -1181,13 +1288,23 @@ fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { ); } +#[test] +fn select_simple_aggregate_respect_nulls() { + let sql = "SELECT MIN(age) RESPECT NULLS FROM person"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_contains!( + err.strip_backtrace(), + "This feature is not implemented: Null treatment in aggregate functions is not supported: RESPECT NULLS" + ); +} #[test] fn select_from_typed_string_values() { quick_test( "SELECT col1, col2 FROM (VALUES (TIMESTAMP '2021-06-10 17:01:00Z', DATE '2004-04-09')) as t (col1, col2)", - "Projection: col1, col2\ - \n Projection: t.column1 AS col1, t.column2 AS col2\ - \n SubqueryAlias: t\ + "Projection: t.col1, t.col2\ + \n SubqueryAlias: t\ + \n Projection: column1 AS col1, column2 AS col2\ \n Values: (CAST(Utf8(\"2021-06-10 17:01:00Z\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2004-04-09\") AS Date32))", ); } @@ -1197,8 +1314,8 @@ fn select_simple_aggregate_repeated_aggregate_with_repeated_aliases() { let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"MIN(person.age) AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"MIN(person.age) AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -1227,8 +1344,8 @@ fn select_simple_aggregate_with_groupby_with_aliases_repeated() { let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"person.state AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"person.state AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -1248,7 +1365,7 @@ fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!("Schema error: No field named doesnotexist. Valid fields are \"SUM(person.age)\", \ person.id, person.first_name, person.last_name, person.age, person.state, \ - person.salary, person.birth_date, person.\"😀\".", format!("{err}")); + person.salary, person.birth_date, person.\"😀\".", err.strip_backtrace()); } #[test] @@ -1263,20 +1380,8 @@ fn select_interval_out_of_range() { let sql = "SELECT INTERVAL '100000000000000000 day'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "ArrowError(ParseError(\"Parsed interval field value out of range: 0 months 100000000000000000 days 0 nanos\"))", - format!("{err:?}") - ); -} - -#[test] -fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.to_string(), - r#"Arrays with different types are not supported: "# + "Arrow error: Invalid argument error: Unable to represent 100000000000000000 days in a signed 32-bit integer", + err.strip_backtrace(), ); } @@ -1291,18 +1396,8 @@ fn recursive_ctes() { select * from numbers;"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r#"NotImplemented("Recursive CTEs are not supported")"#, - format!("{err:?}") - ); -} - -#[test] -fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#, - format!("{err:?}") + "This feature is not implemented: Recursive CTEs are not supported", + err.strip_backtrace() ); } @@ -1337,15 +1432,15 @@ fn select_simple_aggregate_with_groupby_position_out_of_range() { let sql = "SELECT state, MIN(age) FROM person GROUP BY 0"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(0), MIN(person.age)\")", - format!("{err:?}") + "Error during planning: Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(0), MIN(person.age)", + err.strip_backtrace() ); let sql2 = "SELECT state, MIN(age) FROM person GROUP BY 5"; let err2 = logical_plan(sql2).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(5), MIN(person.age)\")", - format!("{err2:?}") + "Error during planning: Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(5), MIN(person.age)", + err2.strip_backtrace() ); } @@ -1364,8 +1459,8 @@ fn select_simple_aggregate_with_groupby_aggregate_repeated() { let sql = "SELECT state, MIN(age), MIN(age) FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - r##"Plan("Projections require unique expression names but the expression \"MIN(person.age)\" at position 1 and \"MIN(person.age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.")"##, - format!("{err:?}") + "Error during planning: Projections require unique expression names but the expression \"MIN(person.age)\" at position 1 and \"MIN(person.age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.", + err.strip_backtrace() ); } @@ -1422,8 +1517,8 @@ fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_not_res let sql = "SELECT ((age + 1) / 2) * (age + 9), MIN(first_name) FROM person GROUP BY age + 1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)\")", - format!("{err:?}") + "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)", + err.strip_backtrace() ); } @@ -1431,8 +1526,9 @@ fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_not_res fn select_simple_aggregate_with_groupby_non_column_expression_and_its_column_selected() { let sql = "SELECT age, MIN(first_name) FROM person GROUP BY age + 1"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!("Plan(\"Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)\")", - format!("{err:?}") + assert_eq!( + "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)", + err.strip_backtrace() ); } @@ -1590,20 +1686,24 @@ fn select_order_by_multiple_index() { #[test] fn select_order_by_index_of_0() { let sql = "SELECT id FROM person ORDER BY 0"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"Order by index starts at 1 for column indexes\")", - format!("{err:?}") + "Error during planning: Order by index starts at 1 for column indexes", + err ); } #[test] fn select_order_by_index_oob() { let sql = "SELECT id FROM person ORDER BY 2"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"Order by column out of bounds, specified: 2, max: 1\")", - format!("{err:?}") + "Error during planning: Order by column out of bounds, specified: 2, max: 1", + err ); } @@ -1665,8 +1765,8 @@ fn select_group_by_columns_not_in_select() { #[test] fn select_group_by_count_star() { let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Projection: person.state, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.state, COUNT(*)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -1697,10 +1797,8 @@ fn select_7480_2() { let sql = "SELECT c1, c13, MIN(c12) FROM aggregate_test_100 GROUP BY c1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"Projection references non-aggregate values: \ - Expression aggregate_test_100.c13 could not be resolved from available columns: \ - aggregate_test_100.c1, MIN(aggregate_test_100.c12)\")", - format!("{err:?}") + "Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c13 could not be resolved from available columns: aggregate_test_100.c1, MIN(aggregate_test_100.c12)", + err.strip_backtrace() ); } @@ -1711,6 +1809,14 @@ fn create_external_table_csv() { quick_test(sql, expected); } +#[test] +fn create_external_table_with_pk() { + let sql = "CREATE EXTERNAL TABLE t(c1 int, primary key(c1)) STORED AS CSV LOCATION 'foo.csv'"; + let expected = + "CreateExternalTable: Bare { table: \"t\" } constraints=[PrimaryKey([0])]"; + quick_test(sql, expected); +} + #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; @@ -1754,6 +1860,7 @@ fn create_external_table_with_compression_type() { "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE BZIP2 LOCATION 'foo.csv.bz2'", "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON COMPRESSION TYPE GZIP LOCATION 'foo.json.gz'", "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON COMPRESSION TYPE BZIP2 LOCATION 'foo.json.bz2'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS NONSTANDARD COMPRESSION TYPE GZIP LOCATION 'foo.unk'", ]; for sql in sqls { let expected = "CreateExternalTable: Bare { table: \"t\" }"; @@ -1766,12 +1873,14 @@ fn create_external_table_with_compression_type() { "CREATE EXTERNAL TABLE t STORED AS AVRO COMPRESSION TYPE BZIP2 LOCATION 'foo.avro'", "CREATE EXTERNAL TABLE t STORED AS PARQUET COMPRESSION TYPE GZIP LOCATION 'foo.parquet'", "CREATE EXTERNAL TABLE t STORED AS PARQUET COMPRESSION TYPE BZIP2 LOCATION 'foo.parquet'", + "CREATE EXTERNAL TABLE t STORED AS ARROW COMPRESSION TYPE GZIP LOCATION 'foo.arrow'", + "CREATE EXTERNAL TABLE t STORED AS ARROW COMPRESSION TYPE BZIP2 LOCATION 'foo.arrow'", ]; for sql in sqls { let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Plan(\"File compression type can be specified for CSV/JSON files.\")", - format!("{err:?}") + "Error during planning: File compression type cannot be set for PARQUET, AVRO, or ARROW files.", + err.strip_backtrace() ); } } @@ -1779,11 +1888,15 @@ fn create_external_table_with_compression_type() { #[test] fn create_external_table_parquet() { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS PARQUET LOCATION 'foo.parquet'"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Plan(\"Column definitions can not be specified for PARQUET files.\")", - format!("{err:?}") - ); + let expected = "CreateExternalTable: Bare { table: \"t\" }"; + quick_test(sql, expected); +} + +#[test] +fn create_external_table_parquet_sort_order() { + let sql = "create external table foo(a varchar, b varchar, c timestamp) stored as parquet location '/tmp/foo' with order (c)"; + let expected = "CreateExternalTable: Bare { table: \"foo\" }"; + quick_test(sql, expected); } #[test] @@ -1951,24 +2064,6 @@ fn union_all() { quick_test(sql, expected); } -#[test] -fn union_4_combined_in_one() { - let sql = "SELECT order_id from orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders - UNION ALL SELECT order_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); -} - #[test] fn union_with_different_column_names() { let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders"; @@ -1994,13 +2089,12 @@ fn union_values_with_no_alias() { #[test] fn union_with_incompatible_data_type() { let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"UNION Column Int64(1) (type: Int64) is \ - not compatible with column IntervalMonthDayNano\ - (\\\"950737950189618795196236955648\\\") \ - (type: Interval(MonthDayNano))\")", - format!("{err:?}") + "Error during planning: UNION Column Int64(1) (type: Int64) is not compatible with column IntervalMonthDayNano(\"950737950189618795196236955648\") (type: Interval(MonthDayNano))", + err ); } @@ -2103,10 +2197,12 @@ fn union_with_aliases() { #[test] fn union_with_incompatible_data_types() { let sql = "SELECT 'a' a UNION ALL SELECT true a"; - let err = logical_plan(sql).expect_err("query should have failed"); + let err = logical_plan(sql) + .expect_err("query should have failed") + .strip_backtrace(); assert_eq!( - "Plan(\"UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)\")", - format!("{err:?}") + "Error during planning: UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)", + err ); } @@ -2582,7 +2678,7 @@ fn prepare_stmt_quick_test( fn prepare_stmt_replace_params_quick_test( plan: LogicalPlan, - param_values: Vec, + param_values: impl Into, expected_plan: &str, ) -> LogicalPlan { // replace params @@ -2599,7 +2695,7 @@ struct MockContextProvider { } impl ContextProvider for MockContextProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_source(&self, name: TableReference) -> Result> { let schema = match name.table() { "test" => Ok(Schema::new(vec![ Field::new("t_date32", DataType::Date32, false), @@ -2643,6 +2739,18 @@ impl ContextProvider for MockContextProvider { Field::new("price", DataType::Float64, false), Field::new("delivered", DataType::Boolean, false), ])), + "array" => Ok(Schema::new(vec![ + Field::new( + "left", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "right", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])), "lineitem" => Ok(Schema::new(vec![ Field::new("l_item_id", DataType::UInt32, false), Field::new("l_description", DataType::Utf8, false), @@ -2667,10 +2775,7 @@ impl ContextProvider for MockContextProvider { Field::new("Id", DataType::UInt32, false), Field::new("lower", DataType::UInt32, false), ])), - _ => Err(DataFusionError::Plan(format!( - "No table named: {} found", - name.table() - ))), + _ => plan_err!("No table named: {} found", name.table()), }; match schema { @@ -2691,6 +2796,10 @@ impl ContextProvider for MockContextProvider { unimplemented!() } + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + fn options(&self) -> &ConfigOptions { &self.options } @@ -2737,7 +2846,31 @@ fn cte_use_same_name_multiple_times() { let expected = "SQL error: ParserError(\"WITH query name \\\"a\\\" specified more than once\")"; let result = logical_plan(sql).err().unwrap(); - assert_eq!(result.to_string(), expected); + assert_eq!(result.strip_backtrace(), expected); +} + +#[test] +fn negative_interval_plus_interval_in_projection() { + let sql = "select -interval '2 days' + interval '5 days';"; + let expected = + "Projection: IntervalMonthDayNano(\"79228162477370849446124847104\") + IntervalMonthDayNano(\"92233720368547758080\")\n EmptyRelation"; + quick_test(sql, expected); +} + +#[test] +fn complex_interval_expression_in_projection() { + let sql = "select -interval '2 days' + interval '5 days'+ (-interval '3 days' + interval '5 days');"; + let expected = + "Projection: IntervalMonthDayNano(\"79228162477370849446124847104\") + IntervalMonthDayNano(\"92233720368547758080\") + IntervalMonthDayNano(\"79228162458924105372415295488\") + IntervalMonthDayNano(\"92233720368547758080\")\n EmptyRelation"; + quick_test(sql, expected); +} + +#[test] +fn negative_sum_intervals_in_projection() { + let sql = "select -((interval '2 days' + interval '5 days') + -(interval '4 days' + interval '7 days'));"; + let expected = + "Projection: (- IntervalMonthDayNano(\"36893488147419103232\") + IntervalMonthDayNano(\"92233720368547758080\") + (- IntervalMonthDayNano(\"73786976294838206464\") + IntervalMonthDayNano(\"129127208515966861312\")))\n EmptyRelation"; + quick_test(sql, expected); } #[test] @@ -2884,8 +3017,8 @@ fn scalar_subquery_reference_outer_field() { let expected = "Projection: j1.j1_string, j2.j2_string\ \n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < ()\ \n Subquery:\ - \n Projection: COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Projection: COUNT(*)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ \n CrossJoin:\ \n TableScan: j1\ @@ -2940,9 +3073,9 @@ fn cte_with_column_names() { ) \ SELECT * FROM numbers;"; - let expected = "Projection: a, b, c\ - \n Projection: numbers.Int64(1) AS a, numbers.Int64(2) AS b, numbers.Int64(3) AS c\ - \n SubqueryAlias: numbers\ + let expected = "Projection: numbers.a, numbers.b, numbers.c\ + \n SubqueryAlias: numbers\ + \n Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c\ \n Projection: Int64(1), Int64(2), Int64(3)\ \n EmptyRelation"; @@ -2958,9 +3091,9 @@ fn cte_with_column_aliases_precedence() { ) \ SELECT * FROM numbers;"; - let expected = "Projection: a, b, c\ - \n Projection: numbers.x AS a, numbers.y AS b, numbers.z AS c\ - \n SubqueryAlias: numbers\ + let expected = "Projection: numbers.a, numbers.b, numbers.c\ + \n SubqueryAlias: numbers\ + \n Projection: x AS a, y AS b, z AS c\ \n Projection: Int64(1) AS x, Int64(2) AS y, Int64(3) AS z\ \n EmptyRelation"; quick_test(sql, expected) @@ -2976,15 +3109,15 @@ fn cte_unbalanced_number_of_columns() { let expected = "Error during planning: Source table contains 3 columns but only 1 names given as column alias"; let result = logical_plan(sql).err().unwrap(); - assert_eq!(result.to_string(), expected); + assert_eq!(result.strip_backtrace(), expected); } #[test] fn aggregate_with_rollup() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -2993,8 +3126,8 @@ fn aggregate_with_rollup() { fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3025,8 +3158,8 @@ fn rank_partition_grouping() { fn aggregate_with_cube() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3042,8 +3175,8 @@ fn round_decimal() { #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3107,7 +3240,7 @@ fn order_by_ambiguous_name() { let expected = "Schema error: Ambiguous reference to unqualified field age"; let err = logical_plan(sql).unwrap_err(); - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); } #[test] @@ -3116,7 +3249,7 @@ fn group_by_ambiguous_name() { let expected = "Schema error: Ambiguous reference to unqualified field age"; let err = logical_plan(sql).unwrap_err(); - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); } #[test] @@ -3379,7 +3512,7 @@ fn test_select_distinct_order_by() { let result = logical_plan(sql); assert!(result.is_err()); let err = result.err().unwrap(); - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); } #[rstest] @@ -3406,20 +3539,31 @@ fn test_select_distinct_order_by() { #[test] fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) { let err = logical_plan(sql).unwrap_err(); - assert_eq!(err.to_string(), error) + assert_eq!(err.strip_backtrace(), error) } #[test] fn select_order_by_with_cast() { let sql = "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ - \n Projection: first_name AS first_name\ - \n Projection: person.first_name AS first_name\ + let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ + \n Projection: person.first_name\ + \n Projection: person.first_name\ \n TableScan: person"; quick_test(sql, expected); } +#[test] +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. @@ -3463,7 +3607,7 @@ fn test_ambiguous_column_references_in_on_join() { let result = logical_plan(sql); assert!(result.is_err()); let err = result.err().unwrap(); - assert_eq!(err.to_string(), expected); + assert_eq!(err.strip_backtrace(), expected); } #[test] @@ -3483,41 +3627,45 @@ fn test_ambiguous_column_references_with_in_using_join() { } #[test] -#[should_panic(expected = "value: Plan(\"Invalid placeholder, not a number: $foo\"")] fn test_prepare_statement_to_plan_panic_param_format() { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - logical_plan(sql).unwrap(); + assert_eq!( + logical_plan(sql).unwrap_err().strip_backtrace(), + "Error during planning: Invalid placeholder, not a number: $foo" + ); } #[test] -#[should_panic( - expected = "value: Plan(\"Invalid placeholder, zero is not a valid index: $0\"" -)] fn test_prepare_statement_to_plan_panic_param_zero() { // param is zero following the $ sign // panic due to error returned from the parser let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; - logical_plan(sql).unwrap(); + assert_eq!( + logical_plan(sql).unwrap_err().strip_backtrace(), + "Error during planning: Invalid placeholder, zero is not a valid index: $0" + ); } #[test] -#[should_panic(expected = "value: SQL(ParserError(\"Expected AS, found: SELECT\"))")] fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - logical_plan(sql).unwrap(); + assert_eq!( + logical_plan(sql).unwrap_err().strip_backtrace(), + "SQL error: ParserError(\"Expected AS, found: SELECT\")" + ) } #[test] -#[should_panic( - expected = "value: SchemaError(FieldNotFound { field: Column { relation: None, name: \"id\" }, valid_fields: [] })" -)] fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - logical_plan(sql).unwrap(); + assert_eq!( + logical_plan(sql).unwrap_err().strip_backtrace(), + "Schema error: No field named id." + ) } #[test] @@ -3533,6 +3681,19 @@ fn test_prepare_statement_should_infer_types() { assert_eq!(actual_types, expected_types); } +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + #[test] #[should_panic( expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" @@ -3580,7 +3741,7 @@ fn test_prepare_statement_to_plan_no_param() { /////////////////// // replace params with values - let param_values = vec![]; + let param_values: Vec = vec![]; let expected_plan = "Projection: person.id, person.age\ \n Filter: person.age = Int64(10)\ \n TableScan: person"; @@ -3589,41 +3750,48 @@ fn test_prepare_statement_to_plan_no_param() { } #[test] -#[should_panic(expected = "value: Internal(\"Expected 1 parameters, got 0\")")] fn test_prepare_statement_to_plan_one_param_no_value_panic() { // no embedded parameter but still declare it let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 - let param_values = vec![]; - let expected_plan = "whatever"; - prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + let param_values: Vec = vec![]; + assert_eq!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + "Error during planning: Expected 1 parameters, got 0" + ); } #[test] -#[should_panic( - expected = "value: Internal(\"Expected parameter of type Int32, got Float64 at index 0\")" -)] fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { // no embedded parameter but still declare it let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 let param_values = vec![ScalarValue::Float64(Some(20.0))]; - let expected_plan = "whatever"; - prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + assert_eq!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + "Error during planning: Expected parameter of type Int32, got Float64 at index 0" + ); } #[test] -#[should_panic(expected = "value: Internal(\"Expected 0 parameters, got 1\")")] fn test_prepare_statement_to_plan_no_param_on_value_panic() { // no embedded parameter but still declare it let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; let plan = logical_plan(sql).unwrap(); // declare 1 param but provide 0 let param_values = vec![ScalarValue::Int32(Some(10))]; - let expected_plan = "whatever"; - prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + assert_eq!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + "Error during planning: Expected 0 parameters, got 1" + ); } #[test] @@ -3700,7 +3868,7 @@ Projection: person.id, orders.order_id assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, orders.order_id Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) @@ -3732,7 +3900,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; + let param_values = vec![ScalarValue::Int32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = Int32(10) @@ -3744,6 +3912,41 @@ Projection: person.id, person.age prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + "# + .trim(); + + let expected_dt = "[Int32]"; + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int32)), + ]); + assert_eq!(actual_types, expected_types); + + // replace params with values + let param_values = + vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))].into(); + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + .trim(); + let plan = plan.replace_params_with_values(¶m_values).unwrap(); + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_infer_types_subquery() { let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; @@ -3768,7 +3971,7 @@ Projection: person.id, person.age assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; + let param_values = vec![ScalarValue::UInt32(Some(10))].into(); let expected_plan = r#" Projection: person.id, person.age Filter: person.age = () @@ -3792,7 +3995,7 @@ fn test_prepare_statement_update_infer() { let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: id = $2 + Filter: person.id = $2 TableScan: person "# .trim(); @@ -3808,11 +4011,12 @@ Dml: op=[Update] table=[person] assert_eq!(actual_types, expected_types); // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; + let param_values = + vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))].into(); let expected_plan = r#" Dml: op=[Update] table=[person] Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: id = UInt32(1) + Filter: person.id = UInt32(1) TableScan: person "# .trim(); @@ -3825,12 +4029,11 @@ Dml: op=[Update] table=[person] fn test_prepare_statement_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let expected_plan = r#" -Dml: op=[Insert] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: ($1, $2, $3) - "# - .trim(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: ($1, $2, $3)"; let expected_dt = "[Int32]"; let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); @@ -3846,15 +4049,15 @@ Dml: op=[Insert] table=[person] // replace params with values let param_values = vec![ ScalarValue::UInt32(Some(1)), - ScalarValue::Utf8(Some("Alan".to_string())), - ScalarValue::Utf8(Some("Turing".to_string())), - ]; - let expected_plan = r#" -Dml: op=[Insert] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name - Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) - "# - .trim(); + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + .into(); + let expected_plan = "Dml: op=[Insert Into] table=[person]\ + \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ + CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ + CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ + \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -3928,11 +4131,11 @@ fn test_prepare_statement_to_plan_multi_params() { // replace params with values let param_values = vec![ ScalarValue::Int32(Some(10)), - ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::from("abc"), ScalarValue::Float64(Some(100.0)), ScalarValue::Int32(Some(20)), ScalarValue::Float64(Some(200.0)), - ScalarValue::Utf8(Some("xyz".to_string())), + ScalarValue::from("xyz"), ]; let expected_plan = "Projection: person.id, person.age, Utf8(\"xyz\")\ @@ -3986,9 +4189,9 @@ fn test_prepare_statement_to_plan_value_list() { let sql = "PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter);"; let expected_plan = "Prepare: \"my_plan\" [Utf8, Utf8] \ - \n Projection: num, letter\ - \n Projection: t.column1 AS num, t.column2 AS letter\ - \n SubqueryAlias: t\ + \n Projection: t.num, t.letter\ + \n SubqueryAlias: t\ + \n Projection: column1 AS num, column2 AS letter\ \n Values: (Int64(1), $1), (Int64(2), $2)"; let expected_dt = "[Utf8, Utf8]"; @@ -3998,12 +4201,12 @@ fn test_prepare_statement_to_plan_value_list() { /////////////////// // replace params with values let param_values = vec![ - ScalarValue::Utf8(Some("a".to_string())), - ScalarValue::Utf8(Some("b".to_string())), + ScalarValue::from("a".to_string()), + ScalarValue::from("b".to_string()), ]; - let expected_plan = "Projection: num, letter\ - \n Projection: t.column1 AS num, t.column2 AS letter\ - \n SubqueryAlias: t\ + let expected_plan = "Projection: t.num, t.letter\ + \n SubqueryAlias: t\ + \n Projection: column1 AS num, column2 AS letter\ \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -4034,9 +4237,9 @@ fn test_table_alias() { (select age from person) t2 \ ) as f (c1, c2)"; - let expected = "Projection: c1, c2\ - \n Projection: f.id AS c1, f.age AS c2\ - \n SubqueryAlias: f\ + let expected = "Projection: f.c1, f.c2\ + \n SubqueryAlias: f\ + \n Projection: t1.id AS c1, t2.age AS c2\ \n CrossJoin:\ \n SubqueryAlias: t1\ \n Projection: person.id\ @@ -4125,3 +4328,10 @@ impl TableSource for EmptyTable { self.table_schema.clone() } } + +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for tests + let _ = env_logger::try_init(); +} diff --git a/datafusion/sqllogictest/.gitignore b/datafusion/sqllogictest/.gitignore new file mode 100644 index 0000000000000..e90171b0acca2 --- /dev/null +++ b/datafusion/sqllogictest/.gitignore @@ -0,0 +1,2 @@ +*.py +test_files/tpch/data \ No newline at end of file diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml new file mode 100644 index 0000000000000..436c6159e7a36 --- /dev/null +++ b/datafusion/sqllogictest/Cargo.toml @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +name = "datafusion-sqllogictest" +readme = "README.md" +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lib] +name = "datafusion_sqllogictest" +path = "src/lib.rs" + +[dependencies] +arrow = { workspace = true } +async-trait = { workspace = true } +bigdecimal = { workspace = true } +bytes = { version = "1.4.0", optional = true } +chrono = { workspace = true, optional = true } +datafusion = { path = "../core", version = "33.0.0" } +datafusion-common = { workspace = true } +futures = { version = "0.3.28" } +half = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } +object_store = { workspace = true } +postgres-protocol = { version = "0.6.4", optional = true } +postgres-types = { version = "0.2.4", optional = true } +rust_decimal = { version = "1.27.0" } +sqllogictest = "0.19.0" +sqlparser = { workspace = true } +tempfile = { workspace = true } +thiserror = { workspace = true } +tokio = { version = "1.0" } +tokio-postgres = { version = "0.7.7", optional = true } + +[features] +avro = ["datafusion/avro"] +postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] + +[dev-dependencies] +env_logger = { workspace = true } +num_cpus = { workspace = true } + +[[test]] +harness = false +name = "sqllogictests" +path = "bin/sqllogictests.rs" diff --git a/datafusion/core/tests/sqllogictests/README.md b/datafusion/sqllogictest/README.md similarity index 68% rename from datafusion/core/tests/sqllogictests/README.md rename to datafusion/sqllogictest/README.md index 3ce00bf8d55a0..bda00a2dce0f8 100644 --- a/datafusion/core/tests/sqllogictests/README.md +++ b/datafusion/sqllogictest/README.md @@ -17,40 +17,53 @@ under the License. --> -#### Overview +# DataFusion sqllogictest -This is the Datafusion implementation of [sqllogictest](https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). We -use [sqllogictest-rs](https://github.com/risinglightdb/sqllogictest-rs) as a parser/runner of `.slt` files -in [`test_files`](test_files). +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -#### Running tests: TLDR Examples +This crate is a submodule of DataFusion that contains an implementation of [sqllogictest](https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). + +[df]: https://crates.io/crates/datafusion + +## Overview + +This crate uses [sqllogictest-rs](https://github.com/risinglightdb/sqllogictest-rs) to parse and run `.slt` files in the +[`test_files`](test_files) directory of this crate. + +## Testing setup + +1. `rustup update stable` DataFusion uses the latest stable release of rust +2. `git submodule init` +3. `git submodule update` + +## Running tests: TLDR Examples ```shell # Run all tests -cargo test -p datafusion --test sqllogictests +cargo test --test sqllogictests ``` ```shell # Run all tests, with debug logging enabled -RUST_LOG=debug cargo test -p datafusion --test sqllogictests +RUST_LOG=debug cargo test --test sqllogictests ``` ```shell # Run only the tests in `information_schema.slt` -cargo test -p datafusion --test sqllogictests -- information_schema +cargo test --test sqllogictests -- information_schema ``` ```shell # Automatically update ddl.slt with expected output -cargo test -p datafusion --test sqllogictests -- ddl --complete +cargo test --test sqllogictests -- ddl --complete ``` ```shell # Run ddl.slt, printing debug logging to stdout -RUST_LOG=debug cargo test -p datafusion --test sqllogictests -- ddl +RUST_LOG=debug cargo test --test sqllogictests -- ddl ``` -#### Cookbook: Adding Tests +## Cookbook: Adding Tests 1. Add queries @@ -70,7 +83,7 @@ SELECT * from foo; Running the following command will update `my_awesome_test.slt` with the expected output: ```shell -cargo test -p datafusion --test sqllogictests -- my_awesome_test --complete +cargo test --test sqllogictests -- my_awesome_test --complete ``` 3. Verify the content @@ -89,35 +102,35 @@ SELECT * from foo; Assuming it looks good, check it in! -#### Reference +# Reference -#### Running tests: Validation Mode +## Running tests: Validation Mode -In this model, `sqllogictests` runs the statements and queries in a `.slt` file, comparing the expected output in the +In this mode, `sqllogictests` runs the statements and queries in a `.slt` file, comparing the expected output in the file to the output produced by that run. For example, to run all tests suites in validation mode ```shell -cargo test -p datafusion --test sqllogictests +cargo test --test sqllogictests ``` sqllogictests also supports `cargo test` style substring matches on file names to restrict which tests to run ```shell # information_schema.slt matches due to substring matching `information` -cargo test -p datafusion --test sqllogictests -- information +cargo test --test sqllogictests -- information ``` -#### Running tests: Postgres compatibility +## Running tests: Postgres compatibility Test files that start with prefix `pg_compat_` verify compatibility -with Postgres by running the same script files both with DataFusion and with Posgres +with Postgres by running the same script files both with DataFusion and with Postgres In order to run the sqllogictests running against a previously running Postgres instance, do: ```shell -PG_COMPAT=true PG_URI="postgresql://postgres@127.0.0.1/postgres" cargo test -p datafusion --test sqllogictests +PG_COMPAT=true PG_URI="postgresql://postgres@127.0.0.1/postgres" cargo test --features=postgres --test sqllogictests ``` The environemnt variables: @@ -139,7 +152,7 @@ docker run \ postgres ``` -#### Running Tests: `tpch` +## Running Tests: `tpch` Test files in `tpch` directory runs against the `TPCH` data set (SF = 0.1), which must be generated before running. You can use following @@ -147,18 +160,19 @@ command to generate tpch data, assuming you are in the repository root: ```shell +mkdir -p datafusion/sqllogictest/test_files/tpch/data docker run -it \ - -v "$(realpath datafusion/core/tests/sqllogictests/test_files/tpch/data)":/data \ + -v "$(realpath datafusion/sqllogictest/test_files/tpch/data)":/data \ ghcr.io/databloom-ai/tpch-docker:main -vf -s 0.1 ``` Then you need to add `INCLUDE_TPCH=true` to run tpch tests: ```shell -INCLUDE_TPCH=true cargo test -p datafusion --test sqllogictests +INCLUDE_TPCH=true cargo test --test sqllogictests ``` -#### Updating tests: Completion Mode +## Updating tests: Completion Mode In test script completion mode, `sqllogictests` reads a prototype script and runs the statements and queries against the database engine. The output is a full script that is a copy of the prototype script with result inserted. @@ -167,17 +181,35 @@ You can update the tests / generate expected output by passing the `--complete` ```shell # Update ddl.slt with output from running -cargo test -p datafusion --test sqllogictests -- ddl --complete +cargo test --test sqllogictests -- ddl --complete ``` -#### sqllogictests +## Running tests: `scratchdir` + +The DataFusion sqllogictest runner automatically creates a directory +named `test_files/scratch/`, creating it if needed and +clearing any file contents if it exists. + +For example, the `test_files/copy.slt` file should use scratch +directory `test_files/scratch/copy`. + +Tests that need to write temporary files should write (only) to this +directory to ensure they do not interfere with others concurrently +running tests. + +## `.slt` file format + +[`sqllogictest`] was originally written for SQLite to verify the +correctness of SQL queries against the SQLite engine. The format is designed +engine-agnostic and can parse sqllogictest files (`.slt`), runs +queries against an SQL engine and compares the output to the expected +output. -sqllogictest is a program originally written for SQLite to verify the correctness of SQL queries against the SQLite -engine. The program is engine-agnostic and can parse sqllogictest files (`.slt`), runs queries against an SQL engine and -compare the output to the expected output. +[`sqllogictest`]: https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki -Tests in the `.slt` file are a sequence of query record generally starting with `CREATE` statements to populate tables -and then further queries to test the populated data (arrow-datafusion exception). +Tests in the `.slt` file are a sequence of query records generally +starting with `CREATE` statements to populate tables and then further +queries to test the populated data. Each `.slt` file runs in its own, isolated `SessionContext`, to make the test setup explicit and so they can run in parallel. Thus it important to keep the tests from having externally visible side effects (like writing to a global @@ -208,7 +240,7 @@ query - NULL values are rendered as `NULL`, - empty strings are rendered as `(empty)`, - boolean values are rendered as `true`/`false`, - - this list can be not exhaustive, check the `datafusion/core/tests/sqllogictests/src/engines/conversion.rs` for + - this list can be not exhaustive, check the `datafusion/sqllogictest/src/engines/conversion.rs` for details. - `sort_mode`: If included, it must be one of `nosort` (**default**), `rowsort`, or `valuesort`. In `nosort` mode, the results appear in exactly the order in which they were received from the database engine. The `nosort` mode should @@ -222,7 +254,7 @@ query > :warning: It is encouraged to either apply `order by`, or use `rowsort` for queries without explicit `order by` > clauses. -##### Example +### Example ```sql # group_by_distinct diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/sqllogictest/bin/sqllogictests.rs similarity index 72% rename from datafusion/core/tests/sqllogictests/src/main.rs rename to datafusion/sqllogictest/bin/sqllogictests.rs index d93d59fb3e1a2..618e3106c6292 100644 --- a/datafusion/core/tests/sqllogictests/src/main.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -16,26 +16,19 @@ // under the License. use std::ffi::OsStr; +use std::fs; use std::path::{Path, PathBuf}; #[cfg(target_family = "windows")] use std::thread; +use datafusion_sqllogictest::{DataFusion, TestContext}; use futures::stream::StreamExt; use log::info; use sqllogictest::strict_column_validator; -use tempfile::TempDir; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_err, DataFusionError, Result}; -use crate::engines::datafusion::DataFusion; -use crate::engines::postgres::Postgres; - -mod engines; -mod setup; -mod utils; - -const TEST_DIRECTORY: &str = "tests/sqllogictests/test_files/"; +const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; #[cfg(target_family = "windows")] @@ -62,6 +55,24 @@ pub async fn main() -> Result<()> { run_tests().await } +/// Sets up an empty directory at test_files/scratch/ +/// creating it if needed and clearing any file contents if it exists +/// This allows tests for inserting to external tables or copy to +/// to persist data to disk and have consistent state when running +/// a new test +fn setup_scratch_dir(name: &Path) -> Result<()> { + // go from copy.slt --> copy + let file_stem = name.file_stem().expect("File should have a stem"); + let path = PathBuf::from("test_files").join("scratch").join(file_stem); + + info!("Creating scratch dir in {path:?}"); + if path.exists() { + fs::remove_dir_all(&path)?; + } + fs::create_dir_all(&path)?; + Ok(()) +} + async fn run_tests() -> Result<()> { // Enable logging (e.g. set RUST_LOG=debug to see debug logs) env_logger::init(); @@ -110,10 +121,7 @@ async fn run_tests() -> Result<()> { for e in &errors { println!("{e}"); } - Err(DataFusionError::Execution(format!( - "{} failures", - errors.len() - ))) + exec_err!("{} failures", errors.len()) } else { Ok(()) } @@ -125,12 +133,17 @@ async fn run_test_file(test_file: TestFile) -> Result<()> { relative_path, } = test_file; info!("Running with DataFusion runner: {}", path.display()); - let Some(test_ctx) = context_for_test_file(&relative_path).await else { + let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else { info!("Skipping: {}", path.display()); return Ok(()); }; - let ctx = test_ctx.session_ctx().clone(); - let mut runner = sqllogictest::Runner::new(DataFusion::new(ctx, relative_path)); + setup_scratch_dir(&relative_path)?; + let mut runner = sqllogictest::Runner::new(|| async { + Ok(DataFusion::new( + test_ctx.session_ctx().clone(), + relative_path.clone(), + )) + }); runner.with_column_validator(strict_column_validator); runner .run_file_async(path) @@ -138,16 +151,16 @@ async fn run_test_file(test_file: TestFile) -> Result<()> { .map_err(|e| DataFusionError::External(Box::new(e))) } +#[cfg(feature = "postgres")] async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { + use datafusion_sqllogictest::Postgres; let TestFile { path, relative_path, } = test_file; info!("Running with Postgres runner: {}", path.display()); - let postgres_client = Postgres::connect(relative_path) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let mut runner = sqllogictest::Runner::new(postgres_client); + let mut runner = + sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); runner.with_column_validator(strict_column_validator); runner .run_file_async(path) @@ -156,6 +169,12 @@ async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { Ok(()) } +#[cfg(not(feature = "postgres"))] +async fn run_test_file_with_postgres(_test_file: TestFile) -> Result<()> { + use datafusion_common::plan_err; + plan_err!("Can not run with postgres as postgres feature is not enabled") +} + async fn run_complete_file(test_file: TestFile) -> Result<()> { let TestFile { path, @@ -165,13 +184,16 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { info!("Using complete mode to complete: {}", path.display()); - let Some(test_ctx) = context_for_test_file(&relative_path).await else { + let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else { info!("Skipping: {}", path.display()); return Ok(()); }; - let ctx = test_ctx.session_ctx().clone(); - let mut runner = - sqllogictest::Runner::new(DataFusion::new(ctx, relative_path.clone())); + let mut runner = sqllogictest::Runner::new(|| async { + Ok(DataFusion::new( + test_ctx.session_ctx().clone(), + relative_path.clone(), + )) + }); let col_separator = " "; runner .update_test_file( @@ -249,85 +271,6 @@ fn read_dir_recursive>(path: P) -> Box Option { - let config = SessionConfig::new() - // hardcode target partitions so plans are deterministic - .with_target_partitions(4); - - let test_ctx = TestContext::new(SessionContext::with_config(config)); - - let file_name = relative_path.file_name().unwrap().to_str().unwrap(); - match file_name { - "aggregate.slt" => { - info!("Registering aggregate tables"); - setup::register_aggregate_tables(test_ctx.session_ctx()).await; - } - "scalar.slt" => { - info!("Registering scalar tables"); - setup::register_scalar_tables(test_ctx.session_ctx()).await; - } - "avro.slt" => { - #[cfg(feature = "avro")] - { - let mut test_ctx = test_ctx; - info!("Registering avro tables"); - setup::register_avro_tables(&mut test_ctx).await; - return Some(test_ctx); - } - #[cfg(not(feature = "avro"))] - { - info!("Skipping {file_name} because avro feature is not enabled"); - return None; - } - } - _ => { - info!("Using default SessionContext"); - } - }; - Some(test_ctx) -} - -/// Context for running tests -pub struct TestContext { - /// Context for running queries - ctx: SessionContext, - /// Temporary directory created and cleared at the end of the test - test_dir: Option, -} - -impl TestContext { - pub fn new(ctx: SessionContext) -> Self { - Self { - ctx, - test_dir: None, - } - } - - /// Enables the test directory feature. If not enabled, - /// calling `testdir_path` will result in a panic. - pub fn enable_testdir(&mut self) { - if self.test_dir.is_none() { - self.test_dir = Some(TempDir::new().expect("failed to create testdir")); - } - } - - /// Returns the path to the test directory. Panics if the test - /// directory feature is not enabled via `enable_testdir`. - pub fn testdir_path(&self) -> &Path { - self.test_dir.as_ref().expect("testdir not enabled").path() - } - - /// Returns a reference to the internal SessionContext - fn session_ctx(&self) -> &SessionContext { - &self.ctx - } -} - /// Parsed command line options struct Options { // regex like diff --git a/datafusion/core/tests/sqllogictests/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs similarity index 64% rename from datafusion/core/tests/sqllogictests/src/engines/conversion.rs rename to datafusion/sqllogictest/src/engines/conversion.rs index c069c2d4a48df..909539b3131bc 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{Decimal128Type, DecimalType}; +use arrow::datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType}; use bigdecimal::BigDecimal; use half::f16; use rust_decimal::prelude::*; -use rust_decimal::Decimal; +/// Represents a constant for NULL string in your database. pub const NULL_STR: &str = "NULL"; -pub fn bool_to_str(value: bool) -> String { +pub(crate) fn bool_to_str(value: bool) -> String { if value { "true".to_string() } else { @@ -31,7 +31,7 @@ pub fn bool_to_str(value: bool) -> String { } } -pub fn varchar_to_str(value: &str) -> String { +pub(crate) fn varchar_to_str(value: &str) -> String { if value.is_empty() { "(empty)".to_string() } else { @@ -39,8 +39,10 @@ pub fn varchar_to_str(value: &str) -> String { } } -pub fn f16_to_str(value: f16) -> String { +pub(crate) fn f16_to_str(value: f16) -> String { if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. "NaN".to_string() } else if value == f16::INFINITY { "Infinity".to_string() @@ -51,8 +53,10 @@ pub fn f16_to_str(value: f16) -> String { } } -pub fn f32_to_str(value: f32) -> String { +pub(crate) fn f32_to_str(value: f32) -> String { if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. "NaN".to_string() } else if value == f32::INFINITY { "Infinity".to_string() @@ -63,8 +67,10 @@ pub fn f32_to_str(value: f32) -> String { } } -pub fn f64_to_str(value: f64) -> String { +pub(crate) fn f64_to_str(value: f64) -> String { if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. "NaN".to_string() } else if value == f64::INFINITY { "Infinity".to_string() @@ -75,17 +81,25 @@ pub fn f64_to_str(value: f64) -> String { } } -pub fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { +pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale)) .unwrap(), ) } -pub fn decimal_to_str(value: Decimal) -> String { +pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String { + big_decimal_to_str( + BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale)) + .unwrap(), + ) +} + +#[cfg(feature = "postgres")] +pub(crate) fn decimal_to_str(value: Decimal) -> String { big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) } -pub fn big_decimal_to_str(value: BigDecimal) -> String { +pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { value.round(12).normalized().to_string() } diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/error.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs similarity index 70% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/error.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/error.rs index ed6d1eda17c3a..5bb40aca2ab8f 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/error.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs @@ -28,45 +28,21 @@ pub type Result = std::result::Result; pub enum DFSqlLogicTestError { /// Error from sqllogictest-rs #[error("SqlLogicTest error(from sqllogictest-rs crate): {0}")] - SqlLogicTest(TestError), + SqlLogicTest(#[from] TestError), /// Error from datafusion #[error("DataFusion error: {0}")] - DataFusion(DataFusionError), + DataFusion(#[from] DataFusionError), /// Error returned when SQL is syntactically incorrect. #[error("SQL Parser error: {0}")] - Sql(ParserError), + Sql(#[from] ParserError), /// Error from arrow-rs #[error("Arrow error: {0}")] - Arrow(ArrowError), + Arrow(#[from] ArrowError), /// Generic error #[error("Other Error: {0}")] Other(String), } -impl From for DFSqlLogicTestError { - fn from(value: TestError) -> Self { - DFSqlLogicTestError::SqlLogicTest(value) - } -} - -impl From for DFSqlLogicTestError { - fn from(value: DataFusionError) -> Self { - DFSqlLogicTestError::DataFusion(value) - } -} - -impl From for DFSqlLogicTestError { - fn from(value: ParserError) -> Self { - DFSqlLogicTestError::Sql(value) - } -} - -impl From for DFSqlLogicTestError { - fn from(value: ArrowError) -> Self { - DFSqlLogicTestError::Arrow(value) - } -} - impl From for DFSqlLogicTestError { fn from(value: String) -> Self { DFSqlLogicTestError::Other(value) diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs new file mode 100644 index 0000000000000..663bbdd5a3c7c --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// DataFusion engine implementation for sqllogictest. +mod error; +mod normalize; +mod runner; + +pub use error::*; +pub use normalize::*; +pub use runner::*; diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs similarity index 78% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 6dd4e17d7dd70..c0db111bc60d8 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::DFField; use datafusion_common::DataFusionError; -use lazy_static::lazy_static; use std::path::PathBuf; +use std::sync::OnceLock; use crate::engines::output::DFColumnType; @@ -27,7 +29,7 @@ use super::super::conversion::*; use super::error::{DFSqlLogicTestError, Result}; /// Converts `batches` to a result as expected by sqllogicteset. -pub fn convert_batches(batches: Vec) -> Result>> { +pub(crate) fn convert_batches(batches: Vec) -> Result>> { if batches.is_empty() { Ok(vec![]) } else { @@ -105,7 +107,7 @@ fn expand_row(mut row: Vec) -> impl Iterator> { }) .collect(); - Either::Right(once(row).chain(new_lines.into_iter())) + Either::Right(once(row).chain(new_lines)) } else { Either::Left(once(row)) } @@ -113,18 +115,18 @@ fn expand_row(mut row: Vec) -> impl Iterator> { /// normalize path references /// -/// ``` +/// ```text /// CsvExec: files={1 group: [[path/to/datafusion/testing/data/csv/aggregate_test_100.csv]]}, ... /// ``` /// /// into: /// -/// ``` +/// ```text /// CsvExec: files={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, ... /// ``` fn normalize_paths(mut row: Vec) -> Vec { row.iter_mut().for_each(|s| { - let workspace_root: &str = WORKSPACE_ROOT.as_ref(); + let workspace_root: &str = workspace_root().as_ref(); if s.contains(workspace_root) { *s = s.replace(workspace_root, "WORKSPACE_ROOT"); } @@ -133,33 +135,32 @@ fn normalize_paths(mut row: Vec) -> Vec { } /// return the location of the datafusion checkout -fn workspace_root() -> object_store::path::Path { - // e.g. /Software/arrow-datafusion/datafusion/core - let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); +fn workspace_root() -> &'static object_store::path::Path { + static WORKSPACE_ROOT_LOCK: OnceLock = OnceLock::new(); + WORKSPACE_ROOT_LOCK.get_or_init(|| { + // e.g. /Software/arrow-datafusion/datafusion/core + let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - // e.g. /Software/arrow-datafusion/datafusion - let workspace_root = dir - .parent() - .expect("Can not find parent of datafusion/core") - // e.g. /Software/arrow-datafusion - .parent() - .expect("parent of datafusion") - .to_string_lossy(); + // e.g. /Software/arrow-datafusion/datafusion + let workspace_root = dir + .parent() + .expect("Can not find parent of datafusion/core") + // e.g. /Software/arrow-datafusion + .parent() + .expect("parent of datafusion") + .to_string_lossy(); - let sanitized_workplace_root = if cfg!(windows) { - // Object store paths are delimited with `/`, e.g. `D:/a/arrow-datafusion/arrow-datafusion/testing/data/csv/aggregate_test_100.csv`. - // The default windows delimiter is `\`, so the workplace path is `D:\a\arrow-datafusion\arrow-datafusion`. - workspace_root.replace(std::path::MAIN_SEPARATOR, object_store::path::DELIMITER) - } else { - workspace_root.to_string() - }; + let sanitized_workplace_root = if cfg!(windows) { + // Object store paths are delimited with `/`, e.g. `D:/a/arrow-datafusion/arrow-datafusion/testing/data/csv/aggregate_test_100.csv`. + // The default windows delimiter is `\`, so the workplace path is `D:\a\arrow-datafusion\arrow-datafusion`. + workspace_root + .replace(std::path::MAIN_SEPARATOR, object_store::path::DELIMITER) + } else { + workspace_root.to_string() + }; - object_store::path::Path::parse(sanitized_workplace_root).unwrap() -} - -// holds the root directory -lazy_static! { - static ref WORKSPACE_ROOT: object_store::path::Path = workspace_root(); + object_store::path::Path::parse(sanitized_workplace_root).unwrap() + }) } /// Convert a single batch to a `Vec>` for comparison @@ -199,6 +200,7 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { Ok(NULL_STR.to_string()) } else { match col.data_type() { + DataType::Null => Ok(NULL_STR.to_string()), DataType::Boolean => { Ok(bool_to_str(get_row_value!(array::BooleanArray, col, row))) } @@ -215,6 +217,10 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { let value = get_row_value!(array::Decimal128Array, col, row); Ok(i128_to_str(value, precision, scale)) } + DataType::Decimal256(precision, scale) => { + let value = get_row_value!(array::Decimal256Array, col, row); + Ok(i256_to_str(value, precision, scale)) + } DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!( array::LargeStringArray, col, @@ -223,14 +229,17 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Utf8 => { Ok(varchar_to_str(get_row_value!(array::StringArray, col, row))) } - _ => arrow::util::display::array_value_to_string(col, row), + _ => { + let f = ArrayFormatter::try_new(col.as_ref(), &DEFAULT_FORMAT_OPTIONS); + Ok(f.unwrap().value(row).to_string()) + } } .map_err(DFSqlLogicTestError::Arrow) } } /// Converts columns to a result as expected by sqllogicteset. -pub fn convert_schema_to_types(columns: &[DFField]) -> Vec { +pub(crate) fn convert_schema_to_types(columns: &[DFField]) -> Vec { columns .iter() .map(|f| f.data_type()) diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs similarity index 91% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index dd30ef494d497..afd0a241ca5ef 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -15,21 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::path::PathBuf; -use std::time::Duration; +use std::{path::PathBuf, time::Duration}; -use crate::engines::output::{DFColumnType, DFOutput}; - -use self::error::{DFSqlLogicTestError, Result}; +use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use log::info; use sqllogictest::DBOutput; -mod error; -mod normalize; -mod util; +use super::{error::Result, normalize, DFSqlLogicTestError}; + +use crate::engines::output::{DFColumnType, DFOutput}; pub struct DataFusion { ctx: SessionContext, @@ -61,7 +57,7 @@ impl sqllogictest::AsyncDB for DataFusion { "DataFusion" } - /// [`Runner`] calls this function to perform sleep. + /// [`DataFusion`] calls this function to perform sleep. /// /// The default implementation is `std::thread::sleep`, which is universal to any async runtime /// but would block the current thread. If you are running in tokio runtime, you should override diff --git a/datafusion/sqllogictest/src/engines/mod.rs b/datafusion/sqllogictest/src/engines/mod.rs new file mode 100644 index 0000000000000..a6a0886332ed7 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/mod.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Implementation of sqllogictest for datafusion. +mod conversion; +mod datafusion_engine; +mod output; + +pub use datafusion_engine::DataFusion; + +#[cfg(feature = "postgres")] +mod postgres_engine; + +#[cfg(feature = "postgres")] +pub use postgres_engine::Postgres; diff --git a/datafusion/core/tests/sqllogictests/src/engines/output.rs b/datafusion/sqllogictest/src/engines/output.rs similarity index 97% rename from datafusion/core/tests/sqllogictests/src/engines/output.rs rename to datafusion/sqllogictest/src/engines/output.rs index 0682f5df97c19..24299856e00d5 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/output.rs +++ b/datafusion/sqllogictest/src/engines/output.rs @@ -54,4 +54,4 @@ impl ColumnType for DFColumnType { } } -pub type DFOutput = DBOutput; +pub(crate) type DFOutput = DBOutput; diff --git a/datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs similarity index 99% rename from datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs rename to datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 2c6287b97bfd1..fe2785603e76d 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +/// Postgres engine implementation for sqllogictest. use std::path::{Path, PathBuf}; use std::str::FromStr; diff --git a/datafusion/core/tests/sqllogictests/src/engines/postgres/types.rs b/datafusion/sqllogictest/src/engines/postgres_engine/types.rs similarity index 100% rename from datafusion/core/tests/sqllogictests/src/engines/postgres/types.rs rename to datafusion/sqllogictest/src/engines/postgres_engine/types.rs diff --git a/datafusion/sqllogictest/src/lib.rs b/datafusion/sqllogictest/src/lib.rs new file mode 100644 index 0000000000000..1bcfd71af0fd0 --- /dev/null +++ b/datafusion/sqllogictest/src/lib.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion sqllogictest driver + +mod engines; + +pub use engines::DataFusion; + +#[cfg(feature = "postgres")] +pub use engines::Postgres; + +mod test_context; +pub use test_context::TestContext; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs new file mode 100644 index 0000000000000..f5ab8f71aaaf0 --- /dev/null +++ b/datafusion/sqllogictest/src/test_context.rs @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use datafusion::execution::context::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion::{ + arrow::{ + array::{ + BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, + }, + datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, + record_batch::RecordBatch, + }, + catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, + datasource::{MemTable, TableProvider, TableType}, + prelude::{CsvReadOptions, SessionContext}, +}; +use datafusion_common::DataFusionError; +use log::info; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tempfile::TempDir; + +/// Context for running tests +pub struct TestContext { + /// Context for running queries + ctx: SessionContext, + /// Temporary directory created and cleared at the end of the test + test_dir: Option, +} + +impl TestContext { + pub fn new(ctx: SessionContext) -> Self { + Self { + ctx, + test_dir: None, + } + } + + /// Create a SessionContext, configured for the specific sqllogictest + /// test(.slt file) , if possible. + /// + /// If `None` is returned (e.g. because some needed feature is not + /// enabled), the file should be skipped + pub async fn try_new_for_test_file(relative_path: &Path) -> Option { + let config = SessionConfig::new() + // hardcode target partitions so plans are deterministic + .with_target_partitions(4); + + let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); + + let file_name = relative_path.file_name().unwrap().to_str().unwrap(); + match file_name { + "scalar.slt" => { + info!("Registering scalar tables"); + register_scalar_tables(test_ctx.session_ctx()).await; + } + "information_schema_table_types.slt" => { + info!("Registering local temporary table"); + register_temp_table(test_ctx.session_ctx()).await; + } + "information_schema_columns.slt" => { + info!("Registering table with many types"); + register_table_with_many_types(test_ctx.session_ctx()).await; + } + "avro.slt" => { + #[cfg(feature = "avro")] + { + info!("Registering avro tables"); + register_avro_tables(&mut test_ctx).await; + } + #[cfg(not(feature = "avro"))] + { + info!("Skipping {file_name} because avro feature is not enabled"); + return None; + } + } + "joins.slt" => { + info!("Registering partition table tables"); + register_partition_table(&mut test_ctx).await; + } + "metadata.slt" => { + info!("Registering metadata table tables"); + register_metadata_tables(test_ctx.session_ctx()).await; + } + _ => { + info!("Using default SessionContext"); + } + }; + Some(test_ctx) + } + + /// Enables the test directory feature. If not enabled, + /// calling `testdir_path` will result in a panic. + pub fn enable_testdir(&mut self) { + if self.test_dir.is_none() { + self.test_dir = Some(TempDir::new().expect("failed to create testdir")); + } + } + + /// Returns the path to the test directory. Panics if the test + /// directory feature is not enabled via `enable_testdir`. + pub fn testdir_path(&self) -> &Path { + self.test_dir.as_ref().expect("testdir not enabled").path() + } + + /// Returns a reference to the internal SessionContext + pub fn session_ctx(&self) -> &SessionContext { + &self.ctx + } +} + +#[cfg(feature = "avro")] +pub async fn register_avro_tables(ctx: &mut crate::TestContext) { + use datafusion::prelude::AvroReadOptions; + + ctx.enable_testdir(); + + let table_path = ctx.testdir_path().join("avro"); + std::fs::create_dir(&table_path).expect("failed to create avro table path"); + + let testdata = datafusion::test_util::arrow_test_data(); + let alltypes_plain_file = format!("{testdata}/avro/alltypes_plain.avro"); + std::fs::copy( + &alltypes_plain_file, + format!("{}/alltypes_plain1.avro", table_path.display()), + ) + .unwrap(); + std::fs::copy( + &alltypes_plain_file, + format!("{}/alltypes_plain2.avro", table_path.display()), + ) + .unwrap(); + + ctx.session_ctx() + .register_avro( + "alltypes_plain_multi_files", + table_path.display().to_string().as_str(), + AvroReadOptions::default(), + ) + .await + .unwrap(); +} + +pub async fn register_scalar_tables(ctx: &SessionContext) { + register_nan_table(ctx) +} + +/// Register a table with a NaN value (different than NULL, and can +/// not be created via SQL) +fn register_nan_table(ctx: &SessionContext) { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); + + let data = RecordBatch::try_new( + schema, + vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(f64::NAN), + ]))], + ) + .unwrap(); + ctx.register_batch("test_float", data).unwrap(); +} + +/// Generate a partitioned CSV file and register it with an execution context +pub async fn register_partition_table(test_ctx: &mut TestContext) { + test_ctx.enable_testdir(); + let partition_count = 1; + let file_extension = "csv"; + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), + ])); + // generate a partitioned file + for partition in 0..partition_count { + let filename = format!("partition-{partition}.{file_extension}"); + let file_path = test_ctx.testdir_path().join(filename); + let mut file = File::create(file_path).unwrap(); + + // generate some data + for i in 0..=10 { + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); + file.write_all(data.as_bytes()).unwrap() + } + } + + // register csv file with the execution context + test_ctx + .ctx + .register_csv( + "test_partition_table", + test_ctx.testdir_path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await + .unwrap(); +} + +// registers a LOCAL TEMPORARY table. +pub async fn register_temp_table(ctx: &SessionContext) { + struct TestTable(TableType); + + #[async_trait] + impl TableProvider for TestTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_type(&self) -> TableType { + self.0 + } + + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + async fn scan( + &self, + _state: &SessionState, + _: Option<&Vec>, + _: &[Expr], + _: Option, + ) -> Result, DataFusionError> { + unimplemented!() + } + } + + ctx.register_table( + "datafusion.public.temp", + Arc::new(TestTable(TableType::Temporary)), + ) + .unwrap(); +} + +pub async fn register_table_with_many_types(ctx: &SessionContext) { + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + catalog + .register_schema("my_schema", Arc::new(schema)) + .unwrap(); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + ctx.register_table("my_catalog.my_schema.t2", table_with_many_types()) + .unwrap(); +} + +fn table_with_many_types() -> Arc { + let schema = Schema::new(vec![ + Field::new("int32_col", DataType::Int32, false), + Field::new("float64_col", DataType::Float64, true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, false), + Field::new("binary_col", DataType::Binary, false), + Field::new("large_binary_col", DataType::LargeBinary, false), + Field::new( + "timestamp_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Float64Array::from(vec![1.0])), + Arc::new(StringArray::from(vec![Some("foo")])), + Arc::new(LargeStringArray::from(vec![Some("bar")])), + Arc::new(BinaryArray::from(vec![b"foo" as &[u8]])), + Arc::new(LargeBinaryArray::from(vec![b"foo" as &[u8]])), + Arc::new(TimestampNanosecondArray::from(vec![Some(123)])), + ], + ) + .unwrap(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + Arc::new(provider) +} + +/// Registers a table_with_metadata that contains both field level and Table level metadata +pub async fn register_metadata_tables(ctx: &SessionContext) { + let id = Field::new("id", DataType::Int32, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the id field"), + )])); + let name = Field::new("name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the name field"), + )])); + + let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )])); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, + Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + ], + ) + .unwrap(); + + ctx.register_batch("table_with_metadata", batch).unwrap(); +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt new file mode 100644 index 0000000000000..7cfc9c707d432 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -0,0 +1,3214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +# Setup test data table +####### +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +statement ok +CREATE TABLE d_table (c1 decimal(10,3), c2 varchar) +as values +(110.000, 'A'), (110.001, 'A'), (110.002, 'A'), (110.003, 'A'), (110.004, 'A'), (110.005, 'A'), (110.006, 'A'), (110.007, 'A'), (110.008, 'A'), (110.009, 'A'), +(-100.000, 'B'),(-100.001, 'B'),(-100.002, 'B'),(-100.003, 'B'),(-100.004, 'B'),(-100.005, 'B'),(-100.006, 'B'),(-100.007, 'B'),(-100.008, 'B'),(-100.009, 'B') + +statement ok +CREATE TABLE median_table ( + col_i8 TINYINT, + col_i16 SMALLINT, + col_i32 INT, + col_i64 BIGINT, + col_u8 TINYINT UNSIGNED, + col_u16 SMALLINT UNSIGNED, + col_u32 INT UNSIGNED, + col_u64 BIGINT UNSIGNED, + col_f32 FLOAT, + col_f64 DOUBLE, + col_f64_nan DOUBLE +) as VALUES +( -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 1.1, 1.1, 1.1 ), +( -128, -32768, -2147483648, arrow_cast(-9223372036854775808,'Int64'), 0, 0, 0, arrow_cast(0,'UInt64'), 4.4, 4.4, arrow_cast('NAN','Float64') ), +( 100, 100, 100, arrow_cast(100,'Int64'), 100,100,100, arrow_cast(100,'UInt64'), 3.3, 3.3, arrow_cast('NAN','Float64') ), +( 127, 32767, 2147483647, arrow_cast(9223372036854775807,'Int64'), 255, 65535, 4294967295, 18446744073709551615, 2.2, 2.2, arrow_cast('NAN','Float64') ) + +statement ok +CREATE TABLE test (c1 BIGINT,c2 BIGINT) as values +(0,null), (1,1), (null,1), (3,2), (3,2) + +####### +# Error tests +####### + +# https://github.com/apache/arrow-datafusion/issues/3353 +statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "APPROX_DISTINCT\(aggregate_test_100\.c9\)" +SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 + +# csv_query_approx_percentile_cont_with_weight +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8, Int8, Float64\)'. You might need to add explicit type casts. +SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Utf8, Float64\)'\. You might need to add explicit type casts\. +SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Int8, Utf8\)'\. You might need to add explicit type casts\. +SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 + +# csv_query_approx_percentile_cont_with_histogram_bins +statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). +SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Utf8\)'\. You might need to add explicit type casts\. +SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. +SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 + +# array agg can use order by +query ? +SELECT array_agg(c13 ORDER BY c13) +FROM + (SELECT * + FROM aggregate_test_100 + ORDER BY c13 + LIMIT 5) as t1 +---- +[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8] + +statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 +SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 + + +# FIX: custom absolute values +# csv_query_avg_multi_batch + +# csv_query_avg +query R +SELECT avg(c12) FROM aggregate_test_100 +---- +0.508972509913 + +# csv_query_bit_and +query IIIII +SELECT bit_and(c5), bit_and(c6), bit_and(c7), bit_and(c8), bit_and(c9) FROM aggregate_test_100 +---- +0 0 0 0 0 + +# csv_query_bit_and_distinct +query IIIII +SELECT bit_and(distinct c5), bit_and(distinct c6), bit_and(distinct c7), bit_and(distinct c8), bit_and(distinct c9) FROM aggregate_test_100 +---- +0 0 0 0 0 + +# csv_query_bit_or +query IIIII +SELECT bit_or(c5), bit_or(c6), bit_or(c7), bit_or(c8), bit_or(c9) FROM aggregate_test_100 +---- +-1 -1 255 65535 4294967295 + +# csv_query_bit_or_distinct +query IIIII +SELECT bit_or(distinct c5), bit_or(distinct c6), bit_or(distinct c7), bit_or(distinct c8), bit_or(distinct c9) FROM aggregate_test_100 +---- +-1 -1 255 65535 4294967295 + +# csv_query_bit_xor +query IIIII +SELECT bit_xor(c5), bit_xor(c6), bit_xor(c7), bit_xor(c8), bit_xor(c9) FROM aggregate_test_100 +---- +1632751011 5960911605712039654 148 54789 169634700 + +# csv_query_bit_xor_distinct (should be different than above) +query IIIII +SELECT bit_xor(distinct c5), bit_xor(distinct c6), bit_xor(distinct c7), bit_xor(distinct c8), bit_xor(distinct c9) FROM aggregate_test_100 +---- +1632751011 5960911605712039654 196 54789 169634700 + +# csv_query_bit_xor_distinct_expr +query I +SELECT bit_xor(distinct c5 % 2) FROM aggregate_test_100 +---- +-2 + +# csv_query_covariance_1 +query R +SELECT covar_pop(c2, c12) FROM aggregate_test_100 +---- +-0.079169322354 + +# csv_query_covariance_2 +query R +SELECT covar(c2, c12) FROM aggregate_test_100 +---- +-0.079969012479 + +# single_row_query_covar_1 +query R +select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq +---- +NULL + +# single_row_query_covar_2 +query R +select covar_pop(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq +---- +0 + +# all_nulls_query_covar +query RR +with data as ( + select null::int as f, null::int as b + union all + select null::int as f, null::int as b +) +select covar_samp(f, b), covar_pop(f, b) +from data +---- +NULL NULL + +# covar_query_with_nulls +query RR +with data as ( + select 1 as f, 4 as b + union all + select null as f, 99 as b + union all + select 2 as f, 5 as b + union all + select 98 as f, null as b + union all + select 3 as f, 6 as b + union all + select null as f, null as b +) +select covar_samp(f, b), covar_pop(f, b) +from data +---- +1 0.666666666667 + +# csv_query_correlation +query R +SELECT corr(c2, c12) FROM aggregate_test_100 +---- +-0.190645441906 + +# single_row_query_correlation +query R +select corr(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq +---- +0 + +# all_nulls_query_correlation +query R +with data as ( + select null::int as f, null::int as b + union all + select null::int as f, null::int as b +) +select corr(f, b) +from data +---- +NULL + +# correlation_query_with_nulls +query R +with data as ( + select 1 as f, 4 as b + union all + select null as f, 99 as b + union all + select 2 as f, 5 as b + union all + select 98 as f, null as b + union all + select 3 as f, 6 as b + union all + select null as f, null as b +) +select corr(f, b) +from data +---- +1 + +# csv_query_variance_1 +query R +SELECT var_pop(c2) FROM aggregate_test_100 +---- +1.8675 + +# csv_query_variance_2 +query R +SELECT var_pop(c6) FROM aggregate_test_100 +---- +26156334342021890000000000000000000000 + +# csv_query_variance_3 +query R +SELECT var_pop(c12) FROM aggregate_test_100 +---- +0.092342237216 + +# csv_query_variance_4 +query R +SELECT var(c2) FROM aggregate_test_100 +---- +1.886363636364 + +# csv_query_variance_5 +query R +SELECT var_samp(c2) FROM aggregate_test_100 +---- +1.886363636364 + +# csv_query_stddev_1 +query R +SELECT stddev_pop(c2) FROM aggregate_test_100 +---- +1.366565036872 + +# csv_query_stddev_2 +query R +SELECT stddev_pop(c6) FROM aggregate_test_100 +---- +5114326382039172000 + +# csv_query_stddev_3 +query R +SELECT stddev_pop(c12) FROM aggregate_test_100 +---- +0.303878655413 + +# csv_query_stddev_4 +query R +SELECT stddev(c12) FROM aggregate_test_100 +---- +0.305409539941 + +# csv_query_stddev_5 +query R +SELECT stddev_samp(c12) FROM aggregate_test_100 +---- +0.305409539941 + +# csv_query_stddev_6 +query R +select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq +---- +0.950438495292 + +# csv_query_approx_median_1 +query I +SELECT approx_median(c2) FROM aggregate_test_100 +---- +3 + +# csv_query_approx_median_2 +query I +SELECT approx_median(c6) FROM aggregate_test_100 +---- +1146409980542786560 + +# csv_query_approx_median_3 +query R +SELECT approx_median(c12) FROM aggregate_test_100 +---- +0.555006541052 + +# csv_query_median_1 +query I +SELECT median(c2) FROM aggregate_test_100 +---- +3 + +# csv_query_median_2 +query I +SELECT median(c6) FROM aggregate_test_100 +---- +1125553990140691277 + +# csv_query_median_3 +query R +SELECT median(c12) FROM aggregate_test_100 +---- +0.551390054439 + +# median_i8 +query I +SELECT median(col_i8) FROM median_table +---- +-14 + +# median_i16 +query I +SELECT median(col_i16) FROM median_table +---- +-16334 + +# median_i32 +query I +SELECT median(col_i32) FROM median_table +---- +-1073741774 + +# median_i64 +query I +SELECT median(col_i64) FROM median_table +---- +-4611686018427387854 + +# median_u8 +query I +SELECT median(col_u8) FROM median_table +---- +50 + +# median_u16 +query I +SELECT median(col_u16) FROM median_table +---- +50 + +# median_u32 +query I +SELECT median(col_u32) FROM median_table +---- +50 + +# median_u64 +query I +SELECT median(col_u64) FROM median_table +---- +50 + +# median_f32 +query R +SELECT median(col_f32) FROM median_table +---- +2.75 + +# median_f64 +query R +SELECT median(col_f64) FROM median_table +---- +2.75 + +# median_f64_nan +query R +SELECT median(col_f64_nan) FROM median_table +---- +NaN + +# approx_median_f64_nan +query R +SELECT approx_median(col_f64_nan) FROM median_table +---- +NaN + +# median_multi +# test case for https://github.com/apache/arrow-datafusion/issues/3105 +# has an intermediate grouping +statement ok +create table cpu (host string, usage float) as select * from (values +('host0', 90.1), +('host1', 90.2), +('host1', 90.4) +); + +query TR rowsort +select host, median(usage) from cpu group by host; +---- +host0 90.1 +host1 90.3 + +statement ok +drop table cpu; + +# this test is to show create table as and select into works in the same way +statement ok +SELECT * INTO cpu +FROM (VALUES + ('host0', 90.1), + ('host1', 90.2), + ('host1', 90.4) + ) AS cpu (host, usage); + +query TR rowsort +select host, median(usage) from cpu group by host; +---- +host0 90.1 +host1 90.3 + +query R +select median(usage) from cpu; +---- +90.2 + +statement ok +drop table cpu; + +# median_multi_odd + +# data is not sorted and has an odd number of values per group +statement ok +create table cpu (host string, usage float) as select * from (values + ('host0', 90.2), + ('host1', 90.1), + ('host1', 90.5), + ('host0', 90.5), + ('host1', 90.0), + ('host1', 90.3), + ('host0', 87.9), + ('host1', 89.3) +); + +query TR rowsort +select host, median(usage) from cpu group by host; +---- +host0 90.2 +host1 90.1 + + +statement ok +drop table cpu; + +# median_multi_even +# data is not sorted and has an odd number of values per group +statement ok +create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3)); + +query TR rowsort +select host, median(usage) from cpu group by host; +---- +host0 90.35 +host1 90.25 + +statement ok +drop table cpu + +# csv_query_external_table_count +query I +SELECT COUNT(c12) FROM aggregate_test_100 +---- +100 + +# csv_query_external_table_sum +query II +SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100 +---- +13060 3017641 + +# csv_query_count +query I +SELECT count(c12) FROM aggregate_test_100 +---- +100 + +# csv_query_count_distinct +query I +SELECT count(distinct c2) FROM aggregate_test_100 +---- +5 + +# csv_query_count_distinct_expr +query I +SELECT count(distinct c2 % 2) FROM aggregate_test_100 +---- +2 + +# csv_query_count_star +query I +SELECT COUNT(*) FROM aggregate_test_100 +---- +100 + +# csv_query_count_literal +query I +SELECT COUNT(2) FROM aggregate_test_100 +---- +100 + +# csv_query_approx_count +# FIX: https://github.com/apache/arrow-datafusion/issues/3353 +# query II +# SELECT approx_distinct(c9) AS count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 +# ---- +# 100 99 + +# csv_query_approx_count_dupe_expr_aliased +query II +SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_100 +---- +100 100 + +## This test executes the APPROX_PERCENTILE_CONT aggregation against the test +## data, asserting the estimated quantiles are ±5% their actual values. +## +## Actual quantiles calculated with: +## +## ```r +## read_csv("./testing/data/csv/aggregate_test_100.csv") |> +## select_if(is.numeric) |> +## summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) +## ``` +## +## Giving: +## +## ```text +## c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 +## +## 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 +## 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 +## 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 +## ``` +## +## Column `c12` is omitted due to a large relative error (~10%) due to the small +## float values. + +#csv_query_approx_percentile_cont (c2) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c3) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.1) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.5) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.9) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c4) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.1) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.5) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.9) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c5) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.1) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.9) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c6) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.1) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.5) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.9) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c7) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.1) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.5) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.9) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c8) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.1) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.5) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.9) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c9) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.1) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.5) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c10) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.1) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.5) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.9) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_approx_percentile_cont (c11) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.1) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.5) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 +---- +true + +# csv_query_cube_avg +query TIR +SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 +---- +a 1 -17.6 +a 2 -15.333333333333 +a 3 -4.5 +a 4 -32 +a 5 -32 +a NULL -18.333333333333 +b 1 31.666666666667 +b 2 25.5 +b 3 -42 +b 4 -44.6 +b 5 -0.2 +b NULL -5.842105263158 +c 1 47.5 +c 2 -55.571428571429 +c 3 47.5 +c 4 -10.75 +c 5 12 +c NULL -1.333333333333 +d 1 -8.142857142857 +d 2 109.333333333333 +d 3 41.333333333333 +d 4 54 +d 5 -49.5 +d NULL 25.444444444444 +e 1 75.666666666667 +e 2 37.8 +e 3 48 +e 4 37.285714285714 +e 5 -11 +e NULL 40.333333333333 +NULL 1 16.681818181818 +NULL 2 8.363636363636 +NULL 3 20.789473684211 +NULL 4 1.260869565217 +NULL 5 -13.857142857143 +NULL NULL 7.81 + +# csv_query_rollup_avg +query TIIR +SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100 WHERE c1 IN ('a', 'b', NULL) GROUP BY ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3 +---- +a 1 -85 -15154 +a 1 -56 8692 +a 1 -25 15295 +a 1 -5 12636 +a 1 83 -14704 +a 1 NULL 1353 +a 2 -48 -18025 +a 2 -43 13080 +a 2 45 15673 +a 2 NULL 3576 +a 3 -72 -11122 +a 3 -12 -9168 +a 3 13 22338.5 +a 3 14 28162 +a 3 17 -22796 +a 3 NULL 4958.833333333333 +a 4 -101 11640 +a 4 -54 -2376 +a 4 -38 20744 +a 4 65 -28462 +a 4 NULL 386.5 +a 5 -101 -12484 +a 5 -31 -12907 +a 5 36 -16974 +a 5 NULL -14121.666666666666 +a NULL NULL 306.047619047619 +b 1 12 7652 +b 1 29 -18218 +b 1 54 -18410 +b 1 NULL -9658.666666666666 +b 2 -60 -21739 +b 2 31 23127 +b 2 63 21456 +b 2 68 15874 +b 2 NULL 9679.5 +b 3 -101 -13217 +b 3 17 14457 +b 3 NULL 620 +b 4 -117 19316 +b 4 -111 -1967 +b 4 -59 25286 +b 4 17 -28070 +b 4 47 20690 +b 4 NULL 7051 +b 5 -82 22080 +b 5 -44 15788 +b 5 -5 24896 +b 5 62 16337 +b 5 68 21576 +b 5 NULL 20135.4 +b NULL NULL 7732.315789473684 +NULL NULL NULL 3833.525 + +# csv_query_groupingsets_avg +query TIIR +SELECT c1, c2, c3, AVG(c4) +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b', NULL) +GROUP BY GROUPING SETS ((c1), (c1,c2), (c1,c2,c3)) +ORDER BY c1, c2, c3 +---- +a 1 -85 -15154 +a 1 -56 8692 +a 1 -25 15295 +a 1 -5 12636 +a 1 83 -14704 +a 1 NULL 1353 +a 2 -48 -18025 +a 2 -43 13080 +a 2 45 15673 +a 2 NULL 3576 +a 3 -72 -11122 +a 3 -12 -9168 +a 3 13 22338.5 +a 3 14 28162 +a 3 17 -22796 +a 3 NULL 4958.833333333333 +a 4 -101 11640 +a 4 -54 -2376 +a 4 -38 20744 +a 4 65 -28462 +a 4 NULL 386.5 +a 5 -101 -12484 +a 5 -31 -12907 +a 5 36 -16974 +a 5 NULL -14121.666666666666 +a NULL NULL 306.047619047619 +b 1 12 7652 +b 1 29 -18218 +b 1 54 -18410 +b 1 NULL -9658.666666666666 +b 2 -60 -21739 +b 2 31 23127 +b 2 63 21456 +b 2 68 15874 +b 2 NULL 9679.5 +b 3 -101 -13217 +b 3 17 14457 +b 3 NULL 620 +b 4 -117 19316 +b 4 -111 -1967 +b 4 -59 25286 +b 4 17 -28070 +b 4 47 20690 +b 4 NULL 7051 +b 5 -82 22080 +b 5 -44 15788 +b 5 -5 24896 +b 5 62 16337 +b 5 68 21576 +b 5 NULL 20135.4 +b NULL NULL 7732.315789473684 + +# csv_query_singlecol_with_rollup_avg +query TIIR +SELECT c1, c2, c3, AVG(c4) +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b', NULL) +GROUP BY c1, ROLLUP (c2, c3) +ORDER BY c1, c2, c3 +---- +a 1 -85 -15154 +a 1 -56 8692 +a 1 -25 15295 +a 1 -5 12636 +a 1 83 -14704 +a 1 NULL 1353 +a 2 -48 -18025 +a 2 -43 13080 +a 2 45 15673 +a 2 NULL 3576 +a 3 -72 -11122 +a 3 -12 -9168 +a 3 13 22338.5 +a 3 14 28162 +a 3 17 -22796 +a 3 NULL 4958.833333333333 +a 4 -101 11640 +a 4 -54 -2376 +a 4 -38 20744 +a 4 65 -28462 +a 4 NULL 386.5 +a 5 -101 -12484 +a 5 -31 -12907 +a 5 36 -16974 +a 5 NULL -14121.666666666666 +a NULL NULL 306.047619047619 +b 1 12 7652 +b 1 29 -18218 +b 1 54 -18410 +b 1 NULL -9658.666666666666 +b 2 -60 -21739 +b 2 31 23127 +b 2 63 21456 +b 2 68 15874 +b 2 NULL 9679.5 +b 3 -101 -13217 +b 3 17 14457 +b 3 NULL 620 +b 4 -117 19316 +b 4 -111 -1967 +b 4 -59 25286 +b 4 17 -28070 +b 4 47 20690 +b 4 NULL 7051 +b 5 -82 22080 +b 5 -44 15788 +b 5 -5 24896 +b 5 62 16337 +b 5 68 21576 +b 5 NULL 20135.4 +b NULL NULL 7732.315789473684 + +# csv_query_approx_percentile_cont_with_weight +query TI +SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + +# csv_query_approx_percentile_cont_with_weight (2) +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + +# csv_query_approx_percentile_cont_with_histogram_bins +query TI +SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + +# csv_query_sum_crossjoin +query TTI +SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 +---- +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 + +# csv_query_cube_sum_crossjoin +query TTI +SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1 +---- +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 +NULL a 5985 +NULL b 5415 +NULL c 5985 +NULL d 5130 +NULL e 5985 +NULL NULL 28500 + +# csv_query_cube_distinct_count +query TII +SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2 +---- +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 +NULL 1 22 +NULL 2 20 +NULL 3 17 +NULL 4 23 +NULL 5 14 +NULL NULL 80 + +# csv_query_rollup_distinct_count +query TII +SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2 +---- +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 +NULL NULL 80 + +# csv_query_rollup_sum_crossjoin +query TTI +SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1 +---- +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 +NULL NULL 28500 + +# query_count_without_from +query I +SELECT count(1 + 1) +---- +1 + +# csv_query_array_agg +query ? +SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test +---- +[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB] + +# csv_query_array_agg_empty +query ? +SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test +---- +[] + +# csv_query_array_agg_one +query ? +SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test +---- +[0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm] + +# csv_query_array_agg_with_overflow +query IIRIII +select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by c2 order by c2 +---- +1 367 16.681818181818 125 -99 22 +2 184 8.363636363636 122 -117 22 +3 395 20.789473684211 123 -101 19 +4 29 1.260869565217 123 -117 23 +5 -194 -13.857142857143 118 -101 14 + +# csv_query_array_cube_agg_with_overflow +query TIIRIII +select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2 +---- +a 1 -88 -17.6 83 -85 5 +a 2 -46 -15.333333333333 45 -48 3 +a 3 -27 -4.5 17 -72 6 +a 4 -128 -32 65 -101 4 +a 5 -96 -32 36 -101 3 +a NULL -385 -18.333333333333 83 -101 21 +b 1 95 31.666666666667 54 12 3 +b 2 102 25.5 68 -60 4 +b 3 -84 -42 17 -101 2 +b 4 -223 -44.6 47 -117 5 +b 5 -1 -0.2 68 -82 5 +b NULL -111 -5.842105263158 68 -117 19 +c 1 190 47.5 103 -24 4 +c 2 -389 -55.571428571429 29 -117 7 +c 3 190 47.5 97 -2 4 +c 4 -43 -10.75 123 -90 4 +c 5 24 12 118 -94 2 +c NULL -28 -1.333333333333 123 -117 21 +d 1 -57 -8.142857142857 125 -99 7 +d 2 328 109.333333333333 122 93 3 +d 3 124 41.333333333333 123 -76 3 +d 4 162 54 102 5 3 +d 5 -99 -49.5 -40 -59 2 +d NULL 458 25.444444444444 125 -99 18 +e 1 227 75.666666666667 120 36 3 +e 2 189 37.8 97 -61 5 +e 3 192 48 112 -95 4 +e 4 261 37.285714285714 97 -56 7 +e 5 -22 -11 64 -86 2 +e NULL 847 40.333333333333 120 -95 21 +NULL 1 367 16.681818181818 125 -99 22 +NULL 2 184 8.363636363636 122 -117 22 +NULL 3 395 20.789473684211 123 -101 19 +NULL 4 29 1.260869565217 123 -117 23 +NULL 5 -194 -13.857142857143 118 -101 14 +NULL NULL 781 7.81 125 -117 100 + +# TODO: array_agg_distinct output is non-determinisitic -- rewrite with array_sort(list_sort) +# unnest is also not available, so manually unnesting via CROSS JOIN +# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data +# +# csv_query_array_agg_distinct +query III +WITH indices AS ( + SELECT 1 AS idx UNION ALL + SELECT 2 AS idx UNION ALL + SELECT 3 AS idx UNION ALL + SELECT 4 AS idx UNION ALL + SELECT 5 AS idx +) +SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy +FROM ( + SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100 +) data + CROSS JOIN indices +ORDER BY 1 +---- +1 5 100 +2 5 100 +3 5 100 +4 5 100 +5 5 100 + +# aggregate_time_min_and_max +query TT +select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' union select '00:00:02') +---- +00:00:00 00:00:02 + +# aggregate_decimal_min +query RT +select min(c1), arrow_typeof(min(c1)) from d_table +---- +-100.009 Decimal128(10, 3) + +# aggregate_decimal_max +query RT +select max(c1), arrow_typeof(max(c1)) from d_table +---- +110.009 Decimal128(10, 3) + +# aggregate_decimal_sum +query RT +select sum(c1), arrow_typeof(sum(c1)) from d_table +---- +100 Decimal128(20, 3) + +# aggregate_decimal_avg +query RT +select avg(c1), arrow_typeof(avg(c1)) from d_table +---- +5 Decimal128(14, 7) + + +# aggregate +query II +SELECT SUM(c1), SUM(c2) FROM test +---- +7 6 + +# aggregate_empty + +query II +SELECT SUM(c1), SUM(c2) FROM test where c1 > 100000 +---- +NULL NULL + +# aggregate_avg +query RR +SELECT AVG(c1), AVG(c2) FROM test +---- +1.75 1.5 + +# aggregate_max +query II +SELECT MAX(c1), MAX(c2) FROM test +---- +3 2 + +# aggregate_min +query II +SELECT MIN(c1), MIN(c2) FROM test +---- +0 1 + +# aggregate_grouped +query II +SELECT c1, SUM(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 + +# aggregate_grouped_avg +query IR +SELECT c1, AVG(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 + +# aggregate_grouped_empty +query IR +SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1 +---- + +# aggregate_grouped_max +query II +SELECT c1, MAX(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 + +# aggregate_grouped_min +query II +SELECT c1, MIN(c2) FROM test GROUP BY c1 order by c1 +---- +0 NULL +1 1 +3 2 +NULL 1 + +# aggregate_min_max_w_custom_window_frames +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.3 PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN 0.1 PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.970671228336 +0.266717779508 0.996540038759 +0.360076636233 0.970671228336 + +# aggregate_min_max_with_custom_window_frames_unbounded_start +query RR +SELECT +MIN(c12) OVER (ORDER BY C12 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as min1, +MAX(c12) OVER (ORDER BY C11 RANGE BETWEEN UNBOUNDED PRECEDING AND 0.2 FOLLOWING) as max1 +FROM aggregate_test_100 +ORDER BY C9 +LIMIT 5 +---- +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 +0.014793053078 0.980019341044 +0.014793053078 0.996540038759 +0.014793053078 0.980019341044 + +# aggregate_avg_add +query RRRR +SELECT AVG(c1), AVG(c1) + 1, AVG(c1) + 2, 1 + AVG(c1) FROM test +---- +1.75 2.75 3.75 2.75 + +# case_sensitive_identifiers_aggregates +query I +SELECT max(c1) FROM test; +---- +3 + + + +# count_basic +query II +SELECT COUNT(c1), COUNT(c2) FROM test +---- +4 4 + +# TODO: count_partitioned + +# TODO: count_aggregated + +# TODO: count_aggregated_cube + +# count_multi_expr +query I +SELECT count(c1, c2) FROM test +---- +3 + +# count_multi_expr_group_by +query I +SELECT count(c1, c2) FROM test group by c1 order by c1 +---- +0 +1 +2 +0 + +# aggreggte_with_alias +query II +select c1, sum(c2) as `Total Salary` from test group by c1 order by c1 +---- +0 NULL +1 1 +3 4 +NULL 1 + +# simple_avg + +query R +select avg(c1) from test +---- +1.75 + +# simple_mean +query R +select mean(c1) from test +---- +1.75 + + + +# query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) +query RI +SELECT AVG(c1), SUM(DISTINCT c2) FROM test +---- +1.75 3 + +# query_sum_distinct - 2 sum(distinct) functions +query II +SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test +---- +4 3 + +# # query_count_distinct +query I +SELECT COUNT(DISTINCT c1) FROM test +---- +3 + +# TODO: count_distinct_integers_aggregated_single_partition + +# TODO: count_distinct_integers_aggregated_multiple_partitions + +# TODO: aggregate_with_alias + +# array_agg_zero +query ? +SELECT ARRAY_AGG([]) +---- +[[]] + +# array_agg_one +query ? +SELECT ARRAY_AGG([1]) +---- +[[1]] + +# test_approx_percentile_cont_decimal_support +query TI +SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 4 +b 5 +c 4 +d 4 +e 4 + + +# array_agg_zero +query ? +SELECT ARRAY_AGG([]); +---- +[[]] + +# array_agg_one +query ? +SELECT ARRAY_AGG([1]); +---- +[[1]] + +# variance_single_value +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; +---- +NULL 0 NULL 0 + +# variance_two_values +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq; +---- +2 1 1.414213562373 1 + + + +# aggregates on empty tables +statement ok +CREATE TABLE empty (column1 bigint, column2 int); + +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM empty +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping (no groups, so no output) +query IIRIIIIII +SELECT + count(column1), + sum(column1), + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM empty +GROUP BY column2 +ORDER BY column2; +---- + + +statement ok +drop table empty + +# aggregates on all nulls +statement ok +CREATE TABLE the_nulls +AS VALUES + (null::bigint, 1), + (null::bigint, 1), + (null::bigint, 2); + +query II +select * from the_nulls +---- +NULL 1 +NULL 1 +NULL 2 + +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM the_nulls +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping +query IIRIIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM the_nulls +GROUP BY column2 +ORDER BY column2; +---- +0 NULL NULL NULL NULL NULL NULL NULL 1 +0 NULL NULL NULL NULL NULL NULL NULL 2 + + +statement ok +drop table the_nulls; + +statement ok +create table bit_aggregate_functions ( + c1 SMALLINT NOT NULL, + c2 SMALLINT NOT NULL, + c3 SMALLINT, + tag varchar +) +as values + (5, 10, 11, 'A'), + (33, 11, null, 'B'), + (9, 12, null, 'A'); + +# query_bit_and, query_bit_or, query_bit_xor +query IIIIIIIII +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3) +FROM bit_aggregate_functions +---- +1 8 11 45 15 11 45 13 11 + +# query_bit_and, query_bit_or, query_bit_xor, with group +query IIIIIIIIIT +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3), + tag +FROM bit_aggregate_functions +GROUP BY tag +ORDER BY tag +---- +1 8 11 13 14 11 12 6 11 A +33 11 NULL 33 11 NULL 33 11 NULL B + + +statement ok +create table bool_aggregate_functions ( + c1 boolean not null, + c2 boolean not null, + c3 boolean not null, + c4 boolean not null, + c5 boolean, + c6 boolean, + c7 boolean, + c8 boolean, +) +as values + (true, true, false, false, true, true, null, null), + (true, false, true, false, false, null, false, null), + (true, true, false, false, null, true, false, null); + +# query_bool_and +query BBBBBBBB +SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +# query_bool_and_distinct +query BBBBBBBB +SELECT bool_and(distinct c1), bool_and(distinct c2), bool_and(distinct c3), bool_and(distinct c4), bool_and(distinct c5), bool_and(distinct c6), bool_and(distinct c7), bool_and(distinct c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +# query_bool_or +query BBBBBBBB +SELECT bool_or(c1), bool_or(c2), bool_or(c3), bool_or(c4), bool_or(c5), bool_or(c6), bool_or(c7), bool_or(c8) FROM bool_aggregate_functions +---- +true true true false true true false NULL + +# query_bool_or_distinct +query BBBBBBBB +SELECT bool_or(distinct c1), bool_or(distinct c2), bool_or(distinct c3), bool_or(distinct c4), bool_or(distinct c5), bool_or(distinct c6), bool_or(distinct c7), bool_or(distinct c8) FROM bool_aggregate_functions +---- +true true true false true true false NULL + +# All supported timestamp types + +# "nanos" --> TimestampNanosecondArray +# "micros" --> TimestampMicrosecondArray +# "millis" --> TimestampMillisecondArray +# "secs" --> TimestampSecondArray +# "names" --> StringArray + +statement ok +create table t_source +as values + ('2018-11-13T17:11:10.011375885995', 'Row 0', 'X'), + ('2011-12-13T11:13:10.12345', 'Row 1', 'X'), + (null, 'Row 2', 'Y'), + ('2021-01-01T05:11:10.432', 'Row 3', 'Y'); + +statement ok +create table t as +select + arrow_cast(column1, 'Timestamp(Nanosecond, None)') as nanos, + arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, + arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, + arrow_cast(column1, 'Timestamp(Second, None)') as secs, + column2 as names, + column3 as tag +from t_source; + +# Demonstate the contents +query PPPPTT +select * from t; +---- +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 X +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 X +NULL NULL NULL NULL Row 2 Y +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 Y + + +# aggregate_timestamps_sum +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t; + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag; + +# aggregate_timestamps_count +query IIII +SELECT count(nanos), count(micros), count(millis), count(secs) FROM t; +---- +3 3 3 3 + +query TIIII +SELECT tag, count(nanos), count(micros), count(millis), count(secs) FROM t GROUP BY tag ORDER BY tag; +---- +X 2 2 2 2 +Y 1 1 1 1 + +# aggregate_timestamps_min +query PPPP +SELECT min(nanos), min(micros), min(millis), min(secs) FROM t; +---- +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 + +query TPPPP +SELECT tag, min(nanos), min(micros), min(millis), min(secs) FROM t GROUP BY tag ORDER BY tag; +---- +X 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 +Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 + +# aggregate_timestamps_max +query PPPP +SELECT max(nanos), max(micros), max(millis), max(secs) FROM t; +---- +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 + +query TPPPP +SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag ORDER BY tag +---- +X 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 +Y 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 + + +# aggregate_timestamps_avg +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; + + +statement ok +drop table t_source; + +statement ok +drop table t; + + +# All supported Date tpes + +# "date32" --> Date32Array +# "date64" --> Date64Array +# "names" --> StringArray + +statement ok +create table t_source +as values + ('2018-11-13', 'Row 0', 'X'), + ('2011-12-13', 'Row 1', 'X'), + (null, 'Row 2', 'Y'), + ('2021-01-01', 'Row 3', 'Y'); + +statement ok +create table t as +select + arrow_cast(column1, 'Date32') as date32, + -- Workaround https://github.com/apache/arrow-rs/issues/4512 is fixed, can use this + -- arrow_cast(column1, 'Date64') as date64, + arrow_cast(arrow_cast(column1, 'Date32'), 'Date64') as date64, + column2 as names, + column3 as tag +from t_source; + +# Demonstate the contents +query DDTT +select * from t; +---- +2018-11-13 2018-11-13T00:00:00 Row 0 X +2011-12-13 2011-12-13T00:00:00 Row 1 X +NULL NULL Row 2 Y +2021-01-01 2021-01-01T00:00:00 Row 3 Y + + +# aggregate_timestamps_sum +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +SELECT sum(date32), sum(date64) FROM t; + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +SELECT tag, sum(date32), sum(date64) FROM t GROUP BY tag ORDER BY tag; + +# aggregate_timestamps_count +query II +SELECT count(date32), count(date64) FROM t; +---- +3 3 + +query TII +SELECT tag, count(date32), count(date64) FROM t GROUP BY tag ORDER BY tag; +---- +X 2 2 +Y 1 1 + +# aggregate_timestamps_min +query DD +SELECT min(date32), min(date64) FROM t; +---- +2011-12-13 2011-12-13T00:00:00 + +query TDD +SELECT tag, min(date32), min(date64) FROM t GROUP BY tag ORDER BY tag; +---- +X 2011-12-13 2011-12-13T00:00:00 +Y 2021-01-01 2021-01-01T00:00:00 + +# aggregate_timestamps_max +query DD +SELECT max(date32), max(date64) FROM t; +---- +2021-01-01 2021-01-01T00:00:00 + +query TDD +SELECT tag, max(date32), max(date64) FROM t GROUP BY tag ORDER BY tag +---- +X 2018-11-13 2018-11-13T00:00:00 +Y 2021-01-01 2021-01-01T00:00:00 + + +# aggregate_timestamps_avg +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +SELECT avg(date32), avg(date64) FROM t + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +SELECT tag, avg(date32), avg(date64) FROM t GROUP BY tag ORDER BY tag; + + +statement ok +drop table t_source; + +statement ok +drop table t; + + +# All supported time types + +# Columns are named: +# "nanos" --> Time64NanosecondArray +# "micros" --> Time64MicrosecondArray +# "millis" --> Time32MillisecondArray +# "secs" --> Time32SecondArray +# "names" --> StringArray + +statement ok +create table t_source +as values + ('18:06:30.243620451', 'Row 0', 'A'), + ('20:08:28.161121654', 'Row 1', 'A'), + ('19:11:04.156423842', 'Row 2', 'B'), + ('21:06:28.247821084', 'Row 3', 'B'); + + +statement ok +create table t as +select + arrow_cast(column1, 'Time64(Nanosecond)') as nanos, + arrow_cast(column1, 'Time64(Microsecond)') as micros, + arrow_cast(column1, 'Time32(Millisecond)') as millis, + arrow_cast(column1, 'Time32(Second)') as secs, + column2 as names, + column3 as tag +from t_source; + +# Demonstate the contents +query DDDDTT +select * from t; +---- +18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 Row 0 A +20:08:28.161121654 20:08:28.161121 20:08:28.161 20:08:28 Row 1 A +19:11:04.156423842 19:11:04.156423 19:11:04.156 19:11:04 Row 2 B +21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 B + +# aggregate_times_sum +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag + +# aggregate_times_count +query IIII +SELECT count(nanos), count(micros), count(millis), count(secs) FROM t +---- +4 4 4 4 + +query TIIII +SELECT tag, count(nanos), count(micros), count(millis), count(secs) FROM t GROUP BY tag ORDER BY tag +---- +A 2 2 2 2 +B 2 2 2 2 + + +# aggregate_times_min +query DDDD +SELECT min(nanos), min(micros), min(millis), min(secs) FROM t +---- +18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 + +query TDDDD +SELECT tag, min(nanos), min(micros), min(millis), min(secs) FROM t GROUP BY tag ORDER BY tag +---- +A 18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 +B 19:11:04.156423842 19:11:04.156423 19:11:04.156 19:11:04 + +# aggregate_times_max +query DDDD +SELECT max(nanos), max(micros), max(millis), max(secs) FROM t +---- +21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 + +query TDDDD +SELECT tag, max(nanos), max(micros), max(millis), max(secs) FROM t GROUP BY tag ORDER BY tag +---- +A 20:08:28.161121654 20:08:28.161121 20:08:28.161 20:08:28 +B 21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 + + +# aggregate_times_avg +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t + +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; + +statement ok +drop table t_source; + +statement ok +drop table t; + + +# aggregates on strings +statement ok +create table t_source +as values + ('Foo', 1), + ('Bar', 2), + (null, 2), + ('Baz', 1); + +statement ok +create table t as +select + arrow_cast(column1, 'Utf8') as utf8, + arrow_cast(column1, 'LargeUtf8') as largeutf8, + column2 as tag +from t_source; + +# No groupy +query TTITTI +SELECT + min(utf8), + max(utf8), + count(utf8), + min(largeutf8), + max(largeutf8), + count(largeutf8) +FROM t +---- +Bar Foo 3 Bar Foo 3 + + +# with groupby +query TTITTI +SELECT + min(utf8), + max(utf8), + count(utf8), + min(largeutf8), + max(largeutf8), + count(largeutf8) +FROM t +GROUP BY tag +ORDER BY tag +---- +Baz Foo 2 Baz Foo 2 +Bar Bar 1 Bar Bar 1 + + +statement ok +drop table t_source; + +statement ok +drop table t; + + +# aggregates on binary +statement ok +create table t_source +as values + ('Foo', 1), + ('Bar', 2), + (null, 2), + ('Baz', 1); + +statement ok +create table t as +select + arrow_cast(column1, 'Binary') as binary, + arrow_cast(column1, 'LargeBinary') as largebinary, + column2 as tag +from t_source; + +# No groupy +query ??I??I +SELECT + min(binary), + max(binary), + count(binary), + min(largebinary), + max(largebinary), + count(largebinary) +FROM t +---- +426172 466f6f 3 426172 466f6f 3 + +# with groupby +query ??I??I +SELECT + min(binary), + max(binary), + count(binary), + min(largebinary), + max(largebinary), + count(largebinary) +FROM t +GROUP BY tag +ORDER BY tag +---- +42617a 466f6f 2 42617a 466f6f 2 +426172 426172 1 426172 426172 1 + + + +statement ok +drop table t_source; + +statement ok +drop table t; + + +query I +select median(a) from (select 1 as a where 1=0); +---- +NULL + +query error DataFusion error: Execution error: aggregate function needs at least one non-null element +select approx_median(a) from (select 1 as a where 1=0); + + +# aggregate_decimal_sum +query RT +select sum(c1), arrow_typeof(sum(c1)) from d_table; +---- +100 Decimal128(20, 3) + +query TRT +select c2, sum(c1), arrow_typeof(sum(c1)) from d_table GROUP BY c2 ORDER BY c2; +---- +A 1100.045 Decimal128(20, 3) +B -1000.045 Decimal128(20, 3) + + +# aggregate_decimal_avg +query RT +select avg(c1), arrow_typeof(avg(c1)) from d_table +---- +5 Decimal128(14, 7) + +query TRT +select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2 +---- +A 110.0045 Decimal128(14, 7) +B -100.0045 Decimal128(14, 7) + +# Use PostgresSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +# Creating the table +statement ok +CREATE TABLE test_table (c1 INT, c2 INT, c3 INT) + +# Inserting data +statement ok +INSERT INTO test_table VALUES + (1, 10, 50), + (1, 20, 60), + (2, 10, 70), + (2, 20, 80), + (3, 10, NULL) + +# query_group_by_with_filter +query III rowsort +SELECT + c1, + SUM(c2) FILTER (WHERE c2 >= 20), + SUM(c2) FILTER (WHERE c2 < 1) -- no rows pass filter, so the output should be NULL +FROM test_table GROUP BY c1 +---- +1 20 NULL +2 20 NULL +3 NULL NULL + +# query_group_by_avg_with_filter +query IRR rowsort +SELECT + c1, + AVG(c2) FILTER (WHERE c2 >= 20), + AVG(c2) FILTER (WHERE c2 < 1) -- no rows pass filter, so output should be null +FROM test_table GROUP BY c1 +---- +1 20 NULL +2 20 NULL +3 NULL NULL + +# query_group_by_with_multiple_filters +query IIR rowsort +SELECT + c1, + SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, + AVG(c3) FILTER (WHERE c3 <= 70) AS avg_c3 +FROM test_table GROUP BY c1 +---- +1 20 55 +2 20 70 +3 NULL NULL + +# query_group_by_distinct_with_filter +query II rowsort +SELECT + c1, + COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count +FROM test_table GROUP BY c1 +---- +1 1 +2 1 +3 0 + +# query_without_group_by_with_filter +query I rowsort +SELECT + SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 +FROM test_table +---- +40 + +# count_without_group_by_with_filter +query I rowsort +SELECT + COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2 +FROM test_table +---- +2 + +# query_with_and_without_filter +query III rowsort +SELECT + c1, + SUM(c2) FILTER (WHERE c2 >= 20) as result, + SUM(c2) as result_no_filter +FROM test_table GROUP BY c1; +---- +1 20 30 +2 20 30 +3 NULL 10 + +# query_filter_on_different_column_than_aggregate +query I rowsort +select + sum(c1) FILTER (WHERE c2 < 30) +FROM test_table; +---- +9 + +# query_test_empty_filter +query I rowsort +SELECT + SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 +FROM test_table; +---- +NULL + +# Creating the decimal table +statement ok +CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1)) + +# Inserting data +statement ok +INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL) + +# aggregate_decimal_with_group_by +query IIRRRRIIR rowsort +select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1 +---- +1 2 15.15 30.3 10.1 20.2 2 0 NULL +2 2 15.15 30.3 10.1 20.2 2 0 NULL +3 2 10.1 20.2 10.1 10.1 1 0 NULL + +# aggregate_decimal_with_group_by_decimal +query RIRRRRIR rowsort +select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3 +---- +100.1 2 10.1 20.2 10.1 10.1 0 NULL +200.2 1 20.2 20.2 20.2 20.2 0 NULL +700.1 2 15.15 30.3 10.1 20.2 0 NULL +NULL 1 10.1 10.1 10.1 10.1 0 NULL + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +## Multiple distinct aggregates and dictionaries +statement ok +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); + +query I? +select * from dict_test; +---- +1 foo +2 bar + +query II +select count(distinct column1), count(distinct column2) from dict_test group by column1; +---- +1 1 +1 1 + +statement ok +drop table dict_test; + + +# Prepare the table with dictionary values for testing +statement ok +CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2); + +statement ok +CREATE TABLE value_dict AS SELECT arrow_cast(x, 'Dictionary(Int64, Int32)') AS x_dict FROM value; + +query ? +select x_dict from value_dict; +---- +1 +2 +3 +1 +3 +4 +5 +2 + +query I +select sum(x_dict) from value_dict; +---- +21 + +query R +select avg(x_dict) from value_dict; +---- +2.625 + +query I +select min(x_dict) from value_dict; +---- +1 + +query I +select max(x_dict) from value_dict; +---- +5 + +query I +select sum(x_dict) from value_dict where x_dict > 3; +---- +9 + +query R +select avg(x_dict) from value_dict where x_dict > 3; +---- +4.5 + +query I +select min(x_dict) from value_dict where x_dict > 3; +---- +4 + +query I +select max(x_dict) from value_dict where x_dict > 3; +---- +5 + +query I +select sum(x_dict) from value_dict group by x_dict % 2 order by sum(x_dict); +---- +8 +13 + +query R +select avg(x_dict) from value_dict group by x_dict % 2 order by avg(x_dict); +---- +2.6 +2.666666666667 + +query I +select min(x_dict) from value_dict group by x_dict % 2 order by min(x_dict); +---- +1 +2 + +query I +select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); +---- +4 +5 + +query T +select arrow_typeof(x_dict) from value_dict group by x_dict; +---- +Int32 +Int32 +Int32 +Int32 +Int32 + +statement ok +drop table value + +statement ok +drop table value_dict + + +# bool aggregation +statement ok +CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3); + +query B +select min(x) from value_bool; +---- +false + +query B +select max(x) from value_bool; +---- +true + +query B +select min(x) from value_bool group by g order by g; +---- +false +false +true +NULL + +query B +select max(x) from value_bool group by g order by g; +---- +true +false +true +NULL + +# TopK aggregation +statement ok +CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES +(NULL, 0, 0), +('a', NULL, NULL), +('a', 1, 1), +('a', -1, -1), +('b', 0, 0), +('c', 1, 1), +('c', 2, 2), +('b', 3, 3); + +statement ok +set datafusion.optimizer.enable_topk_aggregation = false; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +a -1 +NULL 0 +b 0 +c 1 + +query TII +select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; +---- +a -1 -1 +b 0 0 +NULL 0 0 +c 1 1 + +query TII +select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; +---- +a -1 -1 +NULL 0 0 +b 0 0 +c 1 1 + +statement ok +set datafusion.optimizer.enable_topk_aggregation = true; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MIN(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MIN(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MIN(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: traces.trace_id ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 +----SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +a -1 +NULL 0 +b 0 +c 1 + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; +---- +b 3 +c 2 +a 1 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; +---- +a -1 +NULL 0 +b 0 + +query TII +select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; +---- +a -1 -1 +b 0 0 +NULL 0 0 +c 1 1 + +query TII +select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; +---- +a -1 -1 +NULL 0 0 +b 0 0 +c 1 1 + +# +# Push limit into distinct group-by aggregation tests +# + +# Make results deterministic +statement ok +set datafusion.optimizer.repartition_aggregations = false; + +# +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +1 +-40 +29 +-85 +-82 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +logical_plan +Limit: skip=4, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=4, fetch=5 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +---- +5 -82 +4 -111 +3 104 +3 13 +1 38 + +# The limit should only apply to the aggregations which group by c3 +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Projection: aggregate_test_100.c3 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------Filter: aggregate_test_100.c3 >= Int16(10) AND aggregate_test_100.c3 <= Int16(20) +----------TableScan: aggregate_test_100 projection=[c2, c3], partial_filters=[aggregate_test_100.c3 >= Int16(10), aggregate_test_100.c3 <= Int16(20)] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[4] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[4] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------FilterExec: c3@1 >= 10 AND c3@1 <= 20 +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query I +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +---- +13 +17 +12 +14 + +# An aggregate expression causes the limit to not be pushed to the aggregation +query TT +EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; +---- +logical_plan +Projection: MAX(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[MAX(aggregate_test_100.c1)]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[MAX(aggregate_test_100.c1)@2 as MAX(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[MAX(aggregate_test_100.c1)] +------CoalescePartitionsExec +--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[MAX(aggregate_test_100.c1)] +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns +# in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case +query TT +EXPLAIN SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +logical_plan +Limit: skip=10, fetch=3 +--Aggregate: groupBy=[[aggregate_test_100.c3, aggregate_test_100.c2]], aggr=[[]] +----Projection: aggregate_test_100.c3, aggregate_test_100.c2 +------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +--------TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=10, fetch=3 +--AggregateExec: mode=Final, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3, c2@1 as c2], aggr=[], lim=[13] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------ProjectionExec: expr=[c3@1 as c3, c2@0 as c2] +------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] +------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +---- +57 1 +-54 4 +112 3 + +query TT +EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +logical_plan +Limit: skip=0, fetch=3 +--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +----TableScan: aggregate_test_100 projection=[c2, c3] +physical_plan +GlobalLimitExec: skip=0, fetch=3 +--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true + +query II +SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; +---- +NULL NULL +2 NULL +5 NULL + + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = false; + +# The limit should not be pushed into the aggregations +query TT +EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] +------TableScan: aggregate_test_100 projection=[c3] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] +------------CoalescePartitionsExec +--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true + +statement ok +set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + +# +# regr_*() tests +# + +# regr_*() invalid input +statement error +select regr_slope(); + +statement error +select regr_intercept(*); + +statement error +select regr_count(*) from aggregate_test_100; + +statement error +select regr_r2(1); + +statement error +select regr_avgx(1,2,3); + +statement error +select regr_avgy(1, 'foo'); + +statement error +select regr_sxx('foo', 1); + +statement error +select regr_syy('foo', 'bar'); + +statement error +select regr_sxy(NULL, 'bar'); + + + +# regr_*() NULL results +query RRRRRRRRR +select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1); +---- +NULL NULL 1 NULL 1 1 0 0 0 + +query RRRRRRRRR +select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6)); +---- +NULL NULL 3 NULL 1 4 0 8 0 + + + +# regr_*() basic tests +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,2), (2,4), (3,6)); +---- +2 0 3 1 2 4 2 8 4 + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + + + +# regr_*() functions ignore NULLs +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (2,4), (3,6)); +---- +2 0 2 1 2.5 5 0.5 2 1 + +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (NULL,4), (3,6)); +---- +NULL NULL 1 NULL 3 6 0 0 0 + +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (NULL,4), (NULL,NULL)); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query TRRRRRRRRR rowsort +select + column3, + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'), (NULL,100,'c')) +group by column3; +---- +a 2 0 2 1 1.5 3 0.5 2 1 +b 3 0 2 1 2 6 2 18 6 +c NULL NULL 1 NULL 1 10 0 0 0 + + + +# regr_*() testing merge_batch() from RegrAccumulator's internal implementation +statement ok +set datafusion.execution.batch_size = 1; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 2; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 3; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 8192; + + + +# regr_*() testing retract_batch() from RegrAccumulator's internal implementation +query RRRRRRRRR +SELECT + regr_slope(column2, column1) OVER w AS slope, + regr_intercept(column2, column1) OVER w AS intercept, + regr_count(column2, column1) OVER w AS count, + regr_r2(column2, column1) OVER w AS r2, + regr_avgx(column2, column1) OVER w AS avgx, + regr_avgy(column2, column1) OVER w AS avgy, + regr_sxx(column2, column1) OVER w AS sxx, + regr_syy(column2, column1) OVER w AS syy, + regr_sxy(column2, column1) OVER w AS sxy +FROM (VALUES (1,2), (2,4), (3,6), (4,12), (5,15), (6,18)) AS t(column1, column2) +WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW); +---- +NULL NULL 1 NULL 1 2 0 0 0 +2 0 2 1 1.5 3 0.5 2 1 +2 0 3 1 2 4 2 8 4 +4 -4.666666666667 3 0.923076923077 3 7.333333333333 2 34.666666666667 8 +4.5 -7 3 0.964285714286 4 11 2 42 9 +3 0 3 1 5 15 2 18 6 + +query RRRRRRRRR +SELECT + regr_slope(column2, column1) OVER w AS slope, + regr_intercept(column2, column1) OVER w AS intercept, + regr_count(column2, column1) OVER w AS count, + regr_r2(column2, column1) OVER w AS r2, + regr_avgx(column2, column1) OVER w AS avgx, + regr_avgy(column2, column1) OVER w AS avgy, + regr_sxx(column2, column1) OVER w AS sxx, + regr_syy(column2, column1) OVER w AS syy, + regr_sxy(column2, column1) OVER w AS sxy +FROM (VALUES (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 21)) AS t(column1, column2) +WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW); +---- +NULL NULL 1 NULL 1 2 0 0 0 +2 0 2 1 1.5 3 0.5 2 1 +2 0 3 1 2 4 2 8 4 +2 0 2 1 2.5 5 0.5 2 1 +NULL NULL 1 NULL 3 6 0 0 0 +NULL NULL 1 NULL 5 15 0 0 0 +3 0 2 1 5.5 16.5 0.5 4.5 1.5 +3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL + +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +NULL + +statement ok +drop table strings + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1 + + +# Queries with nested count(*) + +query I +select count(*) from (select count(*) from (select 1)); +---- +1 + +query I +select count(*) from (select count(*) a, count(*) b from (select 1)); +---- +1 \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt new file mode 100644 index 0000000000000..1202a2b1e99d6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array.slt @@ -0,0 +1,3573 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Array Expressions Tests +############# + +### Tables + +statement ok +CREATE TABLE values( + a INT, + b INT, + c INT, + d FLOAT, + e VARCHAR, + f VARCHAR +) AS VALUES + (1, 1, 2, 1.1, 'Lorem', 'A'), + (2, 3, 4, 2.2, 'ipsum', ''), + (3, 5, 6, 3.3, 'dolor', 'BB'), + (4, 7, 8, 4.4, 'sit', NULL), + (NULL, 9, 10, 5.5, 'amet', 'CCC'), + (5, NULL, 12, 6.6, ',', 'DD'), + (6, 11, NULL, 7.7, 'consectetur', 'E'), + (7, 13, 14, NULL, 'adipiscing', 'F'), + (8, 15, 16, 8.8, NULL, '') +; + +statement ok +CREATE TABLE values_without_nulls +AS VALUES + (1, 1, 2, 1.1, 'Lorem', 'A'), + (2, 3, 4, 2.2, 'ipsum', ''), + (3, 5, 6, 3.3, 'dolor', 'BB'), + (4, 7, 8, 4.4, 'sit', NULL), + (5, 9, 10, 5.5, 'amet', 'CCC'), + (6, 11, 12, 6.6, ',', 'DD'), + (7, 13, 14, 7.7, 'consectetur', 'E'), + (8, 15, 16, 8.8, 'adipiscing', 'F'), + (9, 17, 18, 9.9, 'elit', '') +; + +statement ok +CREATE TABLE arrays +AS VALUES + (make_array(make_array(NULL, 2),make_array(3, NULL)), make_array(1.1, 2.2, 3.3), make_array('L', 'o', 'r', 'e', 'm')), + (make_array(make_array(3, 4),make_array(5, 6)), make_array(NULL, 5.5, 6.6), make_array('i', 'p', NULL, 'u', 'm')), + (make_array(make_array(5, 6),make_array(7, 8)), make_array(7.7, 8.8, 9.9), make_array('d', NULL, 'l', 'o', 'r')), + (make_array(make_array(7, NULL),make_array(9, 10)), make_array(10.1, NULL, 12.2), make_array('s', 'i', 't')), + (NULL, make_array(13.3, 14.4, 15.5), make_array('a', 'm', 'e', 't')), + (make_array(make_array(11, 12),make_array(13, 14)), NULL, make_array(',')), + (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) +; + +statement ok +CREATE TABLE slices +AS VALUES + (make_array(NULL, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, NULL, 20), 2, -4), + (make_array(21, 22, 23, NULL, 25, 26, 27, 28, 29, 30), 0, 0), + (make_array(31, 32, 33, 34, 35, NULL, 37, 38, 39, 40), -4, -7), + (NULL, 4, 5), + (make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), NULL, 6), + (make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60), 5, NULL) +; + +statement ok +CREATE TABLE arrayspop +AS VALUES + (make_array(1, 2, NULL)), + (make_array(3, 4, 5, NULL)), + (make_array(6, 7, 8, NULL, 9)), + (make_array(NULL, NULL, 100)), + (NULL), + (make_array(NULL, 10, 11, 12)) +; + +statement ok +CREATE TABLE nested_arrays +AS VALUES + (make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), make_array(7, 8, 9), 2, make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]), make_array(11, 12, 13)), + (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]), make_array(121, 131, 141)) +; + +statement ok +CREATE TABLE arrays_values +AS VALUES + (make_array(NULL, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1, ','), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, NULL, 20), 12, 2, '.'), + (make_array(21, 22, 23, NULL, 25, 26, 27, 28, 29, 30), 23, 3, '-'), + (make_array(31, 32, 33, 34, 35, NULL, 37, 38, 39, 40), 34, 4, 'ok'), + (NULL, 44, 5, '@'), + (make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), NULL, 6, '$'), + (make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60), 55, NULL, '^'), + (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL) +; + +statement ok +CREATE TABLE arrays_values_v2 +AS VALUES + (make_array(NULL, 2, 3), make_array(4, 5, NULL), 12, make_array([30, 40, 50])), + (NULL, make_array(7, NULL, 8), 13, make_array(make_array(NULL,NULL,60))), + (make_array(9, NULL, 10), NULL, 14, make_array(make_array(70,NULL,NULL))), + (make_array(NULL, 1), make_array(NULL, 21), NULL, NULL), + (make_array(11, 12), NULL, NULL, NULL), + (NULL, NULL, NULL, NULL) +; + +statement ok +CREATE TABLE flatten_table +AS VALUES + (make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]), make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3, 3.4])), + (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0])) +; + +statement ok +CREATE TABLE array_has_table_1D +AS VALUES + (make_array(1, 2), 1, make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3,5)), + (make_array(3, 4, 5), 2, make_array(1,2,3,4), make_array(2,5), make_array(2,4,6), make_array(1,3,5)) +; + +statement ok +CREATE TABLE array_has_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), 1.0, make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), 2.0, make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE array_has_table_1D_Boolean +AS VALUES + (make_array(true, true, true), false, make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), false, make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE array_has_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), 'bc', make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), 'defg', make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE array_has_table_2D +AS VALUES + (make_array([1,2]), make_array(1,3), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array(5), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE array_has_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE array_has_table_3D +AS VALUES + (make_array([[1,2]]), make_array([1])), + (make_array([[1,2]]), make_array([1,2])), + (make_array([[1,2]]), make_array([1,2,3])), + (make_array([[1], [2]]), make_array([2])), + (make_array([[1], [2]]), make_array([1], [2])), + (make_array([[1], [2]], [[2], [3]]), make_array([1], [2], [3])), + (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) +; + +statement ok +CREATE TABLE array_distinct_table_1D +AS VALUES + (make_array(1, 1, 2, 2, 3)), + (make_array(1, 2, 3, 4, 5)), + (make_array(3, 5, 3, 3, 3)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_UTF8 +AS VALUES + (make_array('a', 'a', 'bc', 'bc', 'def')), + (make_array('a', 'bc', 'def', 'defg', 'defg')), + (make_array('defg', 'defg', 'defg', 'defg', 'defg')) +; + +statement ok +CREATE TABLE array_distinct_table_2D +AS VALUES + (make_array([1,2], [1,2], [3,4], [3,4], [5,6])), + (make_array([1,2], [3,4], [5,6], [7,8], [9,10])), + (make_array([5,6], [5,6], NULL)) +; + +statement ok +CREATE TABLE array_distinct_table_1D_large +AS VALUES + (arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')), + (arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), + (arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)')) +; + +statement ok +CREATE TABLE array_intersect_table_1D +AS VALUES + (make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)), + (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Boolean +AS VALUES + (make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE array_intersect_table_2D +AS VALUES + (make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE array_intersect_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE array_intersect_table_3D +AS VALUES + (make_array([[1,2]]), make_array([[1]])), + (make_array([[1,2]]), make_array([[1,2]])) +; + +statement ok +CREATE TABLE arrays_values_without_nulls +AS VALUES + (make_array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1, ',', [2,3]), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, 19, 20), 12, 2, '.', [4,5]), + (make_array(21, 22, 23, 24, 25, 26, 27, 28, 29, 30), 23, 3, '-', [6,7]), + (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) +; + +statement ok +CREATE TABLE arrays_range +AS VALUES + (3, 10, 2), + (4, 13, 3) +; + +statement ok +CREATE TABLE arrays_with_repeating_elements +AS VALUES + (make_array(1, 2, 1, 3, 2, 2, 1, 3, 2, 3), 2, 4, 3), + (make_array(4, 4, 5, 5, 6, 5, 5, 5, 4, 4), 4, 7, 2), + (make_array(7, 7, 7, 8, 7, 9, 7, 8, 7, 7), 7, 10, 5), + (make_array(10, 11, 12, 10, 11, 12, 10, 11, 12, 10), 10, 13, 10) +; + +statement ok +CREATE TABLE nested_arrays_with_repeating_elements +AS VALUES + (make_array([1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]), [4, 5, 6], [10, 11, 12], 3), + (make_array([10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]), [10, 11, 12], [19, 20, 21], 2), + (make_array([19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]), [19, 20, 21], [28, 29, 30], 5), + (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) +; + +query ? +select [1, true, null] +---- +[1, 1, ] + +query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() +SELECT [now()] + +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from arrays; +---- +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# arrays table +query ??? +select column1, column2, column3 from arrays; +---- +[[, 2], [3, ]] [1.1, 2.2, 3.3] [L, o, r, e, m] +[[3, 4], [5, 6]] [, 5.5, 6.6] [i, p, , u, m] +[[5, 6], [7, 8]] [7.7, 8.8, 9.9] [d, , l, o, r] +[[7, ], [9, 10]] [10.1, , 12.2] [s, i, t] +NULL [13.3, 14.4, 15.5] [a, m, e, t] +[[11, 12], [13, 14]] NULL [,] +[[15, 16], [, 18]] [16.6, 17.7, 18.8] NULL + +# nested_arrays table +query ??I?? +select column1, column2, column3, column4, column5 from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [7, 8, 9] 2 [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] [11, 12, 13] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [10, 11, 12] 3 [[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]] [121, 131, 141] + +# values table +query IIIRT +select a, b, c, d, e from values; +---- +1 1 2 1.1 Lorem +2 3 4 2.2 ipsum +3 5 6 3.3 dolor +4 7 8 4.4 sit +NULL 9 10 5.5 amet +5 NULL 12 6.6 , +6 11 NULL 7.7 consectetur +7 13 14 NULL adipiscing +8 15 16 8.8 NULL + +# arrays_values table +query ?IIT +select column1, column2, column3, column4 from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1 1 , +[11, 12, 13, 14, 15, 16, 17, 18, , 20] 12 2 . +[21, 22, 23, , 25, 26, 27, 28, 29, 30] 23 3 - +[31, 32, 33, 34, 35, , 37, 38, 39, 40] 34 4 ok +NULL 44 5 @ +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] NULL 6 $ +[51, 52, , 54, 55, 56, 57, 58, 59, 60] 55 NULL ^ +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] 66 7 NULL + +# slices table +query ?II +select column1, column2, column3 from slices; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1 1 +[11, 12, 13, 14, 15, 16, 17, 18, , 20] 2 -4 +[21, 22, 23, , 25, 26, 27, 28, 29, 30] 0 0 +[31, 32, 33, 34, 35, , 37, 38, 39, 40] -4 -7 +NULL 4 5 +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] NULL 6 +[51, 52, , 54, 55, 56, 57, 58, 59, 60] 5 NULL + +query ??I? +select column1, column2, column3, column4 from arrays_values_v2; +---- +[, 2, 3] [4, 5, ] 12 [[30, 40, 50]] +NULL [7, , 8] 13 [[, , 60]] +[9, , 10] NULL 14 [[70, , ]] +[, 1] [, 21] NULL NULL +[11, 12] NULL NULL NULL +NULL NULL NULL NULL + +# arrays_values_without_nulls table +query ?IIT +select column1, column2, column3, column4 from arrays_values_without_nulls; +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1 1 , +[11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 12 2 . +[21, 22, 23, 24, 25, 26, 27, 28, 29, 30] 23 3 - +[31, 32, 33, 34, 35, 26, 37, 38, 39, 40] 34 4 ok + +# arrays_with_repeating_elements table +query ?III +select column1, column2, column3, column4 from arrays_with_repeating_elements; +---- +[1, 2, 1, 3, 2, 2, 1, 3, 2, 3] 2 4 3 +[4, 4, 5, 5, 6, 5, 5, 5, 4, 4] 4 7 2 +[7, 7, 7, 8, 7, 9, 7, 8, 7, 7] 7 10 5 +[10, 11, 12, 10, 11, 12, 10, 11, 12, 10] 10 13 10 + +# nested_arrays_with_repeating_elements table +query ???I +select column1, column2, column3, column4 from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [4, 5, 6] [10, 11, 12] 3 +[[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [10, 11, 12] [19, 20, 21] 2 +[[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [19, 20, 21] [28, 29, 30] 5 +[[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [28, 29, 30] [37, 38, 39] 10 + + +### Array index + + +## array[i] + +# single index with scalars #1 (positive index) +query IRT +select make_array(1, 2, 3)[1], make_array(1.0, 2.0, 3.0)[2], make_array('h', 'e', 'l', 'l', 'o')[3]; +---- +1 2 l + +# single index with scalars #2 (zero index) +query I +select make_array(1, 2, 3)[0]; +---- +NULL + +# single index with scalars #3 (negative index) +query IRT +select make_array(1, 2, 3)[-1], make_array(1.0, 2.0, 3.0)[-2], make_array('h', 'e', 'l', 'l', 'o')[-3]; +---- +3 2 l + +# single index with scalars #4 (complex index) +query IRT +select make_array(1, 2, 3)[1 + 2 - 1], make_array(1.0, 2.0, 3.0)[2 * 1 * 0 - 2], make_array('h', 'e', 'l', 'l', 'o')[2 - 3]; +---- +2 2 o + +# single index with columns #1 (positive index) +query ?RT +select column1[2], column2[3], column3[1] from arrays; +---- +[3, ] 3.3 L +[5, 6] 6.6 i +[7, 8] 9.9 d +[9, 10] 12.2 s +NULL 15.5 a +[13, 14] NULL , +[, 18] 18.8 NULL + +# single index with columns #2 (zero index) +query ?RT +select column1[0], column2[0], column3[0] from arrays; +---- +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# single index with columns #3 (negative index) +query ?RT +select column1[-2], column2[-3], column3[-1] from arrays; +---- +[, 2] 1.1 m +[3, 4] NULL m +[5, 6] 7.7 r +[7, ] 10.1 t +NULL 13.3 t +[11, 12] NULL , +[15, 16] 16.6 NULL + +# single index with columns #4 (complex index) +query ?RT +select column1[9 - 7], column2[2 * 0], column3[1 - 3] from arrays; +---- +[3, ] NULL e +[5, 6] NULL u +[7, 8] NULL o +[9, 10] NULL i +NULL NULL e +[13, 14] NULL NULL +[, 18] NULL NULL + +# TODO: support index as column +# single index with columns #5 (index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and index as columns +# single index with columns #6 (argument and index as columns) +# query I +# select column1[column2] from arrays_with_repeating_elements; +# ---- + +## array[i:j] + +# multiple index with columns #1 (positive index) +query ??? +select make_array(1, 2, 3)[1:2], make_array(1.0, 2.0, 3.0)[2:3], make_array('h', 'e', 'l', 'l', 'o')[2:4]; +---- +[1, 2] [2.0, 3.0] [e, l, l] + +# multiple index with columns #2 (zero index) +query ??? +select make_array(1, 2, 3)[0:0], make_array(1.0, 2.0, 3.0)[0:2], make_array('h', 'e', 'l', 'l', 'o')[0:6]; +---- +[] [1.0, 2.0] [h, e, l, l, o] + +# TODO: support multiple negative index +# multiple index with columns #3 (negative index) +# query II +# select make_array(1, 2, 3)[-3:-1], make_array(1.0, 2.0, 3.0)[-3:-1], make_array('h', 'e', 'l', 'l', 'o')[-2:0]; +# ---- + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query III +# select make_array(1, 2, 3)[2 + 1 - 1:10], make_array(1.0, 2.0, 3.0)[2 | 2:10], make_array('h', 'e', 'l', 'l', 'o')[6 ^ 6:10]; +# ---- + +# multiple index with columns #1 (positive index) +query ??? +select column1[2:4], column2[1:4], column3[3:4] from arrays; +---- +[[3, ]] [1.1, 2.2, 3.3] [r, e] +[[5, 6]] [, 5.5, 6.6] [, u] +[[7, 8]] [7.7, 8.8, 9.9] [l, o] +[[9, 10]] [10.1, , 12.2] [t] +[] [13.3, 14.4, 15.5] [e, t] +[[13, 14]] [] [] +[[, 18]] [16.6, 17.7, 18.8] [] + +# multiple index with columns #2 (zero index) +query ??? +select column1[0:5], column2[0:3], column3[0:9] from arrays; +---- +[[, 2], [3, ]] [1.1, 2.2, 3.3] [L, o, r, e, m] +[[3, 4], [5, 6]] [, 5.5, 6.6] [i, p, , u, m] +[[5, 6], [7, 8]] [7.7, 8.8, 9.9] [d, , l, o, r] +[[7, ], [9, 10]] [10.1, , 12.2] [s, i, t] +[] [13.3, 14.4, 15.5] [a, m, e, t] +[[11, 12], [13, 14]] [] [,] +[[15, 16], [, 18]] [16.6, 17.7, 18.8] [] + +# TODO: support negative index +# multiple index with columns #3 (negative index) +# query ?RT +# select column1[-2:-4], column2[-3:-5], column3[-1:-4] from arrays; +# ---- +# [, 2] 1.1 m + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query ?RT +# select column1[9 - 7:2 + 2], column2[1 * 0:2 * 3], column3[1 + 1 - 0:5 % 3] from arrays; +# ---- + +# TODO: support first index as column +# multiple index with columns #5 (first index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2:4] from arrays_with_repeating_elements +# ---- + +# TODO: support last index as column +# multiple index with columns #6 (last index as column) +# query ?RT +# select make_array(1, 2, 3, 4, 5)[2:column3] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and indices as column +# multiple index with columns #7 (argument and indices as column) +# query ?RT +# select column1[column2:column3] from arrays_with_repeating_elements; +# ---- + + +### Array function tests + + +## make_array (aliases: `make_list`) + +# make_array scalar function #1 +query ??? +select make_array(1, 2, 3), make_array(1.0, 2.0, 3.0), make_array('h', 'e', 'l', 'l', 'o'); +---- +[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] + +# make_array scalar function #2 +query ??? +select make_array(1, 2, 3), make_array(make_array(1, 2), make_array(3, 4)), make_array([[[[1], [2]]]]); +---- +[1, 2, 3] [[1, 2], [3, 4]] [[[[[1], [2]]]]] + +# make_array scalar function #3 +query ?? +select make_array([1, 2, 3], [4, 5, 6], [7, 8, 9]), make_array([[1, 2], [3, 4]], [[5, 6], [7, 8]]); +---- +[[1, 2, 3], [4, 5, 6], [7, 8, 9]] [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + +# make_array scalar function #4 +query ?? +select make_array([1.0, 2.0], [3.0, 4.0]), make_array('h', 'e', 'l', 'l', 'o'); +---- +[[1.0, 2.0], [3.0, 4.0]] [h, e, l, l, o] + +# make_array scalar function #5 +query ? +select make_array(make_array(make_array(make_array(1, 2, 3), make_array(4, 5, 6)), make_array(make_array(7, 8, 9), make_array(10, 11, 12)))) +---- +[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] + +# make_array scalar function #6 +query ? +select make_array() +---- +[] + +# make_array scalar function #7 +query ?? +select make_array(make_array()), make_array(make_array(make_array())) +---- +[[]] [[[]]] + +# make_list scalar function #8 (function alias: `make_array`) +query ??? +select make_list(1, 2, 3), make_list(1.0, 2.0, 3.0), make_list('h', 'e', 'l', 'l', 'o'); +---- +[1, 2, 3] [1.0, 2.0, 3.0] [h, e, l, l, o] + +# make_array scalar function with nulls +query ??? +select make_array(1, NULL, 3), make_array(NULL, 2.0, NULL), make_array('h', NULL, 'l', NULL, 'o'); +---- +[1, , 3] [, 2.0, ] [h, , l, , o] + +# make_array scalar function with nulls #2 +query ?? +select make_array(1, 2, NULL), make_array(make_array(NULL, 2), make_array(NULL, 3)); +---- +[1, 2, ] [[, 2], [, 3]] + +# make_array scalar function with nulls #3 +query ??? +select make_array(NULL), make_array(NULL, NULL, NULL), make_array(make_array(NULL, NULL), make_array(NULL, NULL)); +---- +[] [, , ] [[, ], [, ]] + +# make_array with 1 columns +query ??? +select make_array(a), make_array(d), make_array(e) from values; +---- +[1] [1.1] [Lorem] +[2] [2.2] [ipsum] +[3] [3.3] [dolor] +[4] [4.4] [sit] +[] [5.5] [amet] +[5] [6.6] [,] +[6] [7.7] [consectetur] +[7] [] [adipiscing] +[8] [8.8] [] + +# make_array with 2 columns #1 +query ?? +select make_array(b, c), make_array(e, f) from values; +---- +[1, 2] [Lorem, A] +[3, 4] [ipsum, ] +[5, 6] [dolor, BB] +[7, 8] [sit, ] +[9, 10] [amet, CCC] +[, 12] [,, DD] +[11, ] [consectetur, E] +[13, 14] [adipiscing, F] +[15, 16] [, ] + +# make_array with 4 columns +query ? +select make_array(a, b, c, d) from values; +---- +[1.0, 1.0, 2.0, 1.1] +[2.0, 3.0, 4.0, 2.2] +[3.0, 5.0, 6.0, 3.3] +[4.0, 7.0, 8.0, 4.4] +[, 9.0, 10.0, 5.5] +[5.0, , 12.0, 6.6] +[6.0, 11.0, , 7.7] +[7.0, 13.0, 14.0, ] +[8.0, 15.0, 16.0, 8.8] + +# make_array with column of list +query ?? +select column1, column5 from arrays_values_without_nulls; +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] [2, 3] +[11, 12, 13, 14, 15, 16, 17, 18, 19, 20] [4, 5] +[21, 22, 23, 24, 25, 26, 27, 28, 29, 30] [6, 7] +[31, 32, 33, 34, 35, 26, 37, 38, 39, 40] [8, 9] + +query ??? +select make_array(column1), + make_array(column1, column5), + make_array(column1, make_array(50,51,52)) +from arrays_values_without_nulls; +---- +[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [2, 3]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [50, 51, 52]] +[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [4, 5]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [50, 51, 52]] +[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [6, 7]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [50, 51, 52]] +[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [8, 9]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [50, 51, 52]] + +## array_element (aliases: array_extract, list_extract, list_element) + +# array_element error +query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +select array_element(1, 2); + + +# array_element scalar function #1 (with positive index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element scalar function #2 (with positive index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); +---- +NULL NULL + +# array_element scalar function #3 (with zero) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); +---- +NULL NULL + +# array_element scalar function #4 (with NULL) +query error +select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); + +# array_element scalar function #5 (with negative index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); +---- +4 l + +# array_element scalar function #6 (with negative index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); +---- +NULL NULL + +# array_element scalar function #7 (nested array) +query ? +select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); +---- +[1, 2, 3, 4, 5] + +# array_extract scalar function #8 (function alias `array_slice`) +query IT +select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_element scalar function #9 (function alias `array_slice`) +query IT +select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_extract scalar function #10 (function alias `array_slice`) +query IT +select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element with columns +query I +select array_element(column1, column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + +# array_element with columns and scalars +query II +select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + +## array_pop_back (aliases: `list_pop_back`) + +# array_pop_back scalar function #1 +query ?? +select array_pop_back(make_array(1, 2, 3, 4, 5)), array_pop_back(make_array('h', 'e', 'l', 'l', 'o')); +---- +[1, 2, 3, 4] [h, e, l, l] + +# array_pop_back scalar function #2 (after array_pop_back, array is empty) +query ? +select array_pop_back(make_array(1)); +---- +[] + +# array_pop_back scalar function #3 (array_pop_back the empty array) +query ? +select array_pop_back(array_pop_back(make_array(1))); +---- +[] + +# array_pop_back scalar function #4 (array_pop_back the arrays which have NULL) +query ?? +select array_pop_back(make_array(1, 2, 3, 4, NULL)), array_pop_back(make_array(NULL, 'e', 'l', NULL, 'o')); +---- +[1, 2, 3, 4] [, e, l, ] + +# array_pop_back scalar function #5 (array_pop_back the nested arrays) +query ? +select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_back scalar function #6 (array_pop_back the nested arrays with NULL) +query ? +select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), NULL)); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_back scalar function #7 (array_pop_back the nested arrays with NULL) +query ? +select array_pop_back(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), NULL, make_array(1, 7, 4))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], ] + +# array_pop_back scalar function #8 (after array_pop_back, nested array is empty) +query ? +select array_pop_back(make_array(make_array(1, 2, 3))); +---- +[] + +# array_pop_back with columns +query ? +select array_pop_back(column1) from arrayspop; +---- +[1, 2] +[3, 4, 5] +[6, 7, 8, ] +[, ] +[] +[, 10, 11] + +## array_pop_front (aliases: `list_pop_front`) + +# array_pop_front scalar function #1 +query ?? +select array_pop_front(make_array(1, 2, 3, 4, 5)), array_pop_front(make_array('h', 'e', 'l', 'l', 'o')); +---- +[2, 3, 4, 5] [e, l, l, o] + +# array_pop_front scalar function #2 (after array_pop_front, array is empty) +query ? +select array_pop_front(make_array(1)); +---- +[] + +# array_pop_front scalar function #3 (array_pop_front the empty array) +query ? +select array_pop_front(array_pop_front(make_array(1))); +---- +[] + +# array_pop_front scalar function #5 (array_pop_front the nested arrays) +query ? +select array_pop_front(make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6))); +---- +[[2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] + +# array_pop_front scalar function #6 (array_pop_front the nested arrays with NULL) +query ? +select array_pop_front(make_array(NULL, make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4))); +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4]] + +# array_pop_front scalar function #8 (after array_pop_front, nested array is empty) +query ? +select array_pop_front(make_array(make_array(1, 2, 3))); +---- +[] + +## array_slice (aliases: list_slice) + +# array_slice scalar function #1 (with positive indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice scalar function #2 (with positive indexes; full array) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + +# array_slice scalar function #3 (with positive indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); +---- +[4] [l] + +# array_slice scalar function #4 (with positive indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); +---- +[] [] + +# array_slice scalar function #5 (with positive indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #6 (with positive indexes; nested array) +query ? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #7 (with zero and positive number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #8 (with NULL and positive number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); + +# array_slice scalar function #9 (with positive number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); + +# array_slice scalar function #10 (with zero-zero) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); +---- +[] [] + +# array_slice scalar function #11 (with NULL-NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); + +# array_slice scalar function #12 (with zero and negative number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); +---- +[1] [h, e] + +# array_slice scalar function #13 (with negative number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); + +# array_slice scalar function #14 (with NULL and negative number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); + +# array_slice scalar function #15 (with negative indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); +---- +[2, 3, 4] [l, l] + +# array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + +# array_slice scalar function #17 (with negative indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); +---- +[] [] + +# array_slice scalar function #18 (with negative indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); +---- +[] [] + +# array_slice scalar function #19 (with negative indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); +---- +[] [] + +# array_slice scalar function #20 (with negative indexes; nested array) +query ?? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + +# array_slice scalar function #21 (with first positive index and last negative index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); +---- +[2] [e, l] + +# array_slice scalar function #22 (with first negative index and last positive index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); +---- +[4, 5] [l, l] + +# list_slice scalar function #23 (function alias `array_slice`) +query ?? +select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice with columns +query ? +select array_slice(column1, column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + +# TODO: support NULLS in output instead of `[]` +# array_slice with columns and scalars +query ??? +select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + +# make_array with nulls +query ??????? +select make_array(make_array('a','b'), null), + make_array(make_array('a','b'), null, make_array('c','d')), + make_array(null, make_array('a','b'), null), + make_array(null, make_array('a','b'), null, null, make_array('c','d')), + make_array(['a', 'bc', 'def'], null, make_array('rust')), + make_array([1,2,3], null, make_array(4,5,6,7)), + make_array(null, 1, null, 2, null, 3, null, null, 4, 5); +---- +[[a, b], ] [[a, b], , [c, d]] [, [a, b], ] [, [a, b], , , [c, d]] [[a, bc, def], , [rust]] [[1, 2, 3], , [4, 5, 6, 7]] [, 1, , 2, , 3, , , 4, 5] + +query ? +select make_array(column5, null, column5) from arrays_values_without_nulls; +---- +[[2, 3], , [2, 3]] +[[4, 5], , [4, 5]] +[[6, 7], , [6, 7]] +[[8, 9], , [8, 9]] + +query ? +select make_array(['a','b'], null); +---- +[[a, b], ] + +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; +---- +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +NULL +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +NULL +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); +---- +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + + +## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) + +# TODO: array_append with NULLs +# array_append scalar function #1 +# query ? +# select array_append(make_array(), 4); +# ---- +# [4] + +# array_append scalar function #2 +# query ?? +# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); +# ---- +# [[]] [[4]] + +# array_append scalar function #3 +query ??? +select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_append scalar function #4 (element is list) +query ??? +select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# list_append scalar function #5 (function alias `array_append`) +query ??? +select list_append(make_array(1, 2, 3), 4), list_append(make_array(1.0, 2.0, 3.0), 4.0), list_append(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_push_back scalar function #6 (function alias `array_append`) +query ??? +select array_push_back(make_array(1, 2, 3), 4), array_push_back(make_array(1.0, 2.0, 3.0), 4.0), array_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# list_push_back scalar function #7 (function alias `array_append`) +query ??? +select list_push_back(make_array(1, 2, 3), 4), list_push_back(make_array(1.0, 2.0, 3.0), 4.0), list_push_back(make_array('h', 'e', 'l', 'l'), 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_append with columns #1 +query ? +select array_append(column1, column2) from arrays_values; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] +[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] +[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] +[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] +[44] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] +[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] + +# array_append with columns #2 (element is list) +query ? +select array_append(column1, column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + +# array_append with columns and scalars #1 +query ?? +select array_append(column2, 100.1), array_append(column3, '.') from arrays; +---- +[1.1, 2.2, 3.3, 100.1] [L, o, r, e, m, .] +[, 5.5, 6.6, 100.1] [i, p, , u, m, .] +[7.7, 8.8, 9.9, 100.1] [d, , l, o, r, .] +[10.1, , 12.2, 100.1] [s, i, t, .] +[13.3, 14.4, 15.5, 100.1] [a, m, e, t, .] +[100.1] [,, .] +[16.6, 17.7, 18.8, 100.1] [.] + +# array_append with columns and scalars #2 +query ?? +select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + +## array_prepend (aliases: `list_prepend`, `array_push_front`, `list_push_front`) + +# TODO: array_prepend with NULLs +# array_prepend scalar function #1 +# query ? +# select array_prepend(4, make_array()); +# ---- +# [4] + +# array_prepend scalar function #2 +# query ?? +# select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); +# ---- +# [[]] [[4]] + +# array_prepend scalar function #3 +query ??? +select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, 3.0, 4.0)), array_prepend('h', make_array('e', 'l', 'l', 'o')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_prepend scalar function #4 (element is list) +query ??? +select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# list_prepend scalar function #5 (function alias `array_prepend`) +query ??? +select list_prepend(1, make_array(2, 3, 4)), list_prepend(1.0, make_array(2.0, 3.0, 4.0)), list_prepend('h', make_array('e', 'l', 'l', 'o')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_push_front scalar function #6 (function alias `array_prepend`) +query ??? +select array_push_front(1, make_array(2, 3, 4)), array_push_front(1.0, make_array(2.0, 3.0, 4.0)), array_push_front('h', make_array('e', 'l', 'l', 'o')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# list_push_front scalar function #7 (function alias `array_prepend`) +query ??? +select list_push_front(1, make_array(2, 3, 4)), list_push_front(1.0, make_array(2.0, 3.0, 4.0)), list_push_front('h', make_array('e', 'l', 'l', 'o')); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array_prepend with columns #1 +query ? +select array_prepend(column2, column1) from arrays_values; +---- +[1, , 2, 3, 4, 5, 6, 7, 8, 9, 10] +[12, 11, 12, 13, 14, 15, 16, 17, 18, , 20] +[23, 21, 22, 23, , 25, 26, 27, 28, 29, 30] +[34, 31, 32, 33, 34, 35, , 37, 38, 39, 40] +[44] +[, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] +[66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] + +# array_prepend with columns #2 (element is list) +query ? +select array_prepend(column2, column1) from nested_arrays; +---- +[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + +# array_prepend with columns and scalars #1 +query ?? +select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; +---- +[100.1, 1.1, 2.2, 3.3] [., L, o, r, e, m] +[100.1, , 5.5, 6.6] [., i, p, , u, m] +[100.1, 7.7, 8.8, 9.9] [., d, , l, o, r] +[100.1, 10.1, , 12.2] [., s, i, t] +[100.1, 13.3, 14.4, 15.5] [., a, m, e, t] +[100.1] [., ,] +[100.1, 16.6, 17.7, 18.8] [.] + +# array_prepend with columns and scalars #2 (element is list) +query ?? +select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; +---- +[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + +## array_repeat (aliases: `list_repeat`) + +# array_repeat scalar function #1 +query ???????? +select + array_repeat(1, 5), + array_repeat(3.14, 3), + array_repeat('l', 4), + array_repeat(null, 2), + list_repeat(-1, 5), + list_repeat(-3.14, 0), + list_repeat('rust', 4), + list_repeat(null, 0); +---- +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] [, ] [-1, -1, -1, -1, -1] [] [rust, rust, rust, rust] [] + +# array_repeat scalar function #2 (element as list) +query ???? +select + array_repeat([1], 5), + array_repeat([1.1, 2.2, 3.3], 3), + array_repeat([null, null], 3), + array_repeat([[1, 2], [3, 4]], 2); +---- +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] + +# array_repeat with columns #1 + +statement ok +CREATE TABLE array_repeat_table +AS VALUES + (1, 1, 1.1, 'a', make_array(4, 5, 6)), + (2, null, null, null, null), + (3, 2, 2.2, 'rust', make_array(7)), + (0, 3, 3.3, 'datafusion', make_array(8, 9)); + +query ?????? +select + array_repeat(column2, column1), + array_repeat(column3, column1), + array_repeat(column4, column1), + array_repeat(column5, column1), + array_repeat(column2, 3), + array_repeat(make_array(1), column1) +from array_repeat_table; +---- +[1] [1.1] [a] [[4, 5, 6]] [1, 1, 1] [[1]] +[, ] [, ] [, ] [, ] [, , ] [[1], [1]] +[2, 2, 2] [2.2, 2.2, 2.2] [rust, rust, rust] [[7], [7], [7]] [2, 2, 2] [[1], [1], [1]] +[] [] [] [] [3, 3, 3] [] + +statement ok +drop table array_repeat_table; + +## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) + +# array_concat error +query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. +select array_concat(1, 2); + +# array_concat scalar function #1 +query ?? +select array_concat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), array_concat(make_array([1], [2]), make_array([3], [4])); +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] + +# array_concat scalar function #2 +query ? +select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array(5, 6), make_array(7, 8))); +---- +[[1, 2], [3, 4], [5, 6], [7, 8]] + +# array_concat scalar function #3 +query ? +select array_concat(make_array([1], [2], [3]), make_array([4], [5], [6]), make_array([7], [8], [9])); +---- +[[1], [2], [3], [4], [5], [6], [7], [8], [9]] + +# array_concat scalar function #4 +query ? +select array_concat(make_array([[1]]), make_array([[2]])); +---- +[[[1]], [[2]]] + +# array_concat scalar function #5 +query ? +select array_concat(make_array(2, 3), make_array()); +---- +[2, 3] + +# array_concat scalar function #6 +query ? +select array_concat(make_array(), make_array(2, 3)); +---- +[2, 3] + +# array_concat scalar function #7 (with empty arrays) +query ? +select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array())); +---- +[[1, 2], [3, 4]] + +# array_concat scalar function #8 (with empty arrays) +query ? +select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()), make_array(make_array(), make_array()), make_array(make_array(5, 6), make_array(7, 8))); +---- +[[1, 2], [3, 4], [5, 6], [7, 8]] + +# array_concat scalar function #9 (with empty arrays) +query ? +select array_concat(make_array(make_array()), make_array(make_array(1, 2), make_array(3, 4))); +---- +[[1, 2], [3, 4]] + +# array_cat scalar function #10 (function alias `array_concat`) +query ?? +select array_cat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), array_cat(make_array([1], [2]), make_array([3], [4])); +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] + +# list_concat scalar function #11 (function alias `array_concat`) +query ?? +select list_concat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), list_concat(make_array([1], [2]), make_array([3], [4])); +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] + +# list_cat scalar function #12 (function alias `array_concat`) +query ?? +select list_cat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), list_cat(make_array([1], [2]), make_array([3], [4])); +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] + +# array_concat with different dimensions #1 (2D + 1D) +query ? +select array_concat(make_array([1,2], [3,4]), make_array(5, 6)); +---- +[[1, 2], [3, 4], [5, 6]] + +# array_concat with different dimensions #2 (1D + 2D) +query ? +select array_concat(make_array(5, 6), make_array([1,2], [3,4])); +---- +[[5, 6], [1, 2], [3, 4]] + +# array_concat with different dimensions #3 (2D + 1D + 1D) +query ? +select array_concat(make_array([1,2], [3,4]), make_array(5, 6), make_array(7,8)); +---- +[[1, 2], [3, 4], [5, 6], [7, 8]] + +# array_concat with different dimensions #4 (1D + 2D + 3D) +query ? +select array_concat(make_array(10, 20), make_array([30, 40]), make_array([[50, 60]])); +---- +[[[10, 20]], [[30, 40]], [[50, 60]]] + +# array_concat with different dimensions #5 (2D + 1D + 3D) +query ? +select array_concat(make_array([30, 40]), make_array(10, 20), make_array([[50, 60]])); +---- +[[[30, 40]], [[10, 20]], [[50, 60]]] + +# array_concat with different dimensions #6 (2D + 1D + 3D + 4D + 3D) +query ? +select array_concat(make_array([30, 40]), make_array(10, 20), make_array([[50, 60]]), make_array([[[70, 80]]]), make_array([[80, 40]])); +---- +[[[[30, 40]]], [[[10, 20]]], [[[50, 60]]], [[[70, 80]]], [[[80, 40]]]] + +# array_concat column-wise #1 +query ? +select array_concat(column1, make_array(0)) from arrays_values_without_nulls; +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0] +[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0] +[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 0] +[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 0] + +# array_concat column-wise #2 +query ? +select array_concat(column1, column1) from arrays_values_without_nulls; +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] +[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] +[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 31, 32, 33, 34, 35, 26, 37, 38, 39, 40] + +# array_concat column-wise #3 +query ? +select array_concat(make_array(column2), make_array(column3)) from arrays_values_without_nulls; +---- +[1, 1] +[12, 2] +[23, 3] +[34, 4] + +# array_concat column-wise #4 +query ? +select array_concat(make_array(column2), make_array(0)) from arrays_values; +---- +[1, 0] +[12, 0] +[23, 0] +[34, 0] +[44, 0] +[, 0] +[55, 0] +[66, 0] + +# array_concat column-wise #5 +query ??? +select array_concat(column1, column1), array_concat(column2, column2), array_concat(column3, column3) from arrays; +---- +[[, 2], [3, ], [, 2], [3, ]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3] [L, o, r, e, m, L, o, r, e, m] +[[3, 4], [5, 6], [3, 4], [5, 6]] [, 5.5, 6.6, , 5.5, 6.6] [i, p, , u, m, i, p, , u, m] +[[5, 6], [7, 8], [5, 6], [7, 8]] [7.7, 8.8, 9.9, 7.7, 8.8, 9.9] [d, , l, o, r, d, , l, o, r] +[[7, ], [9, 10], [7, ], [9, 10]] [10.1, , 12.2, 10.1, , 12.2] [s, i, t, s, i, t] +NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, m, e, t] +[[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,] +[[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL + +# array_concat column-wise #6 +query ?? +select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays; +---- +[[, 2], [3, ], [1, 2], [3, 4]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3] +[[3, 4], [5, 6], [1, 2], [3, 4]] [, 5.5, 6.6, 1.1, 2.2, 3.3] +[[5, 6], [7, 8], [1, 2], [3, 4]] [7.7, 8.8, 9.9, 1.1, 2.2, 3.3] +[[7, ], [9, 10], [1, 2], [3, 4]] [10.1, , 12.2, 1.1, 2.2, 3.3] +[[1, 2], [3, 4]] [13.3, 14.4, 15.5, 1.1, 2.2, 3.3] +[[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3] +[[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3] + +# array_concat column-wise #7 +query ? +select array_concat(column3, make_array('.', '.', '.')) from arrays; +---- +[L, o, r, e, m, ., ., .] +[i, p, , u, m, ., ., .] +[d, , l, o, r, ., ., .] +[s, i, t, ., ., .] +[a, m, e, t, ., ., .] +[,, ., ., .] +[., ., .] + +# query ??I? +# select column1, column2, column3, column4 from arrays_values_v2; +# ---- +# [, 2, 3] [4, 5, ] 12 [[30, 40, 50]] +# NULL [7, , 8] 13 [[, , 60]] +# [9, , 10] NULL 14 [[70, , ]] +# [, 1] [, 21] NULL NULL +# [11, 12] NULL NULL NULL +# NULL NULL NULL NULL + + +# array_concat column-wise #8 (1D + 1D) +query ? +select array_concat(column1, column2) from arrays_values_v2; +---- +[, 2, 3, 4, 5, ] +[7, , 8] +[9, , 10] +[, 1, , 21] +[11, 12] +NULL + +# array_concat column-wise #9 (2D + 1D) +query ? +select array_concat(column4, make_array(column3)) from arrays_values_v2; +---- +[[30, 40, 50], [12]] +[[, , 60], [13]] +[[70, , ], [14]] +[[]] +[[]] +[[]] + +# array_concat column-wise #10 (3D + 2D + 1D) +query ? +select array_concat(column4, column1, column2) from nested_arrays; +---- +[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]], [[7, 8, 9]]] +[[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]], [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]], [[10, 11, 12]]] + +# array_concat column-wise #11 (2D + 1D) +query ? +select array_concat(column4, column1) from arrays_values_v2; +---- +[[30, 40, 50], [, 2, 3]] +[[, , 60], ] +[[70, , ], [9, , 10]] +[[, 1]] +[[11, 12]] +[] + +# array_concat column-wise #12 (1D + 1D + 1D) +query ? +select array_concat(make_array(column3), column1, column2) from arrays_values_v2; +---- +[12, , 2, 3, 4, 5, ] +[13, 7, , 8] +[14, 9, , 10] +[, , 1, , 21] +[, 11, 12] +[] + +## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`) + +# array_position scalar function #1 +query III +select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, 4, 5], 5), array_position([1, 1, 1], 1); +---- +3 5 1 + +# array_position scalar function #2 (with optional argument) +query III +select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); +---- +4 5 2 + +# array_position scalar function #3 (element is list) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +---- +2 2 + +# array_position scalar function #4 (element in list; with optional argument) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], 3), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 3); +---- +4 3 + +# list_position scalar function #5 (function alias `array_position`) +query III +select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1); +---- +3 5 1 + +# array_indexof scalar function #6 (function alias `array_position`) +query III +select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1); +---- +3 5 1 + +# list_indexof scalar function #7 (function alias `array_position`) +query III +select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1); +---- +3 5 1 + +# array_position with columns #1 +query II +select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; +---- +1 1 +2 2 +3 3 +4 4 + +# array_position with columns #2 (element is list) +query II +select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; +---- +3 3 +2 5 + +# array_position with columns and scalars #1 +query III +select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# array_position with columns and scalars #2 (element is list) +query III +select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; +---- +NULL 6 4 +NULL 1 NULL + +## array_positions (aliases: `list_positions`) + +# array_positions scalar function #1 +query ??? +select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); +---- +[3, 4] [5] [1, 2, 3] + +# array_positions scalar function #2 (element is list) +query ? +select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); +---- +[2, 4] + +# list_positions scalar function #3 (function alias `array_positions`) +query ??? +select list_positions(['h', 'e', 'l', 'l', 'o'], 'l'), list_positions([1, 2, 3, 4, 5], 5), list_positions([1, 1, 1], 1); +---- +[3, 4] [5] [1, 2, 3] + +# array_positions with columns #1 +query ? +select array_positions(column1, column2) from arrays_values_without_nulls; +---- +[1] +[2] +[3] +[4] + +# array_positions with columns #2 (element is list) +query ? +select array_positions(column1, column2) from nested_arrays; +---- +[3] +[2, 5] + +# array_positions with columns and scalars #1 +query ?? +select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; +---- +[4] [1] +[] [] +[] [3] +[] [] + +# array_positions with columns and scalars #2 (element is list) +query ?? +select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; +---- +[6] [] +[1] [] + +## array_replace (aliases: `list_replace`) + +# array_replace scalar function #1 +query ??? +select + array_replace(make_array(1, 2, 3, 4), 2, 3), + array_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +# array_replace scalar function #2 (element is list) +query ?? +select + array_replace( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + +# list_replace scalar function #3 (function alias `list_replace`) +query ??? +select list_replace( + make_array(1, 2, 3, 4), 2, 3), + list_replace(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + +# array_replace scalar function with columns #1 +query ? +select array_replace(column1, column2, column3) from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[7, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[13, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +# array_replace scalar function with columns #2 (element is list) +query ? +select array_replace(column1, column2, column3) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +# array_replace scalar function with columns and scalars #1 +query ??? +select + array_replace(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace(column1, 1, column3), + array_replace(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 1, 3, 2, 2, 1, 3, 2, 3] [1, 4, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +# array_replace scalar function with columns and scalars #2 (element is list) +query ??? +select + array_replace( + make_array( + [1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace(column1, make_array(1, 2, 3), column3), + array_replace(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +## array_replace_n (aliases: `list_replace_n`) + +# array_replace_n scalar function #1 +query ??? +select + array_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + array_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + array_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +# array_replace_n scalar function #2 (element is list) +query ?? +select + array_replace_n( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +# list_replace_n scalar function #3 (function alias `array_replace_n`) +query ??? +select + list_replace_n(make_array(1, 2, 3, 4), 2, 3, 2), + list_replace_n(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0, 2), + list_replace_n(make_array(1, 2, 3), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + +# array_replace_n scalar function with columns #1 +query ? +select + array_replace_n(column1, column2, column3, column4) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 2, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[10, 10, 10, 8, 10, 9, 10, 8, 7, 7] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +# array_replace_n scalar function with columns #2 (element is list) +query ? +select + array_replace_n(column1, column2, column3, column4) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +# array_replace_n scalar function with columns and scalars #1 +query ???? +select + array_replace_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3, column4), + array_replace_n(column1, 1, column3, column4), + array_replace_n(column1, column2, 4, column4), + array_replace_n(column1, column2, column3, 2) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 2, 3] [1, 4, 1, 3, 4, 2, 1, 3, 2, 3] +[1, 2, 2, 7, 5, 7, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [7, 7, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 7, 7] [10, 10, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] [13, 11, 12, 13, 11, 12, 10, 11, 12, 10] + +# array_replace_n scalar function with columns and scalars #2 (element is list) +query ???? +select + array_replace_n( + make_array( + [7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]), + column2, + column3, + column4 + ), + array_replace_n(column1, make_array(1, 2, 3), column3, column4), + array_replace_n(column1, column2, make_array(11, 12, 13), column4), + array_replace_n(column1, column2, column3, 2) +from nested_arrays_with_repeating_elements; +---- +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [10, 11, 12]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [19, 20, 21], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[28, 29, 30], [28, 29, 30], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[7, 8, 9], [2, 1, 3], [1, 5, 6], [10, 11, 12], [2, 1, 3], [7, 8, 9], [4, 5, 6]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] [[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +## array_replace_all (aliases: `list_replace_all`) + +# array_replace_all scalar function #1 +query ??? +select + array_replace_all(make_array(1, 2, 3, 4), 2, 3), + array_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + array_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +# array_replace_all scalar function #2 (element is list) +query ?? +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + +# list_replace_all scalar function #3 (function alias `array_replace_all`) +query ??? +select + list_replace_all(make_array(1, 2, 3, 4), 2, 3), + list_replace_all(make_array(1, 4, 4, 5, 4, 6, 7), 4, 0), + list_replace_all(make_array(1, 2, 3), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + +# array_replace_all scalar function with columns #1 +query ? +select + array_replace_all(column1, column2, column3) +from arrays_with_repeating_elements; +---- +[1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[7, 7, 5, 5, 6, 5, 5, 5, 7, 7] +[10, 10, 10, 8, 10, 9, 10, 8, 10, 10] +[13, 11, 12, 13, 11, 12, 13, 11, 12, 13] + +# array_replace_all scalar function with columns #2 (element is list) +query ? +select + array_replace_all(column1, column2, column3) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [10, 11, 12], [1, 2, 3], [7, 8, 9], [10, 11, 12], [7, 8, 9]] +[[19, 20, 21], [19, 20, 21], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [19, 20, 21], [19, 20, 21]] +[[28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24], [28, 29, 30], [25, 26, 27], [28, 29, 30], [22, 23, 24], [28, 29, 30], [28, 29, 30]] +[[37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39], [31, 32, 33], [34, 35, 36], [37, 38, 39]] + +# array_replace_all scalar function with columns and scalars #1 +query ??? +select + array_replace_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column3), + array_replace_all(column1, 1, column3), + array_replace_all(column1, column2, 4) +from arrays_with_repeating_elements; +---- +[1, 4, 4, 4, 5, 4, 4, 7, 7, 10, 7, 8] [4, 2, 4, 3, 2, 2, 4, 3, 2, 3] [1, 4, 1, 3, 4, 4, 1, 3, 4, 3] +[1, 2, 2, 7, 5, 7, 7, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 10, 10, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [4, 4, 4, 8, 4, 9, 4, 8, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 13, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [4, 11, 12, 4, 11, 12, 4, 11, 12, 4] + +# array_replace_all scalar function with columns and scalars #2 (element is list) +query ??? +select + array_replace_all( + make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), + column2, + column3 + ), + array_replace_all(column1, make_array(1, 2, 3), column3), + array_replace_all(column1, column2, make_array(11, 12, 13)) +from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [10, 11, 12], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [4, 5, 6], [10, 11, 12], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [11, 12, 13], [1, 2, 3], [7, 8, 9], [11, 12, 13], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [19, 20, 21], [13, 14, 15], [19, 20, 21], [19, 20, 21], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[11, 12, 13], [11, 12, 13], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] + +# array_replace with null handling + +statement ok +create table t as values + (make_array(3, 1, NULL, 3), 3, 4, 2), + (make_array(3, 1, NULL, 3), NULL, 5, 2), + (NULL, 3, 2, 1), + (make_array(3, 1, 3), 3, NULL, 1) +; + + +# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched +# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5 +# ([NULL], 3, 2, 1) => NULL +# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] + +query ?III? +select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t; +---- +[3, 1, , 3] 3 4 2 [4, 1, , 4] +[3, 1, , 3] NULL 5 2 [3, 1, 5, 3] +NULL 3 2 1 NULL +[3, 1, 3] 3 NULL 1 [, 1, 3] + + + +statement ok +drop table t; + + + +## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`) + +# array_to_string scalar function #1 +query TTT +select array_to_string(['h', 'e', 'l', 'l', 'o'], ','), array_to_string([1, 2, 3, 4, 5], '-'), array_to_string([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string scalar function #2 +query TTT +select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_repeat(array_repeat(array_repeat(3, 2), 2), 3), '/\'); +---- +11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3 + +# array_to_string scalar function #3 +query T +select array_to_string(make_array(), ',') +---- +(empty) + + +## array_union (aliases: `list_union`) + +# array_union scalar function #1 +query ? +select array_union([1, 2, 3, 4], [5, 6, 3, 4]); +---- +[1, 2, 3, 4, 5, 6] + +# array_union scalar function #2 +query ? +select array_union([1, 2, 3, 4], [5, 6, 7, 8]); +---- +[1, 2, 3, 4, 5, 6, 7, 8] + +# array_union scalar function #3 +query ? +select array_union([1,2,3], []); +---- +[1, 2, 3] + +# array_union scalar function #4 +query ? +select array_union([1, 2, 3, 4], [5, 4]); +---- +[1, 2, 3, 4, 5] + +# array_union scalar function #5 +statement ok +CREATE TABLE arrays_with_repeating_elements_for_union +AS VALUES + ([1], [2]), + ([2, 3], [3]), + ([3], [3, 4]) +; + +query ? +select array_union(column1, column2) from arrays_with_repeating_elements_for_union; +---- +[1, 2] +[2, 3] +[3, 4] + +statement ok +drop table arrays_with_repeating_elements_for_union; + +# array_union scalar function #6 +query ? +select array_union([], []); +---- +[] + +# array_union scalar function #7 +query ? +select array_union([[null]], []); +---- +[[]] + +# array_union scalar function #8 +query ? +select array_union([null], [null]); +---- +[] + +# array_union scalar function #9 +query ? +select array_union(null, []); +---- +[] + +# array_union scalar function #10 +query ? +select array_union(null, null); +---- +NULL + +# array_union scalar function #11 +query ? +select array_union([1.2, 3.0], [1.2, 3.0, 5.7]); +---- +[1.2, 3.0, 5.7] + +# array_union scalar function #12 +query ? +select array_union(['hello'], ['hello','datafusion']); +---- +[hello, datafusion] + + + + + + + + +# list_to_string scalar function #4 (function alias `array_to_string`) +query TTT +select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_join scalar function #5 (function alias `array_to_string`) +query TTT +select array_join(['h', 'e', 'l', 'l', 'o'], ','), array_join([1, 2, 3, 4, 5], '-'), array_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# list_join scalar function #6 (function alias `list_join`) +query TTT +select list_join(['h', 'e', 'l', 'l', 'o'], ','), list_join([1, 2, 3, 4, 5], '-'), list_join([1.0, 2.0, 3.0], '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string scalar function with nulls #1 +query TTT +select array_to_string(make_array('h', NULL, 'l', NULL, 'o'), ','), array_to_string(make_array(1, NULL, 3, NULL, 5), '-'), array_to_string(make_array(NULL, 2.0, 3.0), '|'); +---- +h,l,o 1-3-5 2|3 + +# array_to_string scalar function with nulls #2 +query TTT +select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + +# array_to_string with columns #1 + +# For reference +# select column1, column4 from arrays_values; +# ---- +# [, 2, 3, 4, 5, 6, 7, 8, 9, 10] , +# [11, 12, 13, 14, 15, 16, 17, 18, , 20] . +# [21, 22, 23, , 25, 26, 27, 28, 29, 30] - +# [31, 32, 33, 34, 35, , 37, 38, 39, 40] ok +# NULL @ +# [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] $ +# [51, 52, , 54, 55, 56, 57, 58, 59, 60] ^ +# [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] NULL + +query T +select array_to_string(column1, column4) from arrays_values; +---- +2,3,4,5,6,7,8,9,10 +11.12.13.14.15.16.17.18.20 +21-22-23-25-26-27-28-29-30 +31ok32ok33ok34ok35ok37ok38ok39ok40 +NULL +41$42$43$44$45$46$47$48$49$50 +51^52^54^55^56^57^58^59^60 +NULL + +query TT +select array_to_string(column1, '_'), array_to_string(make_array(1,2,3), '/') from arrays_values; +---- +2_3_4_5_6_7_8_9_10 1/2/3 +11_12_13_14_15_16_17_18_20 1/2/3 +21_22_23_25_26_27_28_29_30 1/2/3 +31_32_33_34_35_37_38_39_40 1/2/3 +NULL 1/2/3 +41_42_43_44_45_46_47_48_49_50 1/2/3 +51_52_54_55_56_57_58_59_60 1/2/3 +61_62_63_64_65_66_67_68_69_70 1/2/3 + +query TT +select array_to_string(column1, '_', '*'), array_to_string(make_array(make_array(1,2,3)), '.') from arrays_values; +---- +*_2_3_4_5_6_7_8_9_10 1.2.3 +11_12_13_14_15_16_17_18_*_20 1.2.3 +21_22_23_*_25_26_27_28_29_30 1.2.3 +31_32_33_34_35_*_37_38_39_40 1.2.3 +NULL 1.2.3 +41_42_43_44_45_46_47_48_49_50 1.2.3 +51_52_*_54_55_56_57_58_59_60 1.2.3 +61_62_63_64_65_66_67_68_69_70 1.2.3 + +## cardinality + +# cardinality scalar function +query III +select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinality(make_array('h', 'e', 'l', 'l', 'o')); +---- +5 3 5 + +# cardinality scalar function #2 +query II +select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); +---- +6 18 + +# cardinality scalar function #3 +query II +select cardinality(make_array()), cardinality(make_array(make_array())) +---- +NULL 0 + +# cardinality with columns +query III +select cardinality(column1), cardinality(column2), cardinality(column3) from arrays; +---- +4 3 5 +4 3 5 +4 3 5 +4 3 3 +NULL 3 4 +4 NULL 1 +4 3 NULL + +## array_remove (aliases: `list_remove`) + +# array_remove scalar function #1 +query ??? +select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), array_remove(make_array('h', 'e', 'l', 'l', 'o'), 'l'); +---- +[1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] + +query ??? +select + array_remove(make_array(1, null, 2, 3), 2), + array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), + array_remove(make_array('a', null, 'bc'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +# TODO: https://github.com/apache/arrow-datafusion/issues/7142 +# query +# select +# array_remove(make_array(1, null, 2), null), +# array_remove(make_array(1, null, 2, null), null); + +# array_remove scalar function #2 (element is list) +query ?? +select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + +# list_remove scalar function #3 (function alias `array_remove`) +query ??? +select list_remove(make_array(1, 2, 2, 1, 1), 2), list_remove(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove(make_array('h', 'e', 'l', 'l', 'o'), 'l'); +---- +[1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] + +# array_remove scalar function with columns #1 +query ? +select array_remove(column1, column2) from arrays_with_repeating_elements; +---- +[1, 1, 3, 2, 2, 1, 3, 2, 3] +[4, 5, 5, 6, 5, 5, 5, 4, 4] +[7, 7, 8, 7, 9, 7, 8, 7, 7] +[11, 12, 10, 11, 12, 10, 11, 12, 10] + +# array_remove scalar function with columns #2 (element is list) +query ? +select array_remove(column1, column2) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +# array_remove scalar function with columns and scalars #1 +query ?? +select array_remove(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove(column1, 1) from arrays_with_repeating_elements; +---- +[1, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8] [2, 1, 3, 2, 2, 1, 3, 2, 3] +[1, 2, 2, 5, 4, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 7, 10, 7, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +# array_remove scalar function with columns and scalars #2 (element is list) +query ?? +select array_remove(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), array_remove(column1, make_array(1, 2, 3)) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +## array_remove_n (aliases: `list_remove_n`) + +# array_remove_n scalar function #1 +query ??? +select array_remove_n(make_array(1, 2, 2, 1, 1), 2, 2), array_remove_n(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0, 2), array_remove_n(make_array('h', 'e', 'l', 'l', 'o'), 'l', 3); +---- +[1, 1, 1] [2.0, 2.0, 1.0] [h, e, o] + +# array_remove_n scalar function #2 (element is list) +query ?? +select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], 2), array_remove_n(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 2); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + +# list_remove_n scalar function #3 (function alias `array_remove_n`) +query ??? +select list_remove_n(make_array(1, 2, 2, 1, 1), 2, 2), list_remove_n(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0, 2), list_remove_n(make_array('h', 'e', 'l', 'l', 'o'), 'l', 3); +---- +[1, 1, 1] [2.0, 2.0, 1.0] [h, e, o] + +# array_remove_n scalar function with columns #1 +query ? +select array_remove_n(column1, column2, column4) from arrays_with_repeating_elements; +---- +[1, 1, 3, 1, 3, 2, 3] +[5, 5, 6, 5, 5, 5, 4, 4] +[8, 9, 8, 7, 7] +[11, 12, 11, 12, 11, 12] + +# array_remove_n scalar function with columns #2 (element is list) +query ? +select array_remove_n(column1, column2, column4) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [1, 2, 3], [7, 8, 9], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[22, 23, 24], [25, 26, 27], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36]] + +# array_remove_n scalar function with columns and scalars #1 +query ??? +select array_remove_n(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2, column4), array_remove_n(column1, 1, column4), array_remove_n(column1, column2, 2) from arrays_with_repeating_elements; +---- +[1, 4, 5, 4, 4, 7, 7, 10, 7, 8] [2, 3, 2, 2, 3, 2, 3] [1, 1, 3, 2, 1, 3, 2, 3] +[1, 2, 2, 5, 4, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] [5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] [7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] [11, 12, 11, 12, 10, 11, 12, 10] + +# array_remove_n scalar function with columns and scalars #2 (element is list) +query ??? +select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2, column4), array_remove_n(column1, make_array(1, 2, 3), column4), array_remove_n(column1, column2, 2) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [7, 8, 9], [4, 5, 6], [4, 5, 6], [7, 8, 9], [4, 5, 6], [7, 8, 9]] [[1, 2, 3], [1, 2, 3], [7, 8, 9], [4, 5, 6], [1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] [[13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +## array_remove_all (aliases: `list_removes`) + +# array_remove_all scalar function #1 +query ??? +select array_remove_all(make_array(1, 2, 2, 1, 1), 2), array_remove_all(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), array_remove_all(make_array('h', 'e', 'l', 'l', 'o'), 'l'); +---- +[1, 1, 1] [2.0, 2.0] [h, e, o] + +# array_remove_all scalar function #2 (element is list) +query ?? +select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove_all(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + +# list_remove_all scalar function #3 (function alias `array_remove_all`) +query ??? +select list_remove_all(make_array(1, 2, 2, 1, 1), 2), list_remove_all(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove_all(make_array('h', 'e', 'l', 'l', 'o'), 'l'); +---- +[1, 1, 1] [2.0, 2.0] [h, e, o] + +# array_remove_all scalar function with columns #1 +query ? +select array_remove_all(column1, column2) from arrays_with_repeating_elements; +---- +[1, 1, 3, 1, 3, 3] +[5, 5, 6, 5, 5, 5] +[8, 9, 8] +[11, 12, 11, 12, 11, 12] + +# array_remove_all scalar function with columns #2 (element is list) +query ? +select array_remove_all(column1, column2) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [1, 2, 3], [7, 8, 9], [1, 2, 3], [7, 8, 9], [7, 8, 9]] +[[13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15]] +[[22, 23, 24], [25, 26, 27], [22, 23, 24]] +[[31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36], [31, 32, 33], [34, 35, 36]] + +# array_remove_all scalar function with columns and scalars #1 +query ?? +select array_remove_all(make_array(1, 2, 2, 4, 5, 4, 4, 7, 7, 10, 7, 8), column2), array_remove_all(column1, 1) from arrays_with_repeating_elements; +---- +[1, 4, 5, 4, 4, 7, 7, 10, 7, 8] [2, 3, 2, 2, 3, 2, 3] +[1, 2, 2, 5, 7, 7, 10, 7, 8] [4, 4, 5, 5, 6, 5, 5, 5, 4, 4] +[1, 2, 2, 4, 5, 4, 4, 10, 8] [7, 7, 7, 8, 7, 9, 7, 8, 7, 7] +[1, 2, 2, 4, 5, 4, 4, 7, 7, 7, 8] [10, 11, 12, 10, 11, 12, 10, 11, 12, 10] + +# array_remove_all scalar function with columns and scalars #2 (element is list) +query ?? +select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]), column2), array_remove_all(column1, make_array(1, 2, 3)) from nested_arrays_with_repeating_elements; +---- +[[1, 2, 3], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[4, 5, 6], [7, 8, 9], [4, 5, 6], [4, 5, 6], [7, 8, 9], [4, 5, 6], [7, 8, 9]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [13, 14, 15], [19, 20, 21], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[10, 11, 12], [10, 11, 12], [13, 14, 15], [13, 14, 15], [16, 17, 18], [13, 14, 15], [13, 14, 15], [13, 14, 15], [10, 11, 12], [10, 11, 12]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] +[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] + +## trim_array (deprecated) + +## array_length (aliases: `list_length`) + +# array_length scalar function #1 +query III +select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3)), array_length(make_array([1, 2], [3, 4], [5, 6])); +---- +5 3 3 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + +# array_length scalar function #2 +query III +select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); +---- +5 3 3 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + +# array_length scalar function #3 +query III +select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); +---- +NULL NULL 2 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + +# array_length scalar function #4 +query II +select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); +---- +3 2 + +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + +# array_length scalar function #5 +query III +select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) +---- +0 0 NULL + +# array_length scalar function #6 nested array +query III +select array_length([[1, 2, 3, 4], [5, 6, 7, 8]]), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 1), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 2); +---- +2 2 4 + +# list_length scalar function #7 (function alias `array_length`) +query IIII +select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), list_length(make_array([1, 2], [3, 4], [5, 6])), array_length([[1, 2, 3, 4], [5, 6, 7, 8]], 3); +---- +5 3 3 NULL + +query III +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + +# array_length with columns +query I +select array_length(column1, column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# array_length with columns and scalars +query II +select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + +## array_dims (aliases: `list_dims`) + +# array dims error +# TODO this is a separate bug +query error Internal error: could not cast value to arrow_array::array::list_array::GenericListArray\. +select array_dims(1); + +# array_dims scalar function +query ??? +select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), array_dims(make_array([[[[1], [2]]]])); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + +# array_dims scalar function #2 +query ?? +select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); +---- +[1, 2, 3] [2, 5, 4] + +# array_dims scalar function #3 +query ?? +select array_dims(make_array()), array_dims(make_array(make_array())) +---- +NULL [1, 0] + +# list_dims scalar function #4 (function alias `array_dims`) +query ??? +select list_dims(make_array(1, 2, 3)), list_dims(make_array([1, 2], [3, 4])), list_dims(make_array([[[[1], [2]]]])); +---- +[3] [2, 2] [1, 1, 1, 2, 1] + +# array_dims with columns +query ??? +select array_dims(column1), array_dims(column2), array_dims(column3) from arrays; +---- +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [5] +[2, 2] [3] [3] +NULL [3] [4] +[2, 2] NULL [1] +[2, 2] [3] NULL + +## array_ndims (aliases: `list_ndims`) + +# array_ndims scalar function #1 + +query III +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); +---- +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; + +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 + +# array_ndims scalar function #2 +query II +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +---- +3 21 + +# array_ndims scalar function #3 +query II +select array_ndims(make_array()), array_ndims(make_array(make_array())) +---- +1 2 + +# list_ndims scalar function #4 (function alias `array_ndims`) +query III +select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), list_ndims(make_array([[[[1], [2]]]])); +---- +1 2 5 + +query II +select array_ndims(make_array()), array_ndims(make_array(make_array())) +---- +1 2 + +# array_ndims with columns +query III +select array_ndims(column1), array_ndims(column2), array_ndims(column3) from arrays; +---- +2 1 1 +2 1 1 +2 1 1 +2 1 1 +NULL 1 1 +2 NULL 1 +2 1 NULL + +## array_has/array_has_all/array_has_any + +query BBBBBBBBBBBB +select array_has(make_array(1,2), 1), + array_has(make_array(1,2,NULL), 1), + array_has(make_array([2,3], [3,4]), make_array(2,3)), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1], [2,3])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([4,5], [6])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1])), + array_has(make_array([[[1]]]), make_array([[1]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[2]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[1], [2]])), + list_has(make_array(1,2,3), 4), + array_contains(make_array(1,2,3), 3), + list_contains(make_array(1,2,3), 0) +; +---- +true true true true true false true false true false true false + +query BBBBBBBBBBBB +select array_has(arrow_cast(make_array(1,2), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array(1,2,NULL), 'LargeList(Int64)'), 1), + array_has(arrow_cast(make_array([2,3], [3,4]), 'LargeList(List(Int64))'), make_array(2,3)), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1], [2,3])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([4,5], [6])), + array_has(arrow_cast(make_array([[1], [2,3]], [[4,5], [6]]), 'LargeList(List(List(Int64)))'), make_array([1])), + array_has(arrow_cast(make_array([[[1]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[2]])), + array_has(arrow_cast(make_array([[[1]]], [[[1], [2]]]), 'LargeList(List(List(List(Int64))))'), make_array([[1], [2]])), + list_has(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 4), + array_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 3), + list_contains(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), 0) +; +---- +true true true true true false true false true false true false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Int64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Int64)'), arrow_cast(column4, 'LargeList(Int64)')), + array_has_any(arrow_cast(column5, 'LargeList(Int64)'), arrow_cast(column6, 'LargeList(Int64)')) +from array_has_table_1D; +---- +true true true +false false false + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Float64)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Float64)'), arrow_cast(column4, 'LargeList(Float64)')), + array_has_any(arrow_cast(column5, 'LargeList(Float64)'), arrow_cast(column6, 'LargeList(Float64)')) +from array_has_table_1D_Float; +---- +true true false +false false true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Boolean)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), arrow_cast(column4, 'LargeList(Boolean)')), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), arrow_cast(column6, 'LargeList(Boolean)')) +from array_has_table_1D_Boolean; +---- +false true true +true true true + +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BBB +select array_has(arrow_cast(column1, 'LargeList(Utf8)'), column2), + array_has_all(arrow_cast(column3, 'LargeList(Utf8)'), arrow_cast(column4, 'LargeList(Utf8)')), + array_has_any(arrow_cast(column5, 'LargeList(Utf8)'), arrow_cast(column6, 'LargeList(Utf8)')) +from array_has_table_1D_UTF8; +---- +true true false +false false true + +query BB +select array_has(column1, column2), + array_has_all(column3, column4) +from array_has_table_2D; +---- +false true +true false + +query BB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), column2), + array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from array_has_table_2D; +---- +false true +true false + +query B +select array_has_all(column1, column2) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has_all(arrow_cast(column1, 'LargeList(List(Float64))'), arrow_cast(column2, 'LargeList(List(Float64))')) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has(column1, column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query B +select array_has(arrow_cast(column1, 'LargeList(List(List(Int64)))'), column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query BBBB +select array_has(column1, make_array(5, 6)), + array_has(column1, make_array(7, NULL)), + array_has(column2, 5.5), + array_has(column3, 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBB +select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), + array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(7, NULL)), + array_has(arrow_cast(column2, 'LargeList(Float64)'), 5.5), + array_has(arrow_cast(column3, 'LargeList(Utf8)'), 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBBBBBBBBBBB +select array_has_all(make_array(1,2,3), make_array(1,3)), + array_has_all(make_array(1,2,3), make_array(1,4)), + array_has_all(make_array([1,2], [3,4]), make_array([1,2])), + array_has_all(make_array([1,2], [3,4]), make_array([1,3])), + array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), + array_has_all(make_array([[1,2,3]]), make_array([[1]])), + array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), + array_has_any(make_array(1,2,3), make_array(1,10,100)), + array_has_any(make_array(1,2,3), make_array(10,100)), + array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), + array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true + +## array_distinct + +query ? +select array_distinct(null); +---- +NULL + +query ? +select array_distinct([]); +---- +[] + +query ? +select array_distinct([[], []]); +---- +[[]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_UTF8; +---- +[a, bc, def] +[a, bc, def, defg] +[defg] + +query ? +select array_distinct(column1) +from array_distinct_table_2D; +---- +[[1, 2], [3, 4], [5, 6]] +[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +[, [5, 6]] + +query ? +select array_distinct(column1) +from array_distinct_table_1D_large; +---- +[1, 2, 3] +[1, 2, 3, 4, 5] +[3, 5] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Float; +---- +[1.0] [1.0, 3.0] [] +[] [2.0] [1.11] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ?????? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), + array_intersect(make_array(1,3,5), make_array(2,4,6)), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array(true, false), make_array(true)), + array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect([], null); +---- +[] + +query ? +select array_intersect(null, []); +---- +[] + +query ? +select array_intersect(null, null); +---- +NULL + +query ?????? +SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), + list_intersect(make_array(1,3,5), make_array(2,4,6)), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array(true, false), make_array(true)), + list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query BBBB +select list_has_all(make_array(1,2,3), make_array(4,5,6)), + list_has_all(make_array(1,2,3), make_array(1,2)), + list_has_any(make_array(1,2,3), make_array(4,5,6)), + list_has_any(make_array(1,2,3), make_array(1,2,4)) +; +---- +false true false true + +query ??? +select range(column2), + range(column1, column2), + range(column1, column2, column3) +from arrays_range; +---- +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9] +[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 7, 10] + +query ?????? +select range(5), + range(2, 5), + range(2, 10, 3), + range(1, 5, -1), + range(1, -5, 1), + range(1, -5, -1) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] + +query ??? +select generate_series(5), + generate_series(2, 5), + generate_series(2, 10, 3) +; +---- +[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] + +## array_except + +statement ok +CREATE TABLE array_except_table +AS VALUES + ([1, 2, 2, 3], [2, 3, 4]), + ([2, 3, 3], [3]), + ([3], [3, 3, 4]), + (null, [3, 4]), + ([1, 2], null), + (null, null) +; + +query ? +select array_except(column1, column2) from array_except_table; +---- +[1] +[2] +[] +NULL +[1, 2] +NULL + +statement ok +drop table array_except_table; + +statement ok +CREATE TABLE array_except_nested_list_table +AS VALUES + ([[1, 2], [3]], [[2], [3], [4, 5]]), + ([[1, 2], [3]], [[2], [1, 2]]), + ([[1, 2], [3]], null), + (null, [[1], [2, 3], [4, 5, 6]]), + ([[1], [2, 3], [4, 5, 6]], [[2, 3], [4, 5, 6], [1]]) +; + +query ? +select array_except(column1, column2) from array_except_nested_list_table; +---- +[[1, 2]] +[[3]] +[[1, 2], [3]] +NULL +[] + +statement ok +drop table array_except_nested_list_table; + +statement ok +CREATE TABLE array_except_table_float +AS VALUES + ([1.1, 2.2, 3.3], [2.2]), + ([1.1, 2.2, 3.3], [4.4]), + ([1.1, 2.2, 3.3], [3.3, 2.2, 1.1]) +; + +query ? +select array_except(column1, column2) from array_except_table_float; +---- +[1.1, 3.3] +[1.1, 2.2, 3.3] +[] + +statement ok +drop table array_except_table_float; + +statement ok +CREATE TABLE array_except_table_ut8 +AS VALUES + (['a', 'b', 'c'], ['a']), + (['a', 'bc', 'def'], ['g', 'def']), + (['a', 'bc', 'def'], null), + (null, ['a']) +; + +query ? +select array_except(column1, column2) from array_except_table_ut8; +---- +[b, c] +[a, bc] +[a, bc, def] +NULL + +statement ok +drop table array_except_table_ut8; + +statement ok +CREATE TABLE array_except_table_bool +AS VALUES + ([true, false, false], [false]), + ([true, true, true], [false]), + ([false, false, false], [true]), + ([true, false], null), + (null, [true, false]) +; + +query ? +select array_except(column1, column2) from array_except_table_bool; +---- +[true] +[true] +[false] +[true, false] +NULL + +statement ok +drop table array_except_table_bool; + +query ? +select array_except([], null); +---- +[] + +query ? +select array_except([], []); +---- +[] + +query ? +select array_except(null, []); +---- +NULL + +query ? +select array_except(null, null) +---- +NULL + +### Array operators tests + + +## array concatenate operator + +# array concatenate operator with scalars #1 (like array_concat scalar function) +query ?? +select make_array(1, 2, 3) || make_array(4, 5, 6) || make_array(7, 8, 9), make_array([1], [2]) || make_array([3], [4]); +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]] + +# array concatenate operator with scalars #2 (like array_append scalar function) +query ??? +select make_array(1, 2, 3) || 4, make_array(1.0, 2.0, 3.0) || 4.0, make_array('h', 'e', 'l', 'l') || 'o'; +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +# array concatenate operator with scalars #3 (like array_prepend scalar function) +query ??? +select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_array('e', 'l', 'l', 'o'); +---- +[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] + +## array containment operator + +# array containment operator with scalars #1 (at arrow) +query BBBBBBB +select make_array(1,2,3) @> make_array(1,3), + make_array(1,2,3) @> make_array(1,4), + make_array([1,2], [3,4]) @> make_array([1,2]), + make_array([1,2], [3,4]) @> make_array([1,3]), + make_array([1,2], [3,4]) @> make_array([1,2], [3,4], [5,6]), + make_array([[1,2,3]]) @> make_array([[1]]), + make_array([[1,2,3]]) @> make_array([[1,2,3]]); +---- +true false true false false false true + +# array containment operator with scalars #2 (arrow at) +query BBBBBBB +select make_array(1,3) <@ make_array(1,2,3), + make_array(1,4) <@ make_array(1,2,3), + make_array([1,2]) <@ make_array([1,2], [3,4]), + make_array([1,3]) <@ make_array([1,2], [3,4]), + make_array([1,2], [3,4], [5,6]) <@ make_array([1,2], [3,4]), + make_array([[1]]) <@ make_array([[1,2,3]]), + make_array([[1,2,3]]) <@ make_array([[1,2,3]]); +---- +true false true false false false true + +### Array casting tests + + +## make_array + +# make_array scalar function #1 +query ? +select make_array(1, 2.0) +---- +[1.0, 2.0] + +# make_array scalar function #2 +query ? +select make_array(null, 1.0) +---- +[, 1.0] + +# make_array scalar function #3 +query ? +select make_array(1, 2.0, null, 3) +---- +[1.0, 2.0, , 3.0] + +# make_array scalar function #4 +query ? +select make_array(1.0, '2', null) +---- +[1.0, 2, ] + +### FixedSizeListArray + +statement ok +CREATE EXTERNAL TABLE fixed_size_list_array STORED AS PARQUET LOCATION '../core/tests/data/fixed_size_list_array.parquet'; + +query T +select arrow_typeof(f0) from fixed_size_list_array; +---- +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 2) +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 2) + +query ? +select * from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select f0 from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select arrow_cast(f0, 'List(Int64)') from fixed_size_list_array; +---- +[1, 2] +[3, 4] + +query ? +select make_array(arrow_cast(f0, 'List(Int64)')) from fixed_size_list_array +---- +[[1, 2]] +[[3, 4]] + +query ? +select make_array(f0) from fixed_size_list_array +---- +[[1, 2]] +[[3, 4]] + +query ? +select array_concat(column1, [7]) from arrays_values_v2; +---- +[, 2, 3, 7] +[7] +[9, , 10, 7] +[, 1, 7] +[11, 12, 7] +[7] + +# flatten +query ??? +select flatten(make_array(1, 2, 1, 3, 2)), + flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), + flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); +---- +[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] + +query ???? +select column1, column2, column3, column4 from flatten_table; +---- +[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] +[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + +query ???? +select flatten(column1), + flatten(column2), + flatten(column3), + flatten(column4) +from flatten_table; +---- +[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + +# empty scalar function #1 +query B +select empty(make_array(1)); +---- +false + +query B +select empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + +# empty scalar function #2 +query B +select empty(make_array()); +---- +true + +query B +select empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + +# empty scalar function #3 +query B +select empty(make_array(NULL)); +---- +false + +query B +select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + +# empty scalar function #4 +query B +select empty(NULL); +---- +NULL + +# empty scalar function #5 +query B +select empty(column1) from arrays; +---- +false +false +false +false +NULL +false +false + +query B +select empty(arrow_cast(column1, 'LargeList(List(Int64))')) from arrays; +---- +false +false +false +false +NULL +false +false + +query ? +SELECT string_to_array('abcxxxdef', 'xxx') +---- +[abc, def] + +query ? +SELECT string_to_array('abc', '') +---- +[abc] + +query ? +SELECT string_to_array('abc', NULL) +---- +[a, b, c] + +query ? +SELECT string_to_array('abc def', ' ', 'def') +---- +[abc, ] + +query ? +select string_to_array(e, ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +query ? +select string_to_list(e, 'm') from values; +---- +[Lore, ] +[ipsu, ] +[dolor] +[sit] +[a, et] +[,] +[consectetur] +[adipiscing] +NULL + +### Delete tables + +statement ok +drop table values; + +statement ok +drop table values_without_nulls; + +statement ok +drop table nested_arrays; + +statement ok +drop table arrays; + +statement ok +drop table slices; + +statement ok +drop table arrayspop; + +statement ok +drop table arrays_values; + +statement ok +drop table arrays_values_v2; + +statement ok +drop table array_has_table_1D; + +statement ok +drop table array_has_table_1D_Float; + +statement ok +drop table array_has_table_1D_Boolean; + +statement ok +drop table array_has_table_1D_UTF8; + +statement ok +drop table array_has_table_2D; + +statement ok +drop table array_has_table_2D_float; + +statement ok +drop table array_has_table_3D; + +statement ok +drop table array_intersect_table_1D; + +statement ok +drop table array_intersect_table_1D_Float; + +statement ok +drop table array_intersect_table_1D_Boolean; + +statement ok +drop table array_intersect_table_1D_UTF8; + +statement ok +drop table array_intersect_table_2D; + +statement ok +drop table array_intersect_table_2D_float; + +statement ok +drop table array_intersect_table_3D; + +statement ok +drop table arrays_values_without_nulls; + +statement ok +drop table arrays_range; + +statement ok +drop table arrays_with_repeating_elements; + +statement ok +drop table nested_arrays_with_repeating_elements; + +statement ok +drop table flatten_table; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt new file mode 100644 index 0000000000000..5c1b6fb726ed7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Arrow Files Format support +############# + + +statement ok + +CREATE EXTERNAL TABLE arrow_simple +STORED AS ARROW +LOCATION '../core/tests/data/example.arrow'; + + +# physical plan +query TT +EXPLAIN SELECT * FROM arrow_simple +---- +logical_plan TableScan: arrow_simple projection=[f0, f1, f2] +physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.arrow]]}, projection=[f0, f1, f2] + +# correct content +query ITB +SELECT * FROM arrow_simple +---- +1 foo true +2 bar NULL +3 baz false +4 NULL true diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt similarity index 89% rename from datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt rename to datafusion/sqllogictest/test_files/arrow_typeof.slt index 4a3d39bdebcf3..3fad4d0f61b98 100644 --- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -180,24 +180,30 @@ drop table foo statement ok create table foo as select - arrow_cast(100, 'Decimal128(5,2)') as col_d128 - -- Can't make a decimal 156: - -- This feature is not implemented: Can't create a scalar from array of type "Decimal256(3, 2)" - --arrow_cast(100, 'Decimal256(5,2)') as col_d256 + arrow_cast(100, 'Decimal128(5,2)') as col_d128, + arrow_cast(100, 'Decimal256(5,2)') as col_d256 ; ## Ensure each column in the table has the expected type -query T +query TT SELECT - arrow_typeof(col_d128) - -- arrow_typeof(col_d256), + arrow_typeof(col_d128), + arrow_typeof(col_d256) FROM foo; ---- -Decimal128(5, 2) +Decimal128(5, 2) Decimal256(5, 2) +query RR +SELECT + col_d128, + col_d256 + FROM foo; +---- +100 100 + statement ok drop table foo @@ -303,9 +309,11 @@ select arrow_cast('30 minutes', 'Interval(MonthDayNano)'); ## Duration -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nThis feature is not implemented: Can't create a scalar from array of type "Duration\(Second\)" +query ? --- select arrow_cast(interval '30 minutes', 'Duration(Second)'); +---- +0 days 0 hours 30 mins 0 secs query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); @@ -330,3 +338,41 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); + + +## List + + +query ? +select arrow_cast('1', 'List(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'List(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); +---- +List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + +## LargeList + + +query ? +select arrow_cast('1', 'LargeList(Int64)'); +---- +[1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); +---- +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt new file mode 100644 index 0000000000000..3f21274c009fa --- /dev/null +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +statement ok +CREATE EXTERNAL TABLE alltypes_plain ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_snappy ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.snappy.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_bzip2 ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.bzip2.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_xz ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.xz.avro' + +statement ok +CREATE EXTERNAL TABLE alltypes_plain_zstandard ( + id INT NOT NULL, + bool_col BOOLEAN NOT NULL, + tinyint_col TINYINT NOT NULL, + smallint_col SMALLINT NOT NULL, + int_col INT NOT NULL, + bigint_col BIGINT NOT NULL, + float_col FLOAT NOT NULL, + double_col DOUBLE NOT NULL, + date_string_col BYTEA NOT NULL, + string_col VARCHAR NOT NULL, + timestamp_col TIMESTAMP NOT NULL, +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/alltypes_plain.zstandard.avro' + +statement ok +CREATE EXTERNAL TABLE single_nan ( + mycol FLOAT +) +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/single_nan.avro' + +statement ok +CREATE EXTERNAL TABLE nested_records +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/nested_records.avro' + +statement ok +CREATE EXTERNAL TABLE simple_enum +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/simple_enum.avro' + +statement ok +CREATE EXTERNAL TABLE simple_fixed +STORED AS AVRO +WITH HEADER ROW +LOCATION '../../testing/data/avro/simple_fixed.avro' + +# test avro query +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with snappy +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_snappy +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with bzip2 +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_bzip2 +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with xz +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_xz +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro query with zstandard +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_zstandard +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro single nan schema +query R +SELECT mycol FROM single_nan +---- +NULL + +# test avro query multi files +query IT +SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# test avro nested records +query ???? +SELECT f1, f2, f3, f4 FROM nested_records +---- +{f1_1: aaa, f1_2: 10, f1_3: {f1_3_1: 3.14}} [{f2_1: true, f2_2: 1.2}, {f2_1: true, f2_2: 2.2}] {f3_1: xyz} [{f4_1: 200}, ] +{f1_1: bbb, f1_2: 20, f1_3: {f1_3_1: 3.14}} [{f2_1: false, f2_2: 10.2}] NULL [, {f4_1: 300}] + +# test avro enum +query TTT +SELECT f1, f2, f3 FROM simple_enum +---- +a g j +b h k +c e NULL +d f i + +# test avro fixed +query ??? +SELECT f1, f2, f3 FROM simple_fixed +---- +6162636465 666768696a6b6c6d6e6f 414243444546 +3132333435 31323334353637383930 NULL + +# test avro explain +query TT +EXPLAIN SELECT count(*) from alltypes_plain +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--TableScan: alltypes_plain projection=[] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]} diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt new file mode 100644 index 0000000000000..0568ada3ad7dc --- /dev/null +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Tests for Binary +############# + +# Basic literals encoded as hex +query ?T +SELECT X'FF01', arrow_typeof(X'FF01'); +---- +ff01 Binary + +# Invaid hex values +query error DataFusion error: Error during planning: Invalid HexStringLiteral 'Z' +SELECT X'Z' + +# Insert binary data into tables +statement ok +CREATE TABLE t +AS VALUES + ('FF01', X'FF01'), + ('ABC', X'ABC'), + ('000', X'000'); + +query T?TT +SELECT column1, column2, arrow_typeof(column1), arrow_typeof(column2) +FROM t; +---- +FF01 ff01 Utf8 Binary +ABC 0abc Utf8 Binary +000 0000 Utf8 Binary + +# comparisons +query ?BBBB +SELECT + column2, + -- binary compare with string + column2 = 'ABC', + column2 <> 'ABC', + -- binary compared with binary + column2 = X'ABC', + column2 <> X'ABC' +FROM t; +---- +ff01 false true false true +0abc false true true false +0000 false true false true + + +# predicates +query T? +SELECT column1, column2 +FROM t +WHERE column2 > X'123'; +---- +FF01 ff01 +ABC 0abc + +# order by +query T? +SELECT * +FROM t +ORDER BY column2; +---- +000 0000 +ABC 0abc +FF01 ff01 + +# group by +query I +SELECT count(*) +FROM t +GROUP BY column1; +---- +1 +1 +1 + +statement ok +drop table t; + +############# +## Tests for FixedSizeBinary +############# + +# fixed_size_binary +statement ok +CREATE TABLE t_source +AS VALUES + (X'000102', X'000102'), + (X'003102', X'000102'), + (NULL, X'000102'), + (X'FF0102', X'000102'), + (X'000102', X'000102') +; + +# Create a table with FixedSizeBinary +statement ok +CREATE TABLE t +AS SELECT + arrow_cast(column1, 'FixedSizeBinary(3)') as "column1", + arrow_cast(column2, 'FixedSizeBinary(3)') as "column2" +FROM t_source; + +query ?T +SELECT column1, arrow_typeof(column1) FROM t; +---- +000102 FixedSizeBinary(3) +003102 FixedSizeBinary(3) +NULL FixedSizeBinary(3) +ff0102 FixedSizeBinary(3) +000102 FixedSizeBinary(3) + +# Comparison +query ??BB +SELECT + column1, + column2, + column1 = arrow_cast(X'000102', 'FixedSizeBinary(3)'), + column1 = column2 +FROM t +---- +000102 000102 true true +003102 000102 false false +NULL 000102 NULL NULL +ff0102 000102 false false +000102 000102 true true + + +# Comparison to different sized field +query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = FixedSizeBinary\(2\) +SELECT column1, column1 = arrow_cast(X'0102', 'FixedSizeBinary(2)') FROM t + +# Comparison to different sized Binary +query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = Binary +SELECT column1, column1 = X'0102' FROM t + +statement ok +drop table t_source + +statement ok +drop table t + + +############# +## Tests for binary that contains strings +############# + +statement ok +CREATE TABLE t_source +AS VALUES + ('Foo'), + (NULL), + ('Bar'), + ('FooBar') +; + +# Create a table with Binary, LargeBinary but really has strings +statement ok +CREATE TABLE t +AS SELECT + arrow_cast(column1, 'Binary') as "binary", + arrow_cast(column1, 'LargeBinary') as "largebinary" +FROM t_source; + +query ??TT +SELECT binary, largebinary, cast(binary as varchar) as binary_str, cast(largebinary as varchar) as binary_largestr from t; +---- +466f6f 466f6f Foo Foo +NULL NULL NULL NULL +426172 426172 Bar Bar +466f6f426172 466f6f426172 FooBar FooBar + +# ensure coercion works for = and <> +query ?T +SELECT binary, cast(binary as varchar) as str FROM t WHERE binary = 'Foo'; +---- +466f6f Foo + +query ?T +SELECT binary, cast(binary as varchar) as str FROM t WHERE binary <> 'Foo'; +---- +426172 Bar +466f6f426172 FooBar + +# order by +query ? +SELECT binary FROM t ORDER BY binary; +---- +426172 +466f6f +466f6f426172 +NULL + +# order by +query ? +SELECT largebinary FROM t ORDER BY largebinary; +---- +426172 +466f6f +466f6f426172 +NULL + +# LIKE +query ? +SELECT binary FROM t where binary LIKE '%F%'; +---- +466f6f +466f6f426172 + +query ? +SELECT largebinary FROM t where largebinary LIKE '%F%'; +---- +466f6f +466f6f426172 + +# character_length function +query TITI +SELECT + cast(binary as varchar) as str, + character_length(binary) as binary_len, + cast(largebinary as varchar) as large_str, + character_length(binary) as largebinary_len +from t; +---- +Foo 3 Foo 3 +NULL NULL NULL NULL +Bar 3 Bar 3 +FooBar 6 FooBar 6 + +query I +SELECT character_length(X'20'); +---- +1 + +# still errors on values that can not be coerced to utf8 +query error Encountered non UTF\-8 data: invalid utf\-8 sequence of 1 bytes from index 0 +SELECT character_length(X'c328'); + +# regexp_replace +query TTTT +SELECT + cast(binary as varchar) as str, + regexp_replace(binary, 'F', 'f') as binary_replaced, + cast(largebinary as varchar) as large_str, + regexp_replace(largebinary, 'F', 'f') as large_binary_replaced +from t; +---- +Foo foo Foo foo +NULL NULL NULL NULL +Bar Bar Bar Bar +FooBar fooBar FooBar fooBar diff --git a/datafusion/core/tests/sqllogictests/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/cast.slt rename to datafusion/sqllogictest/test_files/cast.slt diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt new file mode 100644 index 0000000000000..f6afa525adcc7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -0,0 +1,275 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# This file contains the clickbench schema and queries +# and the first 10 rows of data. Since ClickBench contains case sensitive queries +# this is also a good test of that usecase too + +# create.sql came from +# https://github.com/ClickHouse/ClickBench/blob/8b9e3aa05ea18afa427f14909ddc678b8ef0d5e6/datafusion/create.sql +# Data file made with DuckDB: +# COPY (SELECT * FROM 'hits.parquet' LIMIT 10) TO 'clickbench_hits_10.parquet' (FORMAT PARQUET); + +statement ok +CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '../core/tests/data/clickbench_hits_10.parquet'; + + +# queries.sql came from +# https://github.com/ClickHouse/ClickBench/blob/8b9e3aa05ea18afa427f14909ddc678b8ef0d5e6/datafusion/queries.sql + +query I +SELECT COUNT(*) FROM hits; +---- +10 + +query I +SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; +---- +0 + +query IIR +SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; +---- +0 10 0 + +query R +SELECT AVG("UserID") FROM hits; +---- +-304548765855551600 + +query I +SELECT COUNT(DISTINCT "UserID") FROM hits; +---- +5 + +query I +SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; +---- +1 + +query DD +SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits; +---- +2013-07-15 2013-07-15 + +query II +SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; +---- + +query II rowsort +SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; +---- +197 1 +229 1 +39 1 +839 2 + +query IIIRI rowsort +SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; +---- +197 0 2 0 1 +229 0 1 0 1 +39 0 1 0 1 +839 0 6 0 2 + +query TI +SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; +---- + +query ITI +SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; +---- + +query TI +SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +---- + +query TI +SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; +---- + +query ITI +SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; +---- + +query II rowsort +SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +---- +-2461439046089301801 5 +376160620089546609 1 +427738049800818189 1 +519640690937130534 2 +7418527520126366595 1 + +query ITI rowsort +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; +---- +-2461439046089301801 (empty) 5 +376160620089546609 (empty) 1 +427738049800818189 (empty) 1 +519640690937130534 (empty) 2 +7418527520126366595 (empty) 1 + +query ITI rowsort +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; +---- +-2461439046089301801 (empty) 5 +376160620089546609 (empty) 1 +427738049800818189 (empty) 1 +519640690937130534 (empty) 2 +7418527520126366595 (empty) 1 + +query IRTI rowsort +SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; +---- +-2461439046089301801 18 (empty) 1 +-2461439046089301801 33 (empty) 1 +-2461439046089301801 38 (empty) 1 +-2461439046089301801 56 (empty) 1 +-2461439046089301801 58 (empty) 1 +376160620089546609 30 (empty) 1 +427738049800818189 40 (empty) 1 +519640690937130534 26 (empty) 1 +519640690937130534 36 (empty) 1 +7418527520126366595 18 (empty) 1 + +query I +SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; +---- + +query I +SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; +---- +0 + +query TTI +SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +---- + +query TTTII +SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; +---- + +query IITIIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIII +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +---- + +query T +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +---- + +query T +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; +---- + +query T +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime"), "SearchPhrase" LIMIT 10; +---- + +query IRI +SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; +---- + +query TRIT +SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; +---- + +query IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII +SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; +---- +0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 840 850 860 870 880 890 + +query IIIIR +SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; +---- + +query IIIIR +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; +---- + +query IIIIR rowsort +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; +---- +4894690465724379622 1568366281 1 0 0 +5206346422301499756 -1216690514 1 0 0 +6308646140879811077 -1216690514 1 0 0 +6635790769678439148 1427531677 1 0 0 +6864353419233967042 1568366281 1 0 0 +8120543446287442873 -1216690514 1 0 0 +8156744413230856864 -1216690514 1 0 0 +8740403056911509777 1615432634 1 0 0 +8924809397503602651 -1216690514 1 0 0 +9110818468285196899 -1216690514 1 0 0 + +query TI rowsort +SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; +---- +(empty) 5 +http://afisha.mail.ru/catalog/314/women.ru/ency=1&page3/?errovat-pinniki 1 +http://bonprix.ru/index.ru/cinema/art/0 986 424 233 сезон 1 +http://bonprix.ru/index.ru/cinema/art/A00387,3797); ru)&bL 1 +http://holodilnik.ru/russia/05jul2013&model=0 1 +http://tours/Ekategoriya%2F&sr=http://slovareniye 1 + +query ITI rowsort +SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; +---- +1 (empty) 5 +1 http://afisha.mail.ru/catalog/314/women.ru/ency=1&page3/?errovat-pinniki 1 +1 http://bonprix.ru/index.ru/cinema/art/0 986 424 233 сезон 1 +1 http://bonprix.ru/index.ru/cinema/art/A00387,3797); ru)&bL 1 +1 http://holodilnik.ru/russia/05jul2013&model=0 1 +1 http://tours/Ekategoriya%2F&sr=http://slovareniye 1 + +query IIIII rowsort +SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; +---- +-1216690514 -1216690515 -1216690516 -1216690517 6 +1427531677 1427531676 1427531675 1427531674 1 +1568366281 1568366280 1568366279 1568366278 2 +1615432634 1615432633 1615432632 1615432631 1 + +query TI +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; +---- + +query TI +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; +---- + +query TI +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +---- + +query IIITTI +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +---- + +query IDI +SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100; +---- + +query III +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; +---- + +query PI +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; +---- diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt new file mode 100644 index 0000000000000..02ab330833159 --- /dev/null +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# tests for copy command +statement ok +create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, 'Bar'); + +# Copy to directory as multiple files +query IT +COPY source_table TO 'test_files/scratch/copy/table' (format parquet, single_file_output false, compression 'zstd(10)'); +---- +2 + +query TT +EXPLAIN COPY source_table TO 'test_files/scratch/copy/table' (format parquet, single_file_output false, compression 'zstd(10)'); +---- +logical_plan +CopyTo: format=parquet output_url=test_files/scratch/copy/table single_file_output=false options: (compression 'zstd(10)') +--TableScan: source_table projection=[col1, col2] +physical_plan +FileSinkExec: sink=ParquetSink(file_groups=[]) +--MemoryExec: partitions=1, partition_sizes=[1] + +# Error case +query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! +EXPLAIN COPY source_table to 'test_files/scratch/copy/table' + +query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: query"\) +EXPLAIN COPY source_table to 'test_files/scratch/copy/table' (format parquet, single_file_output false) +query TT +EXPLAIN COPY source_table to 'test_files/scratch/copy/table' (format parquet, per_thread_output true) + +# Copy more files to directory via query +query IT +COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table' (format parquet, single_file_output false); +---- +4 + +# validate multiple parquet file output +statement ok +CREATE EXTERNAL TABLE validate_parquet STORED AS PARQUET LOCATION 'test_files/scratch/copy/table/'; + +query IT +select * from validate_parquet; +---- +1 Foo +2 Bar +1 Foo +2 Bar +1 Foo +2 Bar + +# Copy parquet with all supported statment overrides +query IT +COPY source_table +TO 'test_files/scratch/copy/table_with_options' +(format parquet, +single_file_output false, +compression snappy, +'compression::col1' 'zstd(5)', +'compression::col2' snappy, +max_row_group_size 12345, +data_pagesize_limit 1234, +write_batch_size 1234, +writer_version 2.0, +dictionary_page_size_limit 123, +created_by 'DF copy.slt', +column_index_truncate_length 123, +data_page_row_count_limit 1234, +bloom_filter_enabled true, +'bloom_filter_enabled::col1' false, +'bloom_filter_fpp::col2' 0.456, +'bloom_filter_ndv::col2' 456, +encoding plain, +'encoding::col1' DELTA_BINARY_PACKED, +'dictionary_enabled::col2' true, +dictionary_enabled false, +statistics_enabled page, +'statistics_enabled::col2' none, +max_statistics_size 123, +bloom_filter_fpp 0.001, +bloom_filter_ndv 100 +) +---- +2 + +# validate multiple parquet file output with all options set +statement ok +CREATE EXTERNAL TABLE validate_parquet_with_options STORED AS PARQUET LOCATION 'test_files/scratch/copy/table_with_options/'; + +query IT +select * from validate_parquet_with_options; +---- +1 Foo +2 Bar + +# Copy from table to single file +query IT +COPY source_table to 'test_files/scratch/copy/table.parquet'; +---- +2 + +# validate single parquet file output +statement ok +CREATE EXTERNAL TABLE validate_parquet_single STORED AS PARQUET LOCATION 'test_files/scratch/copy/table.parquet'; + +query IT +select * from validate_parquet_single; +---- +1 Foo +2 Bar + +# copy from table to folder of compressed json files +query IT +COPY source_table to 'test_files/scratch/copy/table_json_gz' (format json, single_file_output false, compression 'gzip'); +---- +2 + +# validate folder of csv files +statement ok +CREATE EXTERNAL TABLE validate_json_gz STORED AS json COMPRESSION TYPE gzip LOCATION 'test_files/scratch/copy/table_json_gz'; + +query IT +select * from validate_json_gz; +---- +1 Foo +2 Bar + +# copy from table to folder of compressed csv files +query IT +COPY source_table to 'test_files/scratch/copy/table_csv' (format csv, single_file_output false, header false, compression 'gzip'); +---- +2 + +# validate folder of csv files +statement ok +CREATE EXTERNAL TABLE validate_csv STORED AS csv COMPRESSION TYPE gzip LOCATION 'test_files/scratch/copy/table_csv'; + +query IT +select * from validate_csv; +---- +1 Foo +2 Bar + +# Copy from table to single csv +query IT +COPY source_table to 'test_files/scratch/copy/table.csv'; +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_single_csv STORED AS csv WITH HEADER ROW LOCATION 'test_files/scratch/copy/table.csv'; + +query IT +select * from validate_single_csv; +---- +1 Foo +2 Bar + +# Copy from table to folder of json +query IT +COPY source_table to 'test_files/scratch/copy/table_json' (format json, single_file_output false); +---- +2 + +# Validate json output +statement ok +CREATE EXTERNAL TABLE validate_json STORED AS json LOCATION 'test_files/scratch/copy/table_json'; + +query IT +select * from validate_json; +---- +1 Foo +2 Bar + +# Copy from table to single json file +query IT +COPY source_table to 'test_files/scratch/copy/table.json'; +---- +2 + +# Validate single JSON file` +statement ok +CREATE EXTERNAL TABLE validate_single_json STORED AS json LOCATION 'test_files/scratch/copy/table_json'; + +query IT +select * from validate_single_json; +---- +1 Foo +2 Bar + +# COPY csv files with all options set +query IT +COPY source_table +to 'test_files/scratch/copy/table_csv_with_options' +(format csv, +single_file_output false, +header false, +compression 'uncompressed', +datetime_format '%FT%H:%M:%S.%9f', +delimiter ';', +null_value 'NULLVAL'); +---- +2 + +# Validate single csv output +statement ok +CREATE EXTERNAL TABLE validate_csv_with_options +STORED AS csv +LOCATION 'test_files/scratch/copy/table_csv_with_options'; + +query T +select * from validate_csv_with_options; +---- +1;Foo +2;Bar + +# Error cases: + +# Copy from table with options +query error DataFusion error: Invalid or Unsupported Configuration: Found unsupported option row_group_size with value 55 for JSON format! +COPY source_table to 'test_files/scratch/copy/table.json' (row_group_size 55); + +# Incomplete statement +query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) +COPY (select col2, sum(col1) from source_table + +# Copy from table with non literal +query error DataFusion error: SQL error: ParserError\("Expected ',' or '\)' after option definition, found: \+"\) +COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); diff --git a/datafusion/core/tests/sqllogictests/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/create_external_table.slt rename to datafusion/sqllogictest/test_files/create_external_table.slt diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt new file mode 100644 index 0000000000000..9facb064bf32a --- /dev/null +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create_external_table_with_quote_escape +statement ok +CREATE EXTERNAL TABLE csv_with_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('quote' '~') +LOCATION '../core/tests/data/quote.csv'; + +statement ok +CREATE EXTERNAL TABLE csv_with_escape ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +WITH HEADER ROW +DELIMITER ',' +OPTIONS ('escape' '\"') +LOCATION '../core/tests/data/escape.csv'; + +query TT +select * from csv_with_quote; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 +id4 value4 +id5 value5 +id6 value6 +id7 value7 +id8 value8 +id9 value9 + +query TT +select * from csv_with_escape; +---- +id0 value"0 +id1 value"1 +id2 value"2 +id3 value"3 +id4 value"4 +id5 value"5 +id6 value"6 +id7 value"7 +id8 value"8 +id9 value"9 diff --git a/datafusion/core/tests/sqllogictests/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/cte.slt rename to datafusion/sqllogictest/test_files/cte.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt similarity index 83% rename from datafusion/core/tests/sqllogictests/test_files/dates.slt rename to datafusion/sqllogictest/test_files/dates.slt index 5b76739e95ba9..a93a7ff7e73cd 100644 --- a/datafusion/core/tests/sqllogictests/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -85,23 +85,25 @@ g h ## Plan error when compare Utf8 and timestamp in where clause -statement error Error during planning: Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 can't be evaluated because there isn't a common type to coerce the types to +statement error DataFusion error: type_coercion\ncaused by\nError during planning: Cannot coerce arithmetic expression Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 to valid types select i_item_desc from test where d3_date > now() + '5 days'; # DATE minus DATE # https://github.com/apache/arrow-rs/issues/4383 -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Cast error: Cannot perform arithmetic operation between array of type Date32 and array of type Date32 +query ? SELECT DATE '2023-04-09' - DATE '2023-04-02'; +---- +7 days 0 hours 0 mins 0 secs # DATE minus Timestamp query ? SELECT DATE '2023-04-09' - '2000-01-01T00:00:00'::timestamp; ---- -0 years 0 mons 8499 days 0 hours 0 mins 0.000000000 secs +8499 days 0 hours 0 mins 0.000000000 secs # Timestamp minus DATE query ? SELECT '2023-01-01T00:00:00'::timestamp - DATE '2021-01-01'; ---- -0 years 0 mons 730 days 0 hours 0 mins 0.000000000 secs +730 days 0 hours 0 mins 0.000000000 secs diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt similarity index 92% rename from datafusion/core/tests/sqllogictests/test_files/ddl.slt rename to datafusion/sqllogictest/test_files/ddl.slt index 1cf67be3a2182..682972b5572a9 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -256,7 +256,7 @@ DROP VIEW non_existent_view ########## statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; # create_table_as statement ok @@ -302,7 +302,7 @@ CREATE TABLE my_table(c1 float, c2 double, c3 boolean, c4 varchar) AS SELECT *,c query RRBT rowsort SELECT * FROM my_table order by c1 LIMIT 1 ---- -0.00001 0.000000000001 true 1 +0.00001 0.000000000001 true true statement ok DROP TABLE my_table; @@ -312,7 +312,7 @@ DROP TABLE aggregate_simple # Arrow format statement ok -CREATE external table arrow_simple STORED as ARROW LOCATION 'tests/data/example.arrow'; +CREATE external table arrow_simple STORED as ARROW LOCATION '../core/tests/data/example.arrow'; query ITB rowsort SELECT * FROM arrow_simple order by f1 LIMIT 1 @@ -348,6 +348,9 @@ SELECT * FROM new_table; statement ok DROP TABLE new_table +statement ok +DROP TABLE my_table; + # create_table_with_schema_as_multiple_values statement ok CREATE TABLE test_table(c1 int, c2 float, c3 varchar) AS VALUES(1, 2, 'hello'),(2, 1, 'there'),(3, 0, '!'); @@ -362,7 +365,32 @@ SELECT * FROM new_table 2 1 there statement ok -DROP TABLE my_table; +DROP TABLE new_table; + +# Select into without alias names of window aggregates +statement ok +SELECT SUM(c1) OVER(ORDER BY c2), c2, c3 INTO new_table FROM test_table + +query IRT +SELECT * FROM new_table +---- +3 0 ! +5 1 there +6 2 hello + +statement ok +DROP TABLE new_table; + +# Create table as without alias names of window aggregates +statement ok +CREATE TABLE new_table AS SELECT SUM(c1) OVER(ORDER BY c2), c2, c3 FROM test_table + +query IRT +SELECT * FROM new_table +---- +3 0 ! +5 1 there +6 2 hello statement ok DROP TABLE new_table; @@ -442,7 +470,7 @@ statement ok CREATE EXTERNAL TABLE csv_with_timestamps ( name VARCHAR, ts TIMESTAMP -) STORED AS CSV LOCATION 'tests/data/timestamps.csv'; +) STORED AS CSV LOCATION '../core/tests/data/timestamps.csv'; query TP SELECT * from csv_with_timestamps @@ -468,7 +496,7 @@ CREATE EXTERNAL TABLE csv_with_timestamps ( ) STORED AS CSV PARTITIONED BY (c_date) -LOCATION 'tests/data/partitioned_table'; +LOCATION '../core/tests/data/partitioned_table'; query TPD SELECT * from csv_with_timestamps where c_date='2018-11-13' @@ -507,7 +535,7 @@ DROP VIEW y; # create_pipe_delimited_csv_table() statement ok -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW DELIMITER '|' LOCATION 'tests/data/aggregate_simple_pipe.csv'; +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW DELIMITER '|' LOCATION '../core/tests/data/aggregate_simple_pipe.csv'; query RRB @@ -553,14 +581,14 @@ statement ok CREATE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 BIGINT); statement ok -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' # Should not recreate the same EXTERNAL table statement error Execution error: Table 'aggregate_simple' already exists -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' statement ok -CREATE EXTERNAL TABLE IF NOT EXISTS aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE IF NOT EXISTS aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' # create bad custom table statement error DataFusion error: Execution error: Unable to find factory for DELTATABLE @@ -642,33 +670,6 @@ describe TABLE_WITHOUT_NORMALIZATION FIELD1 Int64 YES FIELD2 Int64 YES -query R -select 10000000000000000000.01 ----- -10000000000000000000 - -query T -select arrow_typeof(10000000000000000000.01) ----- -Float64 - -statement ok -set datafusion.sql_parser.parse_float_as_decimal = true; - -query R -select 10000000000000000000.01 ----- -10000000000000000000.01 - -query T -select arrow_typeof(10000000000000000000.01) ----- -Decimal128(22, 2) - -# Restore those to default value again -statement ok -set datafusion.sql_parser.parse_float_as_decimal = false; - statement ok set datafusion.sql_parser.enable_ident_normalization = true; @@ -689,7 +690,7 @@ drop table foo; # create csv table with empty csv file statement ok -CREATE EXTERNAL TABLE empty STORED AS CSV WITH HEADER ROW LOCATION 'tests/data/empty.csv'; +CREATE EXTERNAL TABLE empty STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/empty.csv'; query TTI select column_name, data_type, ordinal_position from information_schema.columns where table_name='empty';; @@ -742,14 +743,14 @@ statement ok CREATE UNBOUNDED external table t(c1 integer, c2 integer, c3 integer) STORED as CSV WITH HEADER ROW -LOCATION 'tests/data/empty.csv'; +LOCATION '../core/tests/data/empty.csv'; # should see infinite_source=true in the explain query TT explain select c1 from t; ---- logical_plan TableScan: t projection=[c1] -physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/empty.csv]]}, projection=[c1], infinite_source=true, has_header=true +physical_plan StreamingTableExec: partition_sizes=1, projection=[c1], infinite_source=true statement ok drop table t; @@ -760,7 +761,7 @@ statement ok CREATE external table t(c1 integer, c2 integer, c3 integer) STORED as CSV WITH HEADER ROW -LOCATION 'tests/data/empty.csv'; +LOCATION '../core/tests/data/empty.csv'; # expect to see no infinite_source in the explain query TT diff --git a/datafusion/core/tests/sqllogictests/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt similarity index 79% rename from datafusion/core/tests/sqllogictests/test_files/decimal.slt rename to datafusion/sqllogictest/test_files/decimal.slt index a6ec1edfd0f36..c220a5fc9a527 100644 --- a/datafusion/core/tests/sqllogictests/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -45,7 +45,7 @@ c5 DECIMAL(12,7) NOT NULL ) STORED AS CSV WITH HEADER ROW -LOCATION 'tests/data/decimal_data.csv'; +LOCATION '../core/tests/data/decimal_data.csv'; query TT @@ -124,6 +124,12 @@ select arrow_typeof(avg(c1)), avg(c1) from decimal_simple; Decimal128(14, 10) 0.0000366666 +query TR +select arrow_typeof(median(c1)), median(c1) from decimal_simple; +---- +Decimal128(10, 6) 0.00004 + + query RRIBR rowsort select * from decimal_simple where c1=CAST(0.00002 as Decimal(10,8)); ---- @@ -359,7 +365,7 @@ select c1*c5 from decimal_simple; query T select arrow_typeof(c1/cast(0.00001 as decimal(5,5))) from decimal_simple limit 1; ---- -Decimal128(21, 12) +Decimal128(19, 10) query R rowsort @@ -385,27 +391,27 @@ select c1/cast(0.00001 as decimal(5,5)) from decimal_simple; query T select arrow_typeof(c1/c5) from decimal_simple limit 1; ---- -Decimal128(30, 19) +Decimal128(21, 10) query R rowsort select c1/c5 from decimal_simple; ---- 0.5 -0.641025641026 -0.714285714286 -0.735294117647 +0.641025641 +0.7142857142 +0.7352941176 0.8 -0.857142857143 -0.909090909091 -0.909090909091 +0.8571428571 +0.909090909 +0.909090909 0.9375 -0.961538461538 +0.9615384615 1 1 -1.052631578947 -1.515151515152 -2.727272727273 +1.0526315789 +1.5151515151 +2.7272727272 query T @@ -463,7 +469,7 @@ select c1%c5 from decimal_simple; query T select arrow_typeof(abs(c1)) from decimal_simple limit 1; ---- -Float64 +Decimal128(10, 6) query R rowsort @@ -501,27 +507,26 @@ select * from decimal_simple where c1 >= 0.00004 order by c1; query RRIBR -select * from decimal_simple where c1 >= 0.00004 order by c1 limit 10; +select * from decimal_simple where c1 >= 0.00004 order by c1, c3 limit 10; ---- 0.00004 0.000000000004 5 true 0.000044 +0.00004 0.000000000004 8 false 0.000044 0.00004 0.000000000004 12 false 0.00004 0.00004 0.000000000004 14 true 0.00004 -0.00004 0.000000000004 8 false 0.000044 -0.00005 0.000000000005 9 true 0.000052 +0.00005 0.000000000005 1 false 0.0001 0.00005 0.000000000005 4 true 0.000078 0.00005 0.000000000005 8 false 0.000033 +0.00005 0.000000000005 9 true 0.000052 0.00005 0.000000000005 100 true 0.000068 -0.00005 0.000000000005 1 false 0.0001 - query RRIBR -select * from decimal_simple where c1 >= 0.00004 order by c1 limit 5; +select * from decimal_simple where c1 >= 0.00004 order by c1, c3 limit 5; ---- 0.00004 0.000000000004 5 true 0.000044 +0.00004 0.000000000004 8 false 0.000044 0.00004 0.000000000004 12 false 0.00004 0.00004 0.000000000004 14 true 0.00004 -0.00004 0.000000000004 8 false 0.000044 -0.00005 0.000000000005 9 true 0.000052 +0.00005 0.000000000005 1 false 0.0001 query RRIBR @@ -597,3 +602,123 @@ query R select try_cast(1234567 as decimal(7,3)); ---- NULL + +statement ok +create table foo (a DECIMAL(38, 20), b DECIMAL(38, 0)); + +statement ok +insert into foo VALUES (1, 5); + +query R +select a / b from foo; +---- +0.2 + +statement ok +create table t as values (arrow_cast(123, 'Decimal256(5,2)')); + +# make sure query below runs in single partition +# otherwise error message may not be deterministic +statement ok +set datafusion.execution.target_partitions = 1; + +query R +select AVG(column1) from t; +---- +123 + +statement ok +drop table t; + +statement ok +CREATE EXTERNAL TABLE decimal256_simple ( +c1 DECIMAL(50,6) NOT NULL, +c2 DOUBLE NOT NULL, +c3 BIGINT NOT NULL, +c4 BOOLEAN NOT NULL, +c5 DECIMAL(52,7) NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/decimal_data.csv'; + +query TT +select arrow_typeof(c1), arrow_typeof(c5) from decimal256_simple limit 1; +---- +Decimal256(50, 6) Decimal256(52, 7) + +query R rowsort +SELECT c1 from decimal256_simple; +---- +0.00001 +0.00002 +0.00002 +0.00003 +0.00003 +0.00003 +0.00004 +0.00004 +0.00004 +0.00004 +0.00005 +0.00005 +0.00005 +0.00005 +0.00005 + +query R rowsort +select c1 from decimal256_simple where c1 > 0.000030; +---- +0.00004 +0.00004 +0.00004 +0.00004 +0.00005 +0.00005 +0.00005 +0.00005 +0.00005 + +query RRIBR rowsort +select * from decimal256_simple where c1 > c5; +---- +0.00002 0.000000000002 3 false 0.000019 +0.00003 0.000000000003 5 true 0.000011 +0.00005 0.000000000005 8 false 0.000033 + +query TR +select arrow_typeof(avg(c1)), avg(c1) from decimal256_simple; +---- +Decimal256(54, 10) 0.0000366666 + +query TR +select arrow_typeof(min(c1)), min(c1) from decimal256_simple where c4=false; +---- +Decimal256(50, 6) 0.00002 + +query TR +select arrow_typeof(max(c1)), max(c1) from decimal256_simple where c4=false; +---- +Decimal256(50, 6) 0.00005 + +query TR +select arrow_typeof(sum(c1)), sum(c1) from decimal256_simple; +---- +Decimal256(60, 6) 0.00055 + +query TR +select arrow_typeof(median(c1)), median(c1) from decimal256_simple; +---- +Decimal256(50, 6) 0.00004 + +query IR +select count(*),c1 from decimal256_simple group by c1 order by c1; +---- +1 0.00001 +2 0.00002 +3 0.00003 +4 0.00004 +5 0.00005 + +statement ok +drop table decimal256_simple; diff --git a/datafusion/core/tests/sqllogictests/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt similarity index 68% rename from datafusion/core/tests/sqllogictests/test_files/describe.slt rename to datafusion/sqllogictest/test_files/describe.slt index 5ee4d1cd21976..f94a2e453884f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -24,7 +24,7 @@ statement ok set datafusion.catalog.information_schema = true statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; query TTT rowsort DESCRIBE aggregate_simple; @@ -44,7 +44,7 @@ statement ok set datafusion.catalog.information_schema = false statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; query TTT rowsort DESCRIBE aggregate_simple; @@ -60,5 +60,29 @@ DROP TABLE aggregate_simple; # Describe file (currently we can only describe file in datafusion-cli, fix this after issue (#4850) has been done) ########## -statement error Error during planning: table 'datafusion.public.tests/data/aggregate_simple.csv' not found -DESCRIBE 'tests/data/aggregate_simple.csv'; +statement error Error during planning: table 'datafusion.public.../core/tests/data/aggregate_simple.csv' not found +DESCRIBE '../core/tests/data/aggregate_simple.csv'; + +########## +# Describe command +########## + +statement ok +CREATE EXTERNAL TABLE alltypes_tiny_pages STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_tiny_pages.parquet'; + +query TTT +describe alltypes_tiny_pages; +---- +id Int32 YES +bool_col Boolean YES +tinyint_col Int8 YES +smallint_col Int16 YES +int_col Int32 YES +bigint_col Int64 YES +float_col Float32 YES +double_col Float64 YES +date_string_col Utf8 YES +string_col Utf8 YES +timestamp_col Timestamp(Nanosecond, None) YES +year Int32 YES +month Int32 YES diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt new file mode 100644 index 0000000000000..8a36b49b98c6e --- /dev/null +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# Basic example: distinct on the first column project the second one, and +# order by the third +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +a 5 +b 4 +c 2 +d 1 +e 3 + +# Basic example + reverse order of the selected column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC; +---- +a 1 +b 5 +c 4 +d 1 +e 1 + +# Basic example + reverse order of the ON column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3; +---- +e 3 +d 1 +c 2 +b 4 +a 4 + +# Basic example + reverse order of both columns + limit +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3 DESC LIMIT 3; +---- +e 1 +d 1 +c 4 + +# Basic example + omit ON column from selection +query I +SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +5 +4 +2 +1 +3 + +# Test explain makes sense +query TT +EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +logical_plan +Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +--Sort: aggregate_test_100.c1 ASC NULLS LAST +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +----SortExec: expr=[c1@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)], ordering_mode=Sorted +--------------SortExec: expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# ON expressions are not a sub-set of the ORDER BY expressions +query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2, c3; + +# ON expressions are empty +query error DataFusion error: Error during planning: No `ON` expressions provided +SELECT DISTINCT ON () c1, c2 FROM aggregate_test_100 ORDER BY c1, c2; + +# Use expressions in the ON and ORDER BY clauses, as well as the selection +query II +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2 % 2 = 0, c3 DESC; +---- +1 25 +4 23 + +# Multiple complex expressions +query TIB +SELECT DISTINCT ON (chr(ascii(c1) + 3), c2 % 2) chr(ascii(upper(c1)) + 3), c2 % 2, c3 > 80 AND c2 % 2 = 1 +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b') +ORDER BY chr(ascii(c1) + 3), c2 % 2, c3 DESC; +---- +D 0 false +D 1 true +E 0 false +E 1 false + +# Joins using CTEs +query II +WITH t1 AS (SELECT * FROM aggregate_test_100), +t2 AS (SELECT * FROM aggregate_test_100) +SELECT DISTINCT ON (t1.c1, t2.c2) t2.c3, t1.c4 +FROM t1 INNER JOIN t2 ON t1.c13 = t2.c13 +ORDER BY t1.c1, t2.c2, t2.c5 +LIMIT 3; +---- +-25 15295 +45 15673 +-72 -11122 diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt new file mode 100644 index 0000000000000..9f4f508e23f32 --- /dev/null +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test( + num INT, + bin_field BYTEA, + base64_field TEXT, + hex_field TEXT, +) as VALUES + (0, 'abc', encode('abc', 'base64'), encode('abc', 'hex')), + (1, 'qweqwe', encode('qweqwe', 'base64'), encode('qweqwe', 'hex')), + (2, NULL, NULL, NULL) +; + +# errors +query error DataFusion error: Error during planning: The encode function can only accept utf8 or binary\. +select encode(12, 'hex') + +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +select encode(bin_field, 'non_encoding') from test; + +query error DataFusion error: Error during planning: The decode function can only accept utf8 or binary\. +select decode(12, 'hex') + +query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex +select decode(hex_field, 'non_encoding') from test; + +query error DataFusion error: Error during planning: No function matches the given name and argument types 'to_hex\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tto_hex\(Int64\) +select to_hex(hex_field) from test; + +# Arrays tests +query T +SELECT encode(bin_field, 'hex') FROM test ORDER BY num; +---- +616263 +717765717765 +NULL + +query T +SELECT arrow_cast(decode(base64_field, 'base64'), 'Utf8') FROM test ORDER BY num; +---- +abc +qweqwe +NULL + +query T +SELECT arrow_cast(decode(hex_field, 'hex'), 'Utf8') FROM test ORDER BY num; +---- +abc +qweqwe +NULL + +query T +select to_hex(num) from test ORDER BY num; +---- +0 +1 +2 diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt new file mode 100644 index 0000000000000..e3b2610e51be3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# create aggregate_test_100 table +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# csv_query_error +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'sin\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tsin\(Float64/Float32\) +SELECT sin(c1) FROM aggregate_test_100 + +# cast_expressions_error +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c' to value of Int32 type +SELECT CAST(c1 AS INT) FROM aggregate_test_100 + +# aggregation_with_bad_arguments +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tCOUNT\(Any, .., Any\) +SELECT COUNT(DISTINCT) FROM aggregate_test_100 + +# query_cte_incorrect +statement error Error during planning: table 'datafusion\.public\.t' not found +WITH t AS (SELECT * FROM t) SELECT * from u + +statement error Error during planning: table 'datafusion\.public\.u' not found +WITH t AS (SELECT * FROM u), u AS (SELECT 1) SELECT * from u + +statement error Error during planning: table 'datafusion\.public\.u' not found +WITH t AS (WITH u as (SELECT 1) SELECT 1) SELECT * from u + +# select_wildcard_without_table +statement error Error during planning: SELECT \* with no tables specified is not valid +SELECT * + +# invalid_qualified_table_references +statement error Error during planning: table 'datafusion\.nonexistentschema\.aggregate_test_100' not found +SELECT COUNT(*) FROM nonexistentschema.aggregate_test_100 + +statement error Error during planning: table 'nonexistentcatalog\.public\.aggregate_test_100' not found +SELECT COUNT(*) FROM nonexistentcatalog.public.aggregate_test_100 + +statement error Error during planning: Unsupported compound identifier '\[Ident \{ value: "way", quote_style: None \}, Ident \{ value: "too", quote_style: None \}, Ident \{ value: "many", quote_style: None \}, Ident \{ value: "namespaces", quote_style: None \}, Ident \{ value: "as", quote_style: None \}, Ident \{ value: "ident", quote_style: None \}, Ident \{ value: "prefixes", quote_style: None \}, Ident \{ value: "aggregate_test_100", quote_style: None \}\]' +SELECT COUNT(*) FROM way.too.many.namespaces.as.ident.prefixes.aggregate_test_100 + + + +# +# Wrong scalar function signature +# + +# error message for wrong function signature (Variadic: arbitrary number of args all from some common types) +statement error Error during planning: No function matches the given name and argument types 'concat\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\) +SELECT concat(); + +# error message for wrong function signature (Uniform: t args all from some common types) +statement error Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tnullif\(Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8, Boolean/UInt8/UInt16/UInt32/UInt64/Int8/Int16/Int32/Int64/Float32/Float64/Utf8/LargeUtf8\) +SELECT nullif(1); + +# error message for wrong function signature (Exact: exact number of args of an exact type) +statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) +SELECT pi(3.14); + +# error message for wrong function signature (Any: fixed number of args of arbitrary types) +statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) +SELECT arrow_typeof(1, 1); + +# error message for wrong function signature (OneOf: fixed number of args of arbitrary types) +statement error Error during planning: No function matches the given name and argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64, Float64\) +SELECT power(1, 2, 3); + +# +# Wrong window/aggregate function signature +# + +# AggregateFunction with wrong number of arguments +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) +select count(); + +# AggregateFunction with wrong number of arguments +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Utf8, Float64\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tAVG\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +select avg(c1, c12) from aggregate_test_100; + +# AggregateFunction with wrong argument type +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +select regr_slope(1, '2'); + +# WindowFunction using AggregateFunction wrong signature +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +select +c9, +regr_slope(c11, '2') over () as min1 +from aggregate_test_100 +order by c9 + +# WindowFunction with BuiltInWindowFunction wrong signature +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'NTH_VALUE\(Int32, Int64, Int64\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tNTH_VALUE\(Any, Any\) +select +c9, +nth_value(c5, 2, 3) over (order by c9) as nv1 +from aggregate_test_100 +order by c9 + + +statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 +create table foo as values (1), ('foo'); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt new file mode 100644 index 0000000000000..4583ef319b7fc --- /dev/null +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -0,0 +1,381 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INTEGER NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv'; + +query TT +explain SELECT c1 FROM aggregate_test_100 where c2 > 10 +---- +logical_plan +Projection: aggregate_test_100.c1 +--Filter: aggregate_test_100.c2 > Int8(10) +----TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] +physical_plan +ProjectionExec: expr=[c1@0 as c1] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: c2@1 > 10 +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2], has_header=true + +# explain_csv_exec_scan_config + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100_with_order ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INTEGER NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (c1 ASC) +LOCATION '../core/tests/data/aggregate_test_100_order_by_c1_asc.csv'; + +query TT +explain SELECT c1 FROM aggregate_test_100_with_order order by c1 ASC limit 10 +---- +logical_plan +Limit: skip=0, fetch=10 +--Sort: aggregate_test_100_with_order.c1 ASC NULLS LAST, fetch=10 +----TableScan: aggregate_test_100_with_order projection=[c1] +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv]]}, projection=[c1], output_ordering=[c1@0 ASC NULLS LAST], has_header=true + + +## explain_physical_plan_only + +statement ok +set datafusion.explain.physical_plan_only = true + +query TT +EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3) +---- +physical_plan +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec + +statement ok +set datafusion.explain.physical_plan_only = false + + +## explain nested +query error DataFusion error: Error during planning: Nested EXPLAINs are not supported +EXPLAIN explain select 1 + +## explain nested +statement error DataFusion error: Error during planning: Nested EXPLAINs are not supported +EXPLAIN EXPLAIN explain select 1 + +statement ok +set datafusion.explain.physical_plan_only = true + +statement error DataFusion error: Error during planning: Nested EXPLAINs are not supported +EXPLAIN explain select 1 + +statement ok +set datafusion.explain.physical_plan_only = false + +########## +# EXPLAIN VERBOSE will get pass prefixed with "logical_plan after" +########## + +statement ok +CREATE EXTERNAL TABLE simple_explain_test ( + a INT, + b INT, + c INT +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/example.csv' + +query TT +EXPLAIN SELECT a, b, c FROM simple_explain_test +---- +logical_plan TableScan: simple_explain_test projection=[a, b, c] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true + +# create a sink table, path is same with aggregate_test_100 table +# we do not overwrite this file, we only assert plan. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE sink_table ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INTEGER NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv'; + +query TT +EXPLAIN INSERT INTO sink_table SELECT * FROM aggregate_test_100 ORDER by c1 +---- +logical_plan +Dml: op=[Insert Into] table=[sink_table] +--Projection: aggregate_test_100.c1 AS c1, aggregate_test_100.c2 AS c2, aggregate_test_100.c3 AS c3, aggregate_test_100.c4 AS c4, aggregate_test_100.c5 AS c5, aggregate_test_100.c6 AS c6, aggregate_test_100.c7 AS c7, aggregate_test_100.c8 AS c8, aggregate_test_100.c9 AS c9, aggregate_test_100.c10 AS c10, aggregate_test_100.c11 AS c11, aggregate_test_100.c12 AS c12, aggregate_test_100.c13 AS c13 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +physical_plan +FileSinkExec: sink=StreamWrite { location: "../../testing/data/csv/aggregate_test_100.csv", batch_size: 8192, encoding: Csv, header: true, .. } +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true + +# test EXPLAIN VERBOSE +query TT +EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test +---- +initial_logical_plan +Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c +--TableScan: simple_explain_test +logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after type_coercion SAME TEXT AS ABOVE +logical_plan after count_wildcard_rule SAME TEXT AS ABOVE +analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE +logical_plan after eliminate_join SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE +logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE +logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE +logical_plan after eliminate_filter SAME TEXT AS ABOVE +logical_plan after eliminate_cross_join SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after eliminate_limit SAME TEXT AS ABOVE +logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE +logical_plan after filter_null_join_keys SAME TEXT AS ABOVE +logical_plan after eliminate_outer_join SAME TEXT AS ABOVE +logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE +logical_plan after eliminate_join SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE +logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE +logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE +logical_plan after eliminate_filter SAME TEXT AS ABOVE +logical_plan after eliminate_cross_join SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after eliminate_limit SAME TEXT AS ABOVE +logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE +logical_plan after filter_null_join_keys SAME TEXT AS ABOVE +logical_plan after eliminate_outer_join SAME TEXT AS ABOVE +logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after optimize_projections SAME TEXT AS ABOVE +logical_plan TableScan: simple_explain_test projection=[a, b, c] +initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true +physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] + + +### tests for EXPLAIN with display statistics enabled +statement ok +set datafusion.explain.show_statistics = true; + +statement ok +set datafusion.explain.physical_plan_only = true; + +# CSV scan with empty statistics +query TT +EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; +---- +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] + +# Parquet scan with statistics collected +statement ok +set datafusion.execution.collect_statistics = true; + +statement ok +CREATE EXTERNAL TABLE alltypes_plain STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + +query TT +EXPLAIN SELECT * FROM alltypes_plain limit 10; +---- +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + +# explain verbose with both collect & show statistics on +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + + +statement ok +set datafusion.explain.show_statistics = false; + +# explain verbose with collect on and & show statistics off: still has stats +query TT +EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; +---- +initial_physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +initial_physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after OutputRequirements +OutputRequirementExec +--GlobalLimitExec: skip=0, fetch=10 +----ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after OutputRequirements +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan_with_stats +GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] + + +statement ok +set datafusion.execution.collect_statistics = false; + +# Explain ArrayFuncions + +statement ok +set datafusion.explain.physical_plan_only = false + +query TT +explain select make_array(make_array(1, 2, 3), make_array(4, 5, 6)); +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec + +query TT +explain select [[1, 2, 3], [4, 5, 6]]; +---- +logical_plan +Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) +--EmptyRelation +physical_plan +ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] +--PlaceholderRowExec diff --git a/datafusion/core/tests/sqllogictests/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt similarity index 54% rename from datafusion/core/tests/sqllogictests/test_files/functions.slt rename to datafusion/sqllogictest/test_files/functions.slt index 92597118c65a6..4f55ea316bb9f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -63,6 +63,11 @@ SELECT left('abcde', -2) ---- abc +query T +SELECT left(arrow_cast('abcde', 'Dictionary(Int32, Utf8)'), -2) +---- +abc + query T SELECT left('abcde', -200) ---- @@ -103,6 +108,11 @@ SELECT length('') ---- 0 +query I +SELECT length(arrow_cast('', 'Dictionary(Int32, Utf8)')) +---- +0 + query I SELECT length('chars') ---- @@ -113,6 +123,11 @@ SELECT length('josé') ---- 4 +query I +SELECT length(arrow_cast('josé', 'Dictionary(Int32, Utf8)')) +---- +4 + query ? SELECT length(NULL) ---- @@ -158,6 +173,11 @@ SELECT lpad('hi', 5) ---- hi +query T +SELECT lpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5) +---- + hi + query T SELECT lpad('hi', CAST(NULL AS INT), 'xy') ---- @@ -188,6 +208,11 @@ SELECT reverse('abcde') ---- edcba +query T +SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) +---- +edcba + query T SELECT reverse('loẅks') ---- @@ -203,6 +228,11 @@ SELECT right('abcde', -2) ---- cde +query T +SELECT right(arrow_cast('abcde', 'Dictionary(Int32, Utf8)'), 1) +---- +e + query T SELECT right('abcde', -200) ---- @@ -268,6 +298,11 @@ SELECT rpad('hi', 5, 'xy') ---- hixyx +query T +SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') +---- +hixyx + query T SELECT rpad('hi', 5, NULL) ---- @@ -378,11 +413,22 @@ SELECT substr('alphabet', 3, CAST(NULL AS int)) ---- NULL +statement error The "substr" function can only accept strings, but got Int64. +SELECT substr(1, 3) + +statement error The "substr" function can only accept strings, but got Int64. +SELECT substr(1, 3, 4) + query T SELECT translate('12345', '143', 'ax') ---- a2x5 +query T +SELECT translate(arrow_cast('12345', 'Dictionary(Int32, Utf8)'), '143', 'ax') +---- +a2x5 + query ? SELECT translate(NULL, '143', 'ax') ---- @@ -448,6 +494,10 @@ SELECT counter(*) from test; statement error Did you mean 'STDDEV'? SELECT STDEV(v1) from test; +# Aggregate function +statement error Did you mean 'COVAR'? +SELECT COVARIA(1,1); + # Window function statement error Did you mean 'SUM'? SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; @@ -476,7 +526,7 @@ from (values query ? SELECT struct(c1,c2,c3,c4,a,b) from simple_struct_test ---- -{c0: 1, c1: 1, c2: 3.1, c3: 3.14, c4: str, c5: text} +{c0: true, c1: 1, c2: 3.1, c3: 3.14, c4: str, c5: text} statement ok drop table simple_struct_test @@ -566,3 +616,382 @@ SELECT sqrt(column1),sqrt(column2),sqrt(column3),sqrt(column4),sqrt(column5),sqr statement ok drop table t + +query T +SELECT upper('foo') +---- +FOO + +query T +select upper(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +FOO + +query T +SELECT btrim(' foo ') +---- +foo + +query T +SELECT btrim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) +---- +foo + +query T +SELECT initcap('foo') +---- +Foo + +query T +SELECT initcap(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +Foo + +query T +SELECT lower('FOObar') +---- +foobar + +query T +SELECT lower(arrow_cast('FOObar', 'Dictionary(Int32, Utf8)')) +---- +foobar + +query T +SELECT ltrim(' foo') +---- +foo + +query T +SELECT ltrim(arrow_cast(' foo', 'Dictionary(Int32, Utf8)')) +---- +foo + +query T +SELECT md5('foo') +---- +acbd18db4cc2f85cedef654fccc4a4d8 + +query T +SELECT md5(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +acbd18db4cc2f85cedef654fccc4a4d8 + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT repeat('foo', 3) +---- +foofoofoo + +query T +SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) +---- +foofoofoo + +query T +SELECT replace('foobar', 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') +---- +foohello + +query T +SELECT rtrim(' foo ') +---- + foo + +query T +SELECT rtrim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) +---- + foo + +query T +SELECT split_part('foo_bar', '_', 2) +---- +bar + +query T +SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) +---- +bar + +query T +SELECT trim(' foo ') +---- +foo + +query T +SELECT trim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) +---- +foo + +query I +SELECT bit_length('foo') +---- +24 + +query I +SELECT bit_length(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +24 + +query I +SELECT character_length('foo') +---- +3 + +query I +SELECT character_length(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +3 + +query I +SELECT octet_length('foo') +---- +3 + +query I +SELECT octet_length(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) +---- +3 + +query I +SELECT strpos('helloworld', 'world') +---- +6 + +query I +SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), 'world') +---- +6 + +statement ok +CREATE TABLE products ( +product_id INT PRIMARY KEY, +product_name VARCHAR(100), +price DECIMAL(10, 2)) + +statement ok +INSERT INTO products (product_id, product_name, price) VALUES +(1, 'OldBrand Product 1', 19.99), +(2, 'OldBrand Product 2', 29.99), +(3, 'OldBrand Product 3', 39.99), +(4, 'OldBrand Product 4', 49.99) + +query ITR +SELECT * REPLACE (price*2 AS price) FROM products +---- +1 OldBrand Product 1 39.98 +2 OldBrand Product 2 59.98 +3 OldBrand Product 3 79.98 +4 OldBrand Product 4 99.98 + +# types are conserved +query ITR +SELECT * REPLACE (product_id/2 AS product_id) FROM products +---- +0 OldBrand Product 1 19.99 +1 OldBrand Product 2 29.99 +1 OldBrand Product 3 39.99 +2 OldBrand Product 4 49.99 + +# multiple replace statements with qualified wildcard +query ITR +SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM products +---- +1001 OldBrand Product 1 39.98 +1002 OldBrand Product 2 59.98 +1003 OldBrand Product 3 79.98 +1004 OldBrand Product 4 99.98 + +#overlay tests +statement ok +CREATE TABLE over_test( + str TEXT, + characters TEXT, + pos INT, + len INT +) as VALUES + ('123', 'abc', 4, 5), + ('abcdefg', 'qwertyasdfg', 1, 7), + ('xyz', 'ijk', 1, 2), + ('Txxxxas', 'hom', 2, 4), + (NULL, 'hom', 2, 4), + ('Txxxxas', 'hom', NULL, 4), + ('Txxxxas', 'hom', 2, NULL), + ('Txxxxas', NULL, 2, 4) +; + +query T +SELECT overlay(str placing characters from pos for len) from over_test +---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL +NULL +NULL + +query T +SELECT overlay(str placing characters from pos) from over_test +---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas +NULL + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query ? +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query ? +SELECT levenshtein(NULL, NULL) +---- +NULL + +query T +SELECT substr_index('www.apache.org', '.', 1) +---- +www + +query T +SELECT substr_index('www.apache.org', '.', 2) +---- +www.apache + +query T +SELECT substr_index('www.apache.org', '.', -1) +---- +org + +query T +SELECT substr_index('www.apache.org', '.', -2) +---- +apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', 1) +---- +www.ap + +query T +SELECT substr_index('www.apache.org', 'ac', -1) +---- +he.org + +query T +SELECT substr_index('www.apache.org', 'ac', 2) +---- +www.apache.org + +query T +SELECT substr_index('www.apache.org', 'ac', -2) +---- +www.apache.org + +query ? +SELECT substr_index(NULL, 'ac', 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', NULL, 1) +---- +NULL + +query T +SELECT substr_index('www.apache.org', 'ac', NULL) +---- +NULL + +query T +SELECT substr_index('', 'ac', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', '', 1) +---- +(empty) + +query T +SELECT substr_index('www.apache.org', 'ac', 0) +---- +(empty) + +query ? +SELECT substr_index(NULL, NULL, NULL) +---- +NULL + +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query ? +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query ? +SELECT find_in_set(NULL, NULL) +---- +NULL diff --git a/datafusion/core/tests/sqllogictests/test_files/group.slt b/datafusion/sqllogictest/test_files/group.slt similarity index 98% rename from datafusion/core/tests/sqllogictests/test_files/group.slt rename to datafusion/sqllogictest/test_files/group.slt index a56451d7aaea1..2a28efa73a621 100644 --- a/datafusion/core/tests/sqllogictests/test_files/group.slt +++ b/datafusion/sqllogictest/test_files/group.slt @@ -36,7 +36,7 @@ WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; # csv_query_group_by_int_min_max diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt new file mode 100644 index 0000000000000..b7be4d78b583e --- /dev/null +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -0,0 +1,4281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +statement ok +CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER) + +statement ok +CREATE TABLE tab1(col0 INTEGER, col1 INTEGER, col2 INTEGER) + +statement ok +CREATE TABLE tab2(col0 INTEGER, col1 INTEGER, col2 INTEGER) + +statement ok +INSERT INTO tab0 VALUES(83,0,38) + +statement ok +INSERT INTO tab0 VALUES(26,0,79) + +statement ok +INSERT INTO tab0 VALUES(43,81,24) + +statement ok +INSERT INTO tab1 VALUES(22,6,8) + +statement ok +INSERT INTO tab1 VALUES(28,57,45) + +statement ok +INSERT INTO tab1 VALUES(82,44,71) + +statement ok +INSERT INTO tab2 VALUES(15,61,87) + +statement ok +INSERT INTO tab2 VALUES(91,59,79) + +statement ok +INSERT INTO tab2 VALUES(92,41,58) + +query I rowsort +SELECT - tab1.col0 * 84 + + 38 AS col2 FROM tab1 GROUP BY tab1.col0 +---- +-1810 +-2314 +-6850 + +query I rowsort +SELECT + cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT DISTINCT - ( + col1 ) + - 51 AS col0 FROM tab1 AS cor0 GROUP BY col1 +---- +-108 +-57 +-95 + +query I rowsort +SELECT col1 * cor0.col1 * 56 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +194936 +208376 +94136 + +query I rowsort label-4 +SELECT ALL + tab2.col1 / tab2.col1 FROM tab2 GROUP BY col1 +---- +1 +1 +1 + +query I rowsort +SELECT ALL + tab1.col0 FROM tab1 GROUP BY col0 +---- +22 +28 +82 + +query I rowsort +SELECT DISTINCT tab1.col0 AS col1 FROM tab1 GROUP BY tab1.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ALL col2 FROM tab1 GROUP BY col2 +---- +45 +71 +8 + +query I rowsort +SELECT ALL + cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +26 +43 +83 + +query III rowsort +SELECT DISTINCT * FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 +---- +26 0 79 +43 81 24 +83 0 38 + +query III rowsort +SELECT * FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 +---- +26 0 79 +43 81 24 +83 0 38 + +query I rowsort +SELECT - 9 * cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +-369 +-531 +-549 + +query I rowsort +SELECT DISTINCT - 21 FROM tab2 GROUP BY col2 +---- +-21 + +query I rowsort +SELECT DISTINCT - 97 AS col2 FROM tab1 GROUP BY col0 +---- +-97 + +query I rowsort +SELECT + ( - 1 ) AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +-1 +-1 +-1 + +query I rowsort +SELECT - + cor0.col1 FROM tab0, tab0 cor0 GROUP BY cor0.col1 +---- +-81 +0 + +query I rowsort +SELECT + cor0.col0 + 36 AS col2 FROM tab0 AS cor0 GROUP BY col0 +---- +119 +62 +79 + +query I rowsort +SELECT cor0.col1 AS col1 FROM tab0 AS cor0 GROUP BY col1 +---- +0 +81 + +query I rowsort +SELECT DISTINCT + cor0.col1 FROM tab2 cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT ALL + cor0.col0 + - col0 col1 FROM tab1 AS cor0 GROUP BY col0 +---- +0 +0 +0 + +query I rowsort +SELECT ALL 54 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +54 +54 +54 + +query I rowsort +SELECT 40 AS col1 FROM tab1 cor0 GROUP BY cor0.col0 +---- +40 +40 +40 + +query I rowsort +SELECT DISTINCT ( cor0.col0 ) AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +26 +43 +83 + +query I rowsort +SELECT 62 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +62 +62 +62 + +query I rowsort +SELECT 23 FROM tab2 GROUP BY tab2.col2 +---- +23 +23 +23 + +query I rowsort +SELECT + ( - tab0.col0 ) col2 FROM tab0, tab0 AS cor0 GROUP BY tab0.col0 +---- +-26 +-43 +-83 + +query I rowsort +SELECT + cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col1 +---- +44 +57 +6 + +query I rowsort +SELECT cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col2 +---- +41 +59 +61 + +query I rowsort +SELECT DISTINCT + 80 + cor0.col2 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +104 +118 +159 + +query I rowsort +SELECT DISTINCT 30 * - 9 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +-270 + +query I rowsort +SELECT DISTINCT - col2 FROM tab1 AS cor0 GROUP BY col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT ALL - col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT DISTINCT + 82 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +82 + +query I rowsort +SELECT 79 * 19 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +1501 +1501 +1501 + +query I rowsort +SELECT ALL ( + 68 ) FROM tab1 cor0 GROUP BY cor0.col2 +---- +68 +68 +68 + +query I rowsort +SELECT - col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +-22 +-28 +-82 + +query I rowsort +SELECT + 81 col2 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +81 +81 +81 + +query I rowsort +SELECT ALL cor0.col2 AS col1 FROM tab2 cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT ALL + cor0.col0 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT - cor0.col2 AS col0 FROM tab0 cor0 GROUP BY cor0.col2 +---- +-24 +-38 +-79 + +query I rowsort +SELECT cor0.col0 FROM tab1 AS cor0 GROUP BY col0, cor0.col1, cor0.col1 +---- +22 +28 +82 + +query I rowsort +SELECT 58 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +58 +58 + +query I rowsort +SELECT ALL cor0.col1 + - 20 AS col1 FROM tab0 cor0 GROUP BY cor0.col1 +---- +-20 +61 + +query I rowsort +SELECT ALL + col1 col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT DISTINCT - - 56 FROM tab2, tab0 AS cor0 GROUP BY cor0.col1 +---- +56 + +query I rowsort +SELECT - 10 AS col0 FROM tab2, tab1 AS cor0, tab2 AS cor1 GROUP BY cor1.col0 +---- +-10 +-10 +-10 + +query I rowsort +SELECT 31 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +31 +31 +31 + +query I rowsort +SELECT col2 AS col0 FROM tab0 cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT + 70 AS col1 FROM tab0 GROUP BY col0 +---- +70 +70 +70 + +query I rowsort +SELECT DISTINCT cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT - cor0.col1 FROM tab2, tab2 AS cor0 GROUP BY cor0.col1 +---- +-41 +-59 +-61 + +query I rowsort +SELECT DISTINCT + tab0.col0 col1 FROM tab0 GROUP BY tab0.col0 +---- +26 +43 +83 + +query I rowsort +SELECT DISTINCT - cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +-24 +-38 +-79 + +query I rowsort +SELECT + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT - 5 AS col2 FROM tab2, tab2 AS cor0, tab2 AS cor1 GROUP BY tab2.col1 +---- +-5 +-5 +-5 + +query I rowsort +SELECT DISTINCT 0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +0 + +query I rowsort +SELECT DISTINCT - - tab2.col0 FROM tab2 GROUP BY col0 +---- +15 +91 +92 + +query III rowsort +SELECT DISTINCT * FROM tab2 AS cor0 GROUP BY cor0.col0, col1, cor0.col2 +---- +15 61 87 +91 59 79 +92 41 58 + +query I rowsort label-58 +SELECT 9 / + cor0.col0 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +0 +0 +0 + +query I rowsort +SELECT ( - 72 ) AS col1 FROM tab1 cor0 GROUP BY cor0.col0, cor0.col2 +---- +-72 +-72 +-72 + +query I rowsort +SELECT cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ( col0 ) FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort label-62 +SELECT ALL 59 / 26 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +2 +2 +2 + +query I rowsort +SELECT 15 FROM tab1 AS cor0 GROUP BY col2, col2 +---- +15 +15 +15 + +query I rowsort +SELECT CAST ( NULL AS INTEGER ) FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col2 +---- +NULL +NULL +NULL + +query I rowsort +SELECT ALL - 79 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +-79 +-79 +-79 + +query I rowsort +SELECT ALL 69 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +69 +69 +69 + +query I rowsort +SELECT ALL 37 col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +37 +37 + +query I rowsort +SELECT ALL 55 * 15 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +825 +825 +825 + +query I rowsort +SELECT ( 63 ) FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +63 +63 +63 + +query I rowsort +SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT - col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +-58 +-79 +-87 + +query I rowsort +SELECT ALL 81 * 11 FROM tab2 AS cor0 GROUP BY col1, cor0.col0 +---- +891 +891 +891 + +query I rowsort +SELECT ALL 9 FROM tab2 AS cor0 GROUP BY col2 +---- +9 +9 +9 + +query I rowsort +SELECT DISTINCT ( - 31 ) col1 FROM tab1 GROUP BY tab1.col0 +---- +-31 + +query I rowsort label-75 +SELECT + + cor0.col0 / - cor0.col0 FROM tab1, tab0 AS cor0 GROUP BY cor0.col0 +---- +-1 +-1 +-1 + +query I rowsort +SELECT cor0.col2 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT ALL cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +0 +81 + +query I rowsort +SELECT ALL + - ( - tab0.col2 ) AS col0 FROM tab0 GROUP BY tab0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT 72 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +72 +72 + +query I rowsort +SELECT - 20 - + col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +-101 +-20 + +query I rowsort +SELECT - - 63 FROM tab1 GROUP BY tab1.col0 +---- +63 +63 +63 + +query I rowsort +SELECT cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col2, col1 +---- +45 +71 +8 + +query I rowsort +SELECT + cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +0 +81 + +query I rowsort +SELECT DISTINCT cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col1 +---- +44 +57 +6 + +query I rowsort +SELECT cor0.col0 - col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +0 +0 +0 + +query I rowsort +SELECT 50 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +50 +50 +50 + +query I rowsort +SELECT - 18 AS col0 FROM tab1 cor0 GROUP BY cor0.col2 +---- +-18 +-18 +-18 + +query I rowsort +SELECT + cor0.col2 * cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +1444 +576 +6241 + +query I rowsort +SELECT ALL 91 / cor0.col1 FROM tab2 AS cor0 GROUP BY col1, cor0.col1 +---- +1 +1 +2 + +query I rowsort +SELECT cor0.col2 AS col2 FROM tab0 AS cor0 GROUP BY col2 +---- +24 +38 +79 + +query I rowsort +SELECT ALL + 85 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +85 +85 +85 + +query I rowsort +SELECT + 49 AS col2 FROM tab0 cor0 GROUP BY cor0.col0 +---- +49 +49 +49 + +query I rowsort +SELECT cor0.col2 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT - col0 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +-15 +-91 +-92 + +query I rowsort +SELECT DISTINCT - 87 AS col1 FROM tab0 AS cor0 GROUP BY col0 +---- +-87 + +query I rowsort +SELECT + 39 FROM tab0 AS cor0 GROUP BY col1 +---- +39 +39 + +query I rowsort +SELECT ALL cor0.col2 * + col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +3364 +6241 +7569 + +query I rowsort +SELECT 40 FROM tab0 GROUP BY tab0.col1 +---- +40 +40 + +query I rowsort +SELECT tab1.col2 AS col0 FROM tab1 GROUP BY tab1.col2 +---- +45 +71 +8 + +query I rowsort +SELECT tab2.col0 FROM tab2 GROUP BY tab2.col0 +---- +15 +91 +92 + +query I rowsort +SELECT + col0 * + col0 FROM tab0 GROUP BY tab0.col0 +---- +1849 +676 +6889 + +query I rowsort +SELECT ALL cor0.col2 + cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +158 +48 +76 + +query I rowsort +SELECT DISTINCT cor0.col2 FROM tab1 cor0 GROUP BY cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT ALL + cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT cor0.col2 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort label-106 +SELECT - 53 / cor0.col0 col0 FROM tab1 cor0 GROUP BY cor0.col0 +---- +-1 +-2 +0 + +query I rowsort +SELECT cor0.col1 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +0 +81 + +query I rowsort +SELECT DISTINCT + cor0.col1 col0 FROM tab2 cor0 GROUP BY cor0.col1, cor0.col0 +---- +41 +59 +61 + +query I rowsort +SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT cor0.col1 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col1 +---- +0 +81 + +query I rowsort +SELECT 25 AS col1 FROM tab2 cor0 GROUP BY cor0.col0 +---- +25 +25 +25 + +query I rowsort +SELECT cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT DISTINCT + 6 FROM tab1 cor0 GROUP BY col2, cor0.col0 +---- +6 + +query I rowsort +SELECT cor0.col2 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT ALL 72 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +72 +72 +72 + +query I rowsort +SELECT ALL + 73 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +73 +73 +73 + +query I rowsort +SELECT tab1.col0 AS col2 FROM tab1 GROUP BY col0 +---- +22 +28 +82 + +query I rowsort +SELECT + cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT DISTINCT - cor0.col1 col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +-81 +0 + +query I rowsort +SELECT cor0.col0 * 51 FROM tab1 AS cor0 GROUP BY col0 +---- +1122 +1428 +4182 + +query I rowsort +SELECT ALL + 89 FROM tab2, tab1 AS cor0, tab1 AS cor1 GROUP BY cor0.col2 +---- +89 +89 +89 + +query I rowsort +SELECT ALL + cor0.col0 - + cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +0 +0 +0 + +query I rowsort +SELECT ALL 71 AS col0 FROM tab0 GROUP BY col1 +---- +71 +71 + +query I rowsort +SELECT - ( + cor0.col0 ) AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +-26 +-43 +-83 + +query I rowsort +SELECT 62 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +62 +62 +62 + +query I rowsort +SELECT ALL - 97 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +-97 +-97 +-97 + +query I rowsort +SELECT DISTINCT + 29 * ( cor0.col0 ) + + 47 FROM tab1 cor0 GROUP BY cor0.col0 +---- +2425 +685 +859 + +query I rowsort +SELECT DISTINCT col2 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT ALL 40 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +40 +40 +40 + +query I rowsort +SELECT cor0.col1 + cor0.col1 AS col2 FROM tab2 cor0 GROUP BY cor0.col1 +---- +118 +122 +82 + +query I rowsort +SELECT ( + cor0.col1 ) FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT cor0.col1 * + cor0.col1 col1 FROM tab1 AS cor0 GROUP BY cor0.col1 +---- +1936 +3249 +36 + +query I rowsort +SELECT ALL + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT - 9 FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col1, col2 +---- +-9 +-9 +-9 + +query I rowsort +SELECT ALL - 7 * cor0.col1 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col1 +---- +-308 +-399 +-42 + +query I rowsort +SELECT - 21 AS col2 FROM tab1 cor0 GROUP BY cor0.col1, cor0.col1 +---- +-21 +-21 +-21 + +query I rowsort +SELECT DISTINCT tab1.col2 FROM tab1 GROUP BY tab1.col2 +---- +45 +71 +8 + +query I rowsort +SELECT DISTINCT - 76 FROM tab2 GROUP BY tab2.col2 +---- +-76 + +query I rowsort +SELECT DISTINCT - cor0.col1 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +-41 +-59 +-61 + +query I rowsort +SELECT cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +0 +81 + +query I rowsort +SELECT ALL - cor0.col2 + - 55 AS col1 FROM tab0 AS cor0 GROUP BY col2 +---- +-134 +-79 +-93 + +query I rowsort +SELECT - + 28 FROM tab0, tab2 cor0 GROUP BY tab0.col1 +---- +-28 +-28 + +query I rowsort +SELECT ALL col1 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT ALL + 35 * 14 AS col1 FROM tab2 GROUP BY tab2.col1 +---- +490 +490 +490 + +query I rowsort +SELECT ALL cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, cor0.col1 +---- +15 +91 +92 + +query I rowsort +SELECT DISTINCT - cor0.col2 * 18 + + 56 FROM tab2 AS cor0 GROUP BY col2 +---- +-1366 +-1510 +-988 + +query I rowsort +SELECT cor0.col0 FROM tab0 cor0 GROUP BY col0 +---- +26 +43 +83 + +query I rowsort +SELECT ALL - 38 AS col1 FROM tab2 GROUP BY tab2.col2 +---- +-38 +-38 +-38 + +query I rowsort +SELECT - 79 FROM tab0, tab0 cor0, tab0 AS cor1 GROUP BY cor1.col0 +---- +-79 +-79 +-79 + +query I rowsort +SELECT + cor0.col2 FROM tab1 cor0 GROUP BY cor0.col2, cor0.col1 +---- +45 +71 +8 + +query I rowsort +SELECT cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +26 +43 +83 + +query I rowsort +SELECT cor0.col2 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +24 +38 +79 + +query I rowsort +SELECT + - 57 AS col1 FROM tab2 GROUP BY tab2.col2 +---- +-57 +-57 +-57 + +query I rowsort +SELECT ALL - cor0.col1 FROM tab2 cor0 GROUP BY cor0.col1 +---- +-41 +-59 +-61 + +query I rowsort +SELECT DISTINCT cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT - cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +-26 +-43 +-83 + +query I rowsort +SELECT ( - cor0.col1 ) FROM tab1 AS cor0 GROUP BY cor0.col1 +---- +-44 +-57 +-6 + +query I rowsort +SELECT DISTINCT - cor0.col2 FROM tab0 cor0 GROUP BY cor0.col2, cor0.col2 +---- +-24 +-38 +-79 + +query I rowsort +SELECT DISTINCT tab1.col1 * ( + tab1.col1 ) FROM tab1 GROUP BY col1 +---- +1936 +3249 +36 + +query I rowsort +SELECT - cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +-41 +-59 +-61 + +query III rowsort +SELECT * FROM tab2 AS cor0 GROUP BY cor0.col1, cor0.col2, cor0.col0 +---- +15 61 87 +91 59 79 +92 41 58 + +query I rowsort +SELECT + 83 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +83 +83 +83 + +query I rowsort +SELECT + ( 97 ) + - tab0.col1 FROM tab0, tab1 AS cor0 GROUP BY tab0.col1 +---- +16 +97 + +query I rowsort +SELECT 61 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +61 +61 +61 + +query I rowsort +SELECT ALL cor0.col2 FROM tab0 cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT cor0.col2 FROM tab0, tab1 AS cor0 GROUP BY cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT + - 3 FROM tab2 GROUP BY col1 +---- +-3 +-3 +-3 + +query I rowsort +SELECT DISTINCT + 96 FROM tab2 GROUP BY tab2.col1 +---- +96 + +query I rowsort +SELECT ALL 81 FROM tab1 AS cor0 GROUP BY cor0.col1 +---- +81 +81 +81 + +query I rowsort +SELECT cor0.col0 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +26 +43 +83 + +query I rowsort +SELECT - + 51 col2 FROM tab2, tab2 AS cor0 GROUP BY cor0.col1 +---- +-51 +-51 +-51 + +query I rowsort +SELECT cor0.col1 + - cor0.col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +0 +0 +0 + +query I rowsort +SELECT 35 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col1 +---- +35 +35 +35 + +query I rowsort +SELECT + tab2.col1 col0 FROM tab2 GROUP BY tab2.col1 +---- +41 +59 +61 + +query I rowsort +SELECT 37 AS col1 FROM tab0 AS cor0 GROUP BY col0 +---- +37 +37 +37 + +query I rowsort +SELECT + cor0.col1 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +query I rowsort +SELECT cor0.col1 FROM tab2, tab1 AS cor0 GROUP BY cor0.col1 +---- +44 +57 +6 + +query I rowsort +SELECT ALL - col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +-22 +-28 +-82 + +query I rowsort +SELECT + 77 AS col1 FROM tab1 AS cor0 CROSS JOIN tab0 AS cor1 GROUP BY cor0.col2 +---- +77 +77 +77 + +query I rowsort +SELECT ALL cor0.col0 col1 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT + cor0.col2 * + cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +1032 +2054 +3154 + +query I rowsort +SELECT DISTINCT 39 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +39 + +query III rowsort +SELECT DISTINCT * FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2, cor0.col1 +---- +22 6 8 +28 57 45 +82 44 71 + +query I rowsort +SELECT ALL + 28 FROM tab2 cor0 GROUP BY cor0.col0 +---- +28 +28 +28 + +query I rowsort +SELECT cor0.col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ALL cor0.col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT + ( col0 ) * col0 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +225 +8281 +8464 + +query I rowsort label-188 +SELECT - 21 - + 57 / cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +-21 +-22 +-23 + +query I rowsort +SELECT + 37 + cor0.col0 * cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2, col0 +---- +1342 +5373 +7226 + +query I rowsort +SELECT ALL cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +45 +71 +8 + +query III rowsort +SELECT * FROM tab1 AS cor0 GROUP BY col2, cor0.col1, cor0.col0 +---- +22 6 8 +28 57 45 +82 44 71 + +query I rowsort +SELECT ( cor0.col2 ) AS col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT DISTINCT 28 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +28 + +query I rowsort +SELECT ALL - 18 FROM tab0, tab1 AS cor0 GROUP BY cor0.col0 +---- +-18 +-18 +-18 + +query I rowsort +SELECT DISTINCT cor0.col2 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT + col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT - cor0.col0 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 +---- +-22 +-28 +-82 + +query I rowsort +SELECT 29 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col0 +---- +29 +29 +29 + +query I rowsort +SELECT - + cor0.col0 - 39 AS col0 FROM tab0, tab0 cor0 GROUP BY cor0.col0 +---- +-122 +-65 +-82 + +query I rowsort +SELECT ALL 45 AS col0 FROM tab0 GROUP BY tab0.col0 +---- +45 +45 +45 + +query I rowsort +SELECT + 74 AS col1 FROM tab1 GROUP BY tab1.col0 +---- +74 +74 +74 + +query I rowsort +SELECT cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort label-203 +SELECT - cor0.col2 + CAST ( 80 AS INTEGER ) FROM tab1 AS cor0 GROUP BY col2 +---- +35 +72 +9 + +query I rowsort +SELECT DISTINCT - cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +-81 +0 + +query I rowsort +SELECT - 51 * + cor0.col2 FROM tab0, tab2 cor0, tab1 AS cor1 GROUP BY cor0.col2 +---- +-2958 +-4029 +-4437 + +query I rowsort +SELECT ALL + col0 * cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +225 +8281 +8464 + +query I rowsort +SELECT DISTINCT ( col0 ) FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +26 +43 +83 + +query I rowsort +SELECT 87 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +87 +87 +87 + +query I rowsort +SELECT + cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT DISTINCT + 45 col0 FROM tab1 AS cor0 GROUP BY col0 +---- +45 + +query I rowsort label-211 +SELECT ALL CAST ( NULL AS INTEGER ) FROM tab2 AS cor0 GROUP BY col1 +---- +NULL +NULL +NULL + +query I rowsort +SELECT ALL cor0.col1 + col1 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +0 +162 + +query I rowsort +SELECT - cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +-81 +0 + +query I rowsort +SELECT DISTINCT + 99 * 76 + + tab2.col1 AS col2 FROM tab2 GROUP BY col1 +---- +7565 +7583 +7585 + +query I rowsort +SELECT ALL 54 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +54 +54 +54 + +query I rowsort +SELECT + cor0.col2 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col2, cor0.col0 +---- +58 +79 +87 + +query I rowsort +SELECT cor0.col0 + + 87 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +109 +115 +169 + +query I rowsort +SELECT cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, cor0.col1, cor0.col0 +---- +15 +91 +92 + +query I rowsort +SELECT ALL col0 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT DISTINCT - cor0.col0 - + cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +-182 +-184 +-30 + +query I rowsort +SELECT ALL - 68 * + cor0.col1 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col1 +---- +-5508 +0 + +query I rowsort +SELECT col2 AS col2 FROM tab0 AS cor0 GROUP BY cor0.col1, cor0.col2 +---- +24 +38 +79 + +query I rowsort +SELECT ALL - 11 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +-11 +-11 +-11 + +query I rowsort +SELECT 66 AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +66 +66 +66 + +query I rowsort +SELECT - cor0.col2 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +-58 +-79 +-87 + +query I rowsort +SELECT ALL 37 FROM tab2, tab0 AS cor0 GROUP BY cor0.col1 +---- +37 +37 + +query I rowsort +SELECT DISTINCT + 20 col2 FROM tab0 GROUP BY tab0.col1 +---- +20 + +query I rowsort +SELECT 42 FROM tab0 cor0 GROUP BY col2 +---- +42 +42 +42 + +query I rowsort +SELECT ALL - cor0.col1 AS col1 FROM tab1 cor0 GROUP BY cor0.col1 +---- +-44 +-57 +-6 + +query I rowsort +SELECT - col2 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +-58 +-79 +-87 + +query I rowsort +SELECT DISTINCT + 86 FROM tab1 GROUP BY tab1.col2 +---- +86 + +query I rowsort +SELECT + cor0.col1 AS col1 FROM tab2, tab0 cor0 GROUP BY cor0.col1 +---- +0 +81 + +query I rowsort +SELECT - 13 FROM tab0 cor0 GROUP BY cor0.col1 +---- +-13 +-13 + +query I rowsort +SELECT tab1.col0 AS col1 FROM tab1 GROUP BY tab1.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ALL cor0.col1 * cor0.col1 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +1681 +3481 +3721 + +query I rowsort +SELECT - cor0.col0 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +-15 +-91 +-92 + +query I rowsort +SELECT cor0.col2 FROM tab1 AS cor0 GROUP BY cor0.col0, cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT ALL - 67 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +-67 +-67 +-67 + +query I rowsort +SELECT + 75 AS col2 FROM tab1 cor0 GROUP BY cor0.col0 +---- +75 +75 +75 + +query I rowsort +SELECT ALL cor0.col1 FROM tab0 AS cor0 GROUP BY col0, cor0.col1 +---- +0 +0 +81 + +query I rowsort +SELECT ALL + cor0.col1 FROM tab0 AS cor0 GROUP BY col1 +---- +0 +81 + +query I rowsort +SELECT DISTINCT - 38 - - cor0.col0 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +-12 +45 +5 + +query I rowsort +SELECT + cor0.col0 + - col0 + 21 AS col0 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +21 +21 +21 + +query I rowsort +SELECT + cor0.col0 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col0, cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ALL - cor0.col0 FROM tab0 AS cor0 GROUP BY cor0.col0, cor0.col0 +---- +-26 +-43 +-83 + +query III rowsort +SELECT * FROM tab0 AS cor0 GROUP BY cor0.col2, cor0.col1, cor0.col0 +---- +26 0 79 +43 81 24 +83 0 38 + +query I rowsort +SELECT DISTINCT + + tab2.col2 FROM tab2, tab1 AS cor0 GROUP BY tab2.col2 +---- +58 +79 +87 + +query I rowsort +SELECT cor0.col0 AS col1 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +15 +91 +92 + +query I rowsort +SELECT col0 AS col0 FROM tab2 AS cor0 GROUP BY cor0.col0 +---- +15 +91 +92 + +query I rowsort +SELECT - cor0.col0 AS col1 FROM tab1 AS cor0 GROUP BY col0 +---- +-22 +-28 +-82 + +query I rowsort +SELECT DISTINCT ( + 71 ) col1 FROM tab1 GROUP BY tab1.col2 +---- +71 + +query I rowsort +SELECT + 96 * 29 col1 FROM tab2, tab1 AS cor0 GROUP BY tab2.col0 +---- +2784 +2784 +2784 + +query I rowsort +SELECT + 3 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +3 +3 +3 + +query I rowsort +SELECT 37 FROM tab0 AS cor0 GROUP BY col0 +---- +37 +37 +37 + +query I rowsort +SELECT 82 FROM tab0 cor0 GROUP BY cor0.col1 +---- +82 +82 + +query I rowsort +SELECT cor0.col2 FROM tab2 cor0 GROUP BY cor0.col2 +---- +58 +79 +87 + +query I rowsort +SELECT DISTINCT - 87 FROM tab1, tab2 AS cor0, tab2 AS cor1 GROUP BY tab1.col0 +---- +-87 + +query I rowsort +SELECT 55 FROM tab1 AS cor0 GROUP BY cor0.col2, cor0.col1 +---- +55 +55 +55 + +query I rowsort +SELECT DISTINCT 35 FROM tab0 cor0 GROUP BY cor0.col2, cor0.col0 +---- +35 + +query I rowsort +SELECT cor0.col0 FROM tab2 cor0 GROUP BY col0 +---- +15 +91 +92 + +query I rowsort +SELECT - cor0.col2 AS col1 FROM tab1 AS cor0 GROUP BY col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT ALL ( cor0.col2 ) AS col1 FROM tab2, tab1 AS cor0 GROUP BY cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT DISTINCT - col2 FROM tab1 GROUP BY tab1.col2 +---- +-45 +-71 +-8 + +query I rowsort +SELECT 38 FROM tab1 AS cor0 GROUP BY cor0.col1, cor0.col1 +---- +38 +38 +38 + +query I rowsort +SELECT - 16 * - cor0.col0 * 47 FROM tab0 AS cor0 GROUP BY cor0.col0 +---- +19552 +32336 +62416 + +query I rowsort +SELECT - 31 FROM tab2 AS cor0 GROUP BY cor0.col2 +---- +-31 +-31 +-31 + +query I rowsort +SELECT ( + 34 ) AS col1 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +34 +34 +34 + +query I rowsort +SELECT cor0.col2 AS col0 FROM tab1 AS cor0 GROUP BY cor0.col2 +---- +45 +71 +8 + +query I rowsort +SELECT DISTINCT 21 FROM tab0 AS cor0 GROUP BY cor0.col2 +---- +21 + +query I rowsort +SELECT 62 AS col2 FROM tab0 cor0 GROUP BY cor0.col1, cor0.col2 +---- +62 +62 +62 + +query I rowsort +SELECT cor0.col0 FROM tab1 cor0 GROUP BY cor0.col0, cor0.col1 +---- +22 +28 +82 + +query I rowsort +SELECT DISTINCT cor0.col0 FROM tab2 AS cor0 GROUP BY cor0.col0, col1 +---- +15 +91 +92 + +query I rowsort +SELECT DISTINCT cor0.col0 AS col2 FROM tab1 AS cor0 GROUP BY cor0.col0 +---- +22 +28 +82 + +query I rowsort +SELECT ALL - ( 30 ) * + cor0.col1 AS col2 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +-1230 +-1770 +-1830 + +query I rowsort +SELECT DISTINCT 94 AS col1 FROM tab0 AS cor0 GROUP BY cor0.col1 +---- +94 + +query I rowsort +SELECT DISTINCT + col1 FROM tab2 AS cor0 GROUP BY cor0.col1 +---- +41 +59 +61 + +# Group By All tests +statement ok +CREATE TABLE tab3(col0 INTEGER, col1 INTEGER, col2 INTEGER, col3 INTEGER) + +statement ok +INSERT INTO tab3 VALUES(0,1,12,-1) + +statement ok +INSERT INTO tab3 VALUES(0,2,13,-1) + +statement ok +INSERT INTO tab3 VALUES(0,1,10,-2) + +statement ok +INSERT INTO tab3 VALUES(0,2,15,-2) + +statement ok +INSERT INTO tab3 VALUES(1, NULL, 10, -2) + +query IRI rowsort +SELECT col1, AVG(col2), col0 FROM tab3 GROUP BY ALL +---- +1 11 0 +2 14 0 +NULL 10 1 + +query IIR rowsort +SELECT sub.col1, sub.col0, AVG(sub.col2) AS avg_col2 +FROM ( + SELECT col1, col0, col2 + FROM tab3 + WHERe col3 = -1 + GROUP BY ALL +) AS sub +GROUP BY ALL; +---- +1 0 12 +2 0 13 + +query IIR rowsort +SELECT sub.col1, sub.col0, sub."AVG(tab3.col2)" AS avg_col2 +FROM ( + SELECT col1, AVG(col2), col0 FROM tab3 GROUP BY ALL +) AS sub +GROUP BY ALL; +---- +1 0 11 +2 0 14 +NULL 1 10 + +query IIII rowsort +SELECT col0, col1, COUNT(col2), SUM(col3) FROM tab3 GROUP BY ALL +---- +0 1 2 -3 +0 2 2 -3 +1 NULL 1 -2 + +# query below should work in multi partition, successfully. +query II +SELECT l.col0, LAST_VALUE(r.col1 ORDER BY r.col0) as last_col1 +FROM tab0 as l +JOIN tab0 as r +ON l.col0 = r.col0 +GROUP BY l.col0, l.col1, l.col2 +ORDER BY l.col0; +---- +26 0 +43 81 +83 0 + +# assert that above query works in indeed multi partitions +# physical plan for this query should contain RepartitionExecs. +# Aggregation should be in two stages, Partial + FinalPartitioned stages. +query TT +EXPLAIN SELECT l.col0, LAST_VALUE(r.col1 ORDER BY r.col0) as last_col1 +FROM tab0 as l +JOIN tab0 as r +ON l.col0 = r.col0 +GROUP BY l.col0, l.col1, l.col2 +ORDER BY l.col0; +---- +logical_plan +Sort: l.col0 ASC NULLS LAST +--Projection: l.col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST] AS last_col1 +----Aggregate: groupBy=[[l.col0, l.col1, l.col2]], aggr=[[LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]]] +------Inner Join: l.col0 = r.col0 +--------SubqueryAlias: l +----------TableScan: tab0 projection=[col0, col1, col2] +--------SubqueryAlias: r +----------TableScan: tab0 projection=[col0, col1] +physical_plan +SortPreservingMergeExec: [col0@0 ASC NULLS LAST] +--SortExec: expr=[col0@0 ASC NULLS LAST] +----ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] +------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallySorted([0]) +--------------SortExec: expr=[col0@3 ASC NULLS LAST] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[3] + +# Columns in the table are a,b,c,d. Source is CsvExec which is ordered by +# a,b,c column. Column a has cardinality 2, column b has cardinality 4. +# Column c has cardinality 100 (unique entries). Column d has cardinality 5. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite2 ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC, c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Create a table with 2 ordered columns. +# In the next step, we will expect to observe the removed sort execs. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Expected a sort exec for b DESC +query TT +EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY b DESC; +---- +logical_plan +Projection: multiple_ordered_table.a +--Sort: multiple_ordered_table.b DESC NULLS FIRST +----TableScan: multiple_ordered_table projection=[a, b] +physical_plan +ProjectionExec: expr=[a@0 as a] +--SortExec: expr=[b@1 DESC] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true + +# Final plan shouldn't have SortExec c ASC, +# because table already satisfies this ordering. +query TT +EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY c ASC; +---- +logical_plan +Projection: multiple_ordered_table.a +--Sort: multiple_ordered_table.c ASC NULLS LAST +----TableScan: multiple_ordered_table projection=[a, c] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# Final plan shouldn't have SortExec a ASC, b ASC, +# because table already satisfies this ordering. +query TT +EXPLAIN SELECT a FROM multiple_ordered_table ORDER BY a ASC, b ASC; +---- +logical_plan +Projection: multiple_ordered_table.a +--Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST +----TableScan: multiple_ordered_table projection=[a, b] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# test_window_agg_sort +statement ok +set datafusion.execution.target_partitions = 1; + +# test_source_sorted_groupby +query TT +EXPLAIN SELECT a, b, + SUM(c) as summation1 + FROM annotated_data_infinite2 + GROUP BY b, a +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotated_data_infinite2.c) AS summation1 +--Aggregate: groupBy=[[annotated_data_infinite2.b, annotated_data_infinite2.a]], aggr=[[SUM(CAST(annotated_data_infinite2.c AS Int64))]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] +--AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + + +query III + SELECT a, b, + SUM(c) as summation1 + FROM annotated_data_infinite2 + GROUP BY b, a +---- +0 0 300 +0 1 925 +1 2 1550 +1 3 2175 + + +# test_source_sorted_groupby2 +# If ordering is not important for the aggregation function, we should ignore the ordering requirement. Hence +# "ORDER BY a DESC" should have no effect. +query TT +EXPLAIN SELECT a, d, + SUM(c ORDER BY a DESC) as summation1 + FROM annotated_data_infinite2 + GROUP BY d, a +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS summation1 +--Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +----TableScan: annotated_data_infinite2 projection=[a, c, d] +physical_plan +ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] +--AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallySorted([1]) +----StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + +query III +SELECT a, d, + SUM(c ORDER BY a DESC) as summation1 + FROM annotated_data_infinite2 + GROUP BY d, a +---- +0 0 292 +0 2 196 +0 1 315 +0 4 164 +0 3 258 +1 0 622 +1 3 299 +1 1 1043 +1 4 913 +1 2 848 + +# test_source_sorted_groupby3 + +query TT +EXPLAIN SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS first_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +query III +SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 0 +0 1 25 +1 2 50 +1 3 75 + +# test_source_sorted_groupby4 + +query TT +EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS last_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +query III +SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 24 +0 1 49 +1 2 74 +1 3 99 + +# when LAST_VALUE, or FIRST_VALUE value do not contain ordering requirement +# queries should still work, However, result depends on the scanning order and +# not deterministic +query TT +EXPLAIN SELECT a, b, LAST_VALUE(c) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) AS last_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c)]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +query III +SELECT a, b, LAST_VALUE(c) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 24 +0 1 49 +1 2 74 +1 3 99 + +statement ok +drop table annotated_data_infinite2; + +# create a table for testing +statement ok +CREATE TABLE sales_global (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), + (0, 'GRC', 4, '2022-01-03 10:00:00'::timestamp, 'EUR', 80.0) + +# create a new table named exchange rates +statement ok +CREATE TABLE exchange_rates ( + sn INTEGER, + ts TIMESTAMP, + currency_from VARCHAR(3), + currency_to VARCHAR(3), + rate DECIMAL(10,2) +) as VALUES + (0, '2022-01-01 06:00:00'::timestamp, 'EUR', 'USD', 1.10), + (1, '2022-01-01 08:00:00'::timestamp, 'TRY', 'USD', 0.10), + (2, '2022-01-01 11:30:00'::timestamp, 'EUR', 'USD', 1.12), + (3, '2022-01-02 12:00:00'::timestamp, 'TRY', 'USD', 0.11), + (4, '2022-01-03 10:00:00'::timestamp, 'EUR', 'USD', 1.12) + +# test_ordering_sensitive_aggregation +# ordering sensitive requirement should add a SortExec in the final plan. To satisfy amount ASC +# in the aggregation +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +----TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] +----SortExec: expr=[amount@1 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + + +query T? +SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts + FROM sales_global + GROUP BY country +---- +GRC [30.0, 80.0] +FRA [50.0, 200.0] +TUR [75.0, 100.0] + +# test_ordering_sensitive_aggregation2 +# We should be able to satisfy the finest requirement among all aggregators, when we have multiple aggregators. +# Hence final plan should have SortExec: expr=[amount@1 DESC] to satisfy array_agg requirement. +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM sales_global AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +----SubqueryAlias: s +------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)] +----SortExec: expr=[amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM sales_global AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +TUR [100.0, 75.0] 175 +GRC [80.0, 30.0] 110 + +# test_ordering_sensitive_aggregation3 +# When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. +# test below should raise Plan Error. +statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +SELECT ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + ARRAY_AGG(s.amount ORDER BY s.amount ASC) AS amounts2, + ARRAY_AGG(s.amount ORDER BY s.sn ASC) AS amounts3 + FROM sales_global AS s + GROUP BY s.country + +# test_ordering_sensitive_aggregation4 +# If aggregators can work with bounded memory (Sorted or PartiallySorted mode), we should append requirement to +# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. +# This test checks for whether we can satisfy aggregation requirement in Sorted mode. +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted +----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 + +# test_ordering_sensitive_aggregation5 +# If aggregators can work with bounded memory (Sorted or PartiallySorted mode), we should be append requirement to +# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. +# This test checks for whether we can satisfy aggregation requirement in PartiallySorted mode. +query TT +EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country, s.zip_code +---- +logical_plan +Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[zip_code, country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] +--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=PartiallySorted([0]) +----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TI?R +SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country, s.zip_code +---- +FRA 1 [200.0, 50.0] 250 +GRC 0 [80.0, 30.0] 110 +TUR 1 [100.0, 75.0] 175 + +# test_ordering_sensitive_aggregation6 +# If aggregators can work with bounded memory (FullySorted or PartiallySorted mode), we should be append requirement to +# the existing ordering. When group by expressions contain aggregation requirement, we shouldn't append redundant expression. +# Hence in the final plan SortExec should be SortExec: expr=[country@0 DESC] not SortExec: expr=[country@0 ASC NULLS LAST,country@0 DESC] +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted +----SortExec: expr=[country@0 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 + +# test_ordering_sensitive_aggregation7 +# Lexicographical ordering requirement can be given as +# argument to the aggregate functions +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=Sorted +----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 + +# test_reverse_aggregate_expr +# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering +# that have contradictory requirements at first glance. +query TT +EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +----TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----SortExec: expr=[amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?RR +SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country +---- +FRA [200.0, 50.0] 50 50 +TUR [100.0, 75.0] 75 75 +GRC [80.0, 30.0] 30 30 + +# test_reverse_aggregate_expr2 +# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering +# that have contradictory requirements at first glance. +query TT +EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +----TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +----SortExec: expr=[amount@1 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?RR +SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country +---- +GRC [30.0, 80.0] 30 30 +FRA [50.0, 200.0] 50 50 +TUR [75.0, 100.0] 75 75 + +# test_reverse_aggregate_expr3 +# Some of the Aggregators can be reversed, by this way we can still run aggregators without re-ordering +# that have contradictory requirements at first glance. This algorithm shouldn't depend +# on the order of the aggregation expressions. +query TT +EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2, + ARRAY_AGG(amount ORDER BY amount ASC) AS amounts + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +----TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), ARRAY_AGG(sales_global.amount)] +----SortExec: expr=[amount@1 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TRR? +SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2, + ARRAY_AGG(amount ORDER BY amount ASC) AS amounts + FROM sales_global + GROUP BY country +---- +GRC 30 30 [30.0, 80.0] +FRA 50 50 [50.0, 200.0] +TUR 75 75 [75.0, 100.0] + +# test_reverse_aggregate_expr4 +# Ordering requirement by the ordering insensitive aggregators shouldn't have effect on +# final plan. Hence seemingly conflicting requirements by SUM and ARRAY_AGG shouldn't raise error. +query TT +EXPLAIN SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, + ARRAY_AGG(amount ORDER BY amount ASC) AS amounts + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +--Aggregate: groupBy=[[sales_global.country]], aggr=[[SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +----TableScan: sales_global projection=[country, ts, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[SUM(sales_global.amount), ARRAY_AGG(sales_global.amount)] +----SortExec: expr=[amount@2 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TR? +SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, + ARRAY_AGG(amount ORDER BY amount ASC) AS amounts + FROM sales_global + GROUP BY country +---- +GRC 110 [30.0, 80.0] +FRA 250 [50.0, 200.0] +TUR 175 [75.0, 100.0] + +# test_reverse_aggregate_expr5 +# If all of the ordering sensitive aggregation functions are reversible +# we should be able to reverse requirements, if this helps to remove a SortExec. +# Hence in query below, FIRST_VALUE, and LAST_VALUE should be reversed to calculate its result according to `ts ASC` ordering. +# Please note that after `ts ASC` ordering because of inner query. There is no SortExec in the final plan. +query TT +EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, + LAST_VALUE(amount ORDER BY ts DESC) as lv1, + SUM(amount ORDER BY ts DESC) as sum1 + FROM (SELECT * + FROM sales_global + ORDER BY ts ASC) + GROUP BY country +---- +logical_plan +Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 +--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +----Sort: sales_global.ts ASC NULLS LAST +------TableScan: sales_global projection=[country, ts, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[LAST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount), SUM(sales_global.amount)] +----SortExec: expr=[ts@1 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TRRR +SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, + LAST_VALUE(amount ORDER BY ts DESC) as lv1, + SUM(amount ORDER BY ts DESC) as sum1 + FROM (SELECT * + FROM sales_global + ORDER BY ts ASC) + GROUP BY country +---- +GRC 80 30 110 +FRA 200 50 250 +TUR 100 75 175 + +# If existing ordering doesn't satisfy requirement, we should do calculations +# on naive requirement (by convention, otherwise the final plan will be unintuitive), +# even if reverse ordering is possible. +# hence, below query should add `SortExec(ts DESC)` to the final plan. +query TT +EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, + LAST_VALUE(amount ORDER BY ts DESC) as lv1, + SUM(amount ORDER BY ts DESC) as sum1 + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 +--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +----TableScan: sales_global projection=[country, ts, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount), SUM(sales_global.amount)] +----SortExec: expr=[ts@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TRRR +SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, + LAST_VALUE(amount ORDER BY ts DESC) as lv1, + SUM(amount ORDER BY ts DESC) as sum1 + FROM sales_global + GROUP BY country +---- +TUR 100 75 175 +GRC 80 30 110 +FRA 200 50 250 + +query TT +EXPLAIN SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +logical_plan +Sort: s.sn ASC NULLS LAST +--Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST] AS last_rate +----Aggregate: groupBy=[[s.sn, s.zip_code, s.country, s.ts, s.currency]], aggr=[[LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]]] +------Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, e.sn, e.amount +--------Inner Join: s.currency = e.currency Filter: s.ts >= e.ts +----------SubqueryAlias: s +------------TableScan: sales_global projection=[zip_code, country, sn, ts, currency] +----------SubqueryAlias: e +------------TableScan: sales_global projection=[sn, ts, currency, amount] +physical_plan +SortExec: expr=[sn@2 ASC NULLS LAST] +--ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] +----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount)] +------SortExec: expr=[sn@5 ASC NULLS LAST] +--------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, sn@5 as sn, amount@8 as amount] +----------CoalesceBatchesExec: target_batch_size=8192 +------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@4, currency@2)], filter=ts@0 >= ts@1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query ITIPTR +SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 4 2022-01-03T10:00:00 TRY 100 + +# Run order-sensitive aggregators in multiple partitions +statement ok +set datafusion.execution.target_partitions = 8; + +# order-sensitive FIRST_VALUE and LAST_VALUE aggregators should work in +# multi-partitions without group by also. +query TT +EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts ASC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +logical_plan +Sort: sales_global.country ASC NULLS LAST +--Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 +----Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +------TableScan: sales_global projection=[country, ts, amount] +physical_plan +SortPreservingMergeExec: [country@0 ASC NULLS LAST] +--SortExec: expr=[country@0 ASC NULLS LAST] +----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@2 as fv2] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----------------SortExec: expr=[ts@1 ASC NULLS LAST] +------------------MemoryExec: partitions=1, partition_sizes=[1] + +query TRR +SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts ASC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +FRA 50 200 +GRC 30 80 +TUR 75 100 + +# Conversion in between FIRST_VALUE and LAST_VALUE to resolve +# contradictory requirements should work in multi partitions. +query TT +EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts DESC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +logical_plan +Sort: sales_global.country ASC NULLS LAST +--Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 +----Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +------TableScan: sales_global projection=[country, ts, amount] +physical_plan +SortPreservingMergeExec: [country@0 ASC NULLS LAST] +--SortExec: expr=[country@0 ASC NULLS LAST] +----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +----------------SortExec: expr=[ts@1 ASC NULLS LAST] +------------------MemoryExec: partitions=1, partition_sizes=[1] + +query TRR +SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts DESC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +FRA 50 50 +GRC 30 30 +TUR 75 75 + +# make sure that batch size is small. So that query below runs in multi partitions +# row number of the sales_global is 5. Hence we choose batch size 4 to make is smaller. +statement ok +set datafusion.execution.batch_size = 4; + +# order-sensitive FIRST_VALUE and LAST_VALUE aggregators should work in +# multi-partitions without group by also. +query TT +EXPLAIN SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts ASC) AS fv2 + FROM sales_global +---- +logical_plan +Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 +--Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +----TableScan: sales_global projection=[ts, amount] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv2] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------SortExec: expr=[ts@0 ASC NULLS LAST] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query RR +SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts ASC) AS fv2 + FROM sales_global +---- +30 80 + +# Conversion in between FIRST_VALUE and LAST_VALUE to resolve +# contradictory requirements should work in multi partitions. +query TT +EXPLAIN SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts DESC) AS fv2 + FROM sales_global +---- +logical_plan +Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 +--Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +----TableScan: sales_global projection=[ts, amount] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] +--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--------SortExec: expr=[ts@0 ASC NULLS LAST] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query RR +SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, + LAST_VALUE(amount ORDER BY ts DESC) AS fv2 + FROM sales_global +---- +30 30 + +# ARRAY_AGG should work in multiple partitions +query TT +EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts ASC) AS array_agg1 + FROM sales_global +---- +logical_plan +Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS array_agg1 +--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +----TableScan: sales_global projection=[ts, amount] +physical_plan +ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] +--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +--------SortExec: expr=[ts@0 ASC NULLS LAST] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query ? +SELECT ARRAY_AGG(amount ORDER BY ts ASC) AS array_agg1 + FROM sales_global +---- +[30.0, 50.0, 75.0, 200.0, 100.0, 80.0] + +# ARRAY_AGG should work in multiple partitions +query TT +EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts DESC) AS array_agg1 + FROM sales_global +---- +logical_plan +Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS array_agg1 +--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +----TableScan: sales_global projection=[ts, amount] +physical_plan +ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] +--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +--------SortExec: expr=[ts@0 DESC] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query ? +SELECT ARRAY_AGG(amount ORDER BY ts DESC) AS array_agg1 + FROM sales_global +---- +[100.0, 80.0, 200.0, 75.0, 50.0, 30.0] + +# ARRAY_AGG should work in multiple partitions +query TT +EXPLAIN SELECT ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 + FROM sales_global +---- +logical_plan +Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +----TableScan: sales_global projection=[amount] +physical_plan +ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] +--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] +--------SortExec: expr=[amount@0 ASC NULLS LAST] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query ? +SELECT ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 + FROM sales_global +---- +[30.0, 50.0, 75.0, 80.0, 100.0, 200.0] + +# ARRAY_AGG should work in multiple partitions +query TT +EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 + FROM sales_global + GROUP BY country + ORDER BY country +---- +logical_plan +Sort: sales_global.country ASC NULLS LAST +--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +------TableScan: sales_global projection=[country, amount] +physical_plan +SortPreservingMergeExec: [country@0 ASC NULLS LAST] +--SortExec: expr=[country@0 ASC NULLS LAST] +----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] +--------CoalesceBatchesExec: target_batch_size=4 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] +--------------SortExec: expr=[amount@1 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] + +query T? +SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 + FROM sales_global + GROUP BY country + ORDER BY country +---- +FRA [50.0, 200.0] +GRC [30.0, 80.0] +TUR [75.0, 100.0] + +# ARRAY_AGG, FIRST_VALUE, LAST_VALUE should work in multiple partitions +query TT +EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +logical_plan +Sort: sales_global.country ASC NULLS LAST +--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +------TableScan: sales_global projection=[country, amount] +physical_plan +SortPreservingMergeExec: [country@0 ASC NULLS LAST] +--SortExec: expr=[country@0 ASC NULLS LAST] +----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------CoalesceBatchesExec: target_batch_size=4 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------------SortExec: expr=[amount@1 DESC] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] + +query T?RR +SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, + FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, + LAST_VALUE(amount ORDER BY amount DESC) AS fv2 + FROM sales_global + GROUP BY country + ORDER BY country +---- +FRA [200.0, 50.0] 50 50 +GRC [80.0, 30.0] 30 30 +TUR [100.0, 75.0] 75 75 + +# make sure that query below runs in multi partitions +statement ok +set datafusion.execution.target_partitions = 8; + +query ? +SELECT ARRAY_AGG(e.rate ORDER BY e.sn) +FROM sales_global AS s +JOIN exchange_rates AS e +ON s.currency = e.currency_from AND + e.currency_to = 'USD' AND + s.ts >= e.ts +GROUP BY s.sn +ORDER BY s.sn; +---- +[1.10] +[1.10] +[0.10] +[1.10, 1.12] +[1.10, 0.10, 1.12, 0.11, 1.12] + + +query I +SELECT FIRST_VALUE(C order by c ASC) as first_c +FROM multiple_ordered_table +GROUP BY d +ORDER BY first_c +---- +0 +1 +4 +9 +15 + +query ITIPTR +SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate +FROM sales_global AS s +JOIN sales_global AS e + ON s.currency = e.currency AND + s.ts >= e.ts +GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency +ORDER BY s.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +0 GRC 4 2022-01-03T10:00:00 EUR 80 +1 TUR 4 2022-01-03T10:00:00 TRY 100 + +# create a table for testing +statement ok +CREATE TABLE sales_global_with_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# create a table for testing, with primary key alternate syntax +statement ok +CREATE TABLE sales_global_with_pk_alternate (zip_code INT, + country VARCHAR(3), + sn INT primary key, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# we do not currently support foreign key constraints. +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT references sales_global_with_pk_alternate(sn), + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# we do not currently support foreign key +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT REFERENCES sales_global_with_pk_alternate(sn), + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# we do not currently support foreign key +# foreign key can be defined with a different syntax. +# we should get the same error. +statement error DataFusion error: Error during planning: Foreign key constraints are not currently supported +CREATE TABLE sales_global_with_foreign_key (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + FOREIGN KEY (sn) + REFERENCES sales_global_with_pk_alternate(sn) +) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# create a table for testing, where primary key is composite +statement ok +CREATE TABLE sales_global_with_composite_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn, ts) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# create a table for testing, where sn is unique key +statement ok +CREATE TABLE sales_global_with_unique (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + unique(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), + (1, 'TUR', NULL, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# when group by contains primary key expression +# we can use all the expressions in the table during selection +# (not just group by expressions + aggregation result) +query TT +EXPLAIN SELECT s.sn, s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +logical_plan +Sort: s.sn ASC NULLS LAST +--Projection: s.sn, s.amount, Int64(2) * CAST(s.sn AS Int64) +----Aggregate: groupBy=[[s.sn, s.amount]], aggr=[[]] +------SubqueryAlias: s +--------TableScan: sales_global_with_pk projection=[sn, amount] +physical_plan +SortPreservingMergeExec: [sn@0 ASC NULLS LAST] +--SortExec: expr=[sn@0 ASC NULLS LAST] +----ProjectionExec: expr=[sn@0 as sn, amount@1 as amount, 2 * CAST(sn@0 AS Int64) as Int64(2) * s.sn] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[] +--------CoalesceBatchesExec: target_batch_size=4 +----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] +--------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + +query IRI +SELECT s.sn, s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 30 0 +1 50 2 +2 75 4 +3 200 6 +4 100 8 + +# we should be able to re-write group by expression +# using functional dependencies for complex expressions also. +# In this case, we use 2*s.amount instead of s.amount. +query IRI +SELECT s.sn, 2*s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 60 0 +1 100 2 +2 150 4 +3 400 6 +4 200 8 + +query IRI +SELECT s.sn, s.amount, 2*s.sn + FROM sales_global_with_pk_alternate AS s + GROUP BY sn + ORDER BY sn +---- +0 30 0 +1 50 2 +2 75 4 +3 200 6 +4 100 8 + +# Join should propagate primary key successfully +query TT +EXPLAIN SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn +---- +logical_plan +Sort: r.sn ASC NULLS LAST +--Projection: r.sn, SUM(l.amount), r.amount +----Aggregate: groupBy=[[r.sn, r.amount]], aggr=[[SUM(CAST(l.amount AS Float64))]] +------Projection: l.amount, r.sn, r.amount +--------Inner Join: Filter: l.sn >= r.sn +----------SubqueryAlias: l +------------TableScan: sales_global_with_pk projection=[sn, amount] +----------SubqueryAlias: r +------------TableScan: sales_global_with_pk projection=[sn, amount] +physical_plan +SortPreservingMergeExec: [sn@0 ASC NULLS LAST] +--SortExec: expr=[sn@0 ASC NULLS LAST] +----ProjectionExec: expr=[sn@0 as sn, SUM(l.amount)@2 as SUM(l.amount), amount@1 as amount] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[SUM(l.amount)] +--------CoalesceBatchesExec: target_batch_size=4 +----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[SUM(l.amount)] +--------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] +----------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1 +------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +------------------CoalescePartitionsExec +--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + +query IRR +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn +---- +0 455 30 +1 425 50 +2 375 75 +3 300 200 +4 100 100 + +# when primary key consists of composite columns +# to associate it with other fields, aggregate should contain all the composite columns +query IRR +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_composite_pk AS l + JOIN sales_global_with_composite_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn, r.ts + ORDER BY r.sn +---- +0 455 30 +1 425 50 +2 375 75 +3 300 200 +4 100 100 + +# when primary key consists of composite columns +# to associate it with other fields, aggregate should contain all the composite columns +# if any of the composite column is missing, we cannot use associated indices, inside select expression +# below query should fail +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(l.amount\) +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_composite_pk AS l + JOIN sales_global_with_composite_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn + +# left join should propagate right side constraint, +# if right side is a primary key (unique and doesn't contain null) +query IRR +SELECT r.sn, r.amount, SUM(r.amount) + FROM (SELECT * + FROM sales_global_with_pk as l + LEFT JOIN sales_global_with_pk as r + ON l.amount >= r.amount + 10) + GROUP BY r.sn +ORDER BY r.sn +---- +0 30 120 +1 50 150 +2 75 150 +4 100 100 +NULL NULL NULL + +# left join shouldn't propagate right side constraint, +# if right side is a unique key (unique and can contain null) +# Please note that, above query and this one is same except the constraint in the table. +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(r.amount\) +SELECT r.sn, r.amount, SUM(r.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY r.sn +ORDER BY r.sn + +# left semi join should propagate constraint of left side as is. +query IRR +SELECT l.sn, l.amount, SUM(l.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT SEMI JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY l.sn +ORDER BY l.sn +---- +1 50 50 +2 75 75 +3 200 200 +4 100 100 +NULL 100 100 + +# Similarly, left anti join should propagate constraint of left side as is. +query IRR +SELECT l.sn, l.amount, SUM(l.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT ANTI JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY l.sn +ORDER BY l.sn +---- +0 30 30 + +# Should support grouping by list column +query ?I +SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e'], 2), (['a', 'b'], 3)) as values0 GROUP BY column1 ORDER BY column2; +---- +[c, d, e] 1 +[a, b] 2 + + +# primary key should be aware from which columns it is associated +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) +SELECT l.sn, r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY l.sn + ORDER BY l.sn + +# window should propagate primary key successfully +query TT +EXPLAIN SELECT * + FROM(SELECT *, SUM(l.amount) OVER(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum_amount + FROM sales_global_with_pk AS l + ) as l + GROUP BY l.sn + ORDER BY l.sn +---- +logical_plan +Sort: l.sn ASC NULLS LAST +--Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, l.sum_amount +----Aggregate: groupBy=[[l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, l.sum_amount]], aggr=[[]] +------SubqueryAlias: l +--------Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum_amount +----------WindowAggr: windowExpr=[[SUM(CAST(l.amount AS Float64)) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +------------SubqueryAlias: l +--------------TableScan: sales_global_with_pk projection=[zip_code, country, sn, ts, currency, amount] +physical_plan +SortPreservingMergeExec: [sn@2 ASC NULLS LAST] +--SortExec: expr=[sn@2 ASC NULLS LAST] +----ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, zip_code@1 as zip_code, country@2 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] +--------CoalesceBatchesExec: target_batch_size=4 +----------RepartitionExec: partitioning=Hash([sn@0, zip_code@1, country@2, ts@3, currency@4, amount@5, sum_amount@6], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] +------------------BoundedWindowAggExec: wdw=[SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------------------CoalescePartitionsExec +----------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + + +query ITIPTRR +SELECT * + FROM(SELECT *, SUM(l.amount) OVER(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum_amount + FROM sales_global_with_pk AS l + ) as l + GROUP BY l.sn + ORDER BY l.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 80 +1 FRA 1 2022-01-01T08:00:00 EUR 50 155 +1 TUR 2 2022-01-01T11:30:00 TRY 75 325 +1 FRA 3 2022-01-02T12:00:00 EUR 200 375 +1 TUR 4 2022-01-03T10:00:00 TRY 100 300 + +# join should propagate primary key correctly +query IRP +SELECT l.sn, SUM(l.amount), l.ts +FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn) +GROUP BY l.sn +ORDER BY l.sn +---- +0 30 2022-01-01T06:00:00 +1 100 2022-01-01T08:00:00 +2 225 2022-01-01T11:30:00 +3 800 2022-01-02T12:00:00 +4 500 2022-01-03T10:00:00 + +# Projection propagates primary keys correctly +# (we can use r.ts at the final projection, because it +# is associated with primary key r.sn) +query IRP +SELECT r.sn, SUM(r.amount), r.ts +FROM + (SELECT r.ts, r.sn, r.amount + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn)) +GROUP BY r.sn +ORDER BY r.sn +---- +0 150 2022-01-01T06:00:00 +1 200 2022-01-01T08:00:00 +2 225 2022-01-01T11:30:00 +3 400 2022-01-02T12:00:00 +4 100 2022-01-03T10:00:00 + +# after join, new window expressions shouldn't be associated with primary keys +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) +SELECT r.sn, SUM(r.amount), rn1 +FROM + (SELECT r.ts, r.sn, r.amount, + ROW_NUMBER() OVER() AS rn1 + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn)) +GROUP BY r.sn + +# aggregate should propagate primary key successfully +query IPR +SELECT sn, ts, sum1 +FROM ( + SELECT ts, sn, SUM(amount) as sum1 + FROM sales_global_with_pk + GROUP BY sn) +GROUP BY sn +ORDER BY sn +---- +0 2022-01-01T06:00:00 30 +1 2022-01-01T08:00:00 50 +2 2022-01-01T11:30:00 75 +3 2022-01-02T12:00:00 200 +4 2022-01-03T10:00:00 100 + +# aggregate should be able to introduce functional dependence +# (when group by contains single expression, group by expression +# becomes determinant, after aggregation; since we are sure that +# it will consist of unique values.) +# please note that ts is not primary key, still +# we can use sum1, after outer aggregation because +# after inner aggregation, ts becomes determinant +# of functional dependence. +query PR +SELECT ts, sum1 +FROM ( + SELECT ts, SUM(amount) as sum1 + FROM sales_global_with_pk + GROUP BY ts) +GROUP BY ts +ORDER BY ts +---- +2022-01-01T06:00:00 30 +2022-01-01T08:00:00 50 +2022-01-01T11:30:00 75 +2022-01-02T12:00:00 200 +2022-01-03T10:00:00 100 + +# aggregate should update its functional dependence +# mode, if it is guaranteed that, after aggregation +# group by expressions will be unique. +query IRI +SELECT * +FROM ( + SELECT *, ROW_NUMBER() OVER(ORDER BY l.sn) AS rn1 + FROM ( + SELECT l.sn, SUM(l.amount) + FROM ( + SELECT l.sn, l.amount, SUM(l.amount) as sum1 + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn) + GROUP BY l.sn) + GROUP BY l.sn) + ) +GROUP BY l.sn +ORDER BY l.sn +---- +0 30 1 +1 50 2 +2 75 3 +3 200 4 +4 100 5 + +# create a table +statement ok +CREATE TABLE FOO (x int, y int) AS VALUES (1, 2), (2, 3), (1, 3); + +# make sure that query runs in multi partitions +statement ok +set datafusion.execution.target_partitions = 8; + +query I +SELECT LAST_VALUE(x) +FROM FOO; +---- +1 + +query II +SELECT x, LAST_VALUE(x) +FROM FOO +GROUP BY x +ORDER BY x; +---- +1 1 +2 2 + +query II +SELECT y, LAST_VALUE(x) +FROM FOO +GROUP BY y +ORDER BY y; +---- +2 1 +3 1 + +# Make sure to choose a batch size smaller than, row number of the table. +# In this case we choose 2 (Row number of the table is 3). +# otherwise we won't see parallelism in tests. +statement ok +set datafusion.execution.batch_size = 2; + +# plan of the query above should contain partial +# and final aggregation stages +query TT +EXPLAIN SELECT LAST_VALUE(x) + FROM FOO; +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[LAST_VALUE(foo.x)]] +--TableScan: foo projection=[x] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[LAST_VALUE(foo.x)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[LAST_VALUE(foo.x)] +------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +query I +SELECT FIRST_VALUE(x) +FROM FOO; +---- +1 + +# similarly plan of the above query should +# contain partial and final aggregation stages. +query TT +EXPLAIN SELECT FIRST_VALUE(x) + FROM FOO; +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(foo.x)]] +--TableScan: foo projection=[x] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(foo.x)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(foo.x)] +------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +# Since both ordering requirements are satisfied, there shouldn't be +# any SortExec in the final plan. +query TT +EXPLAIN SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +logical_plan +Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST] AS first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] AS last_c +--Aggregate: groupBy=[[multiple_ordered_table.d]], aggr=[[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +----TableScan: multiple_ordered_table projection=[a, c, d] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] +--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a), FIRST_VALUE(multiple_ordered_table.c)] +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +query II rowsort +SELECT FIRST_VALUE(a ORDER BY a ASC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +0 0 +0 1 +0 15 +0 4 +0 9 + +query III rowsort +SELECT d, FIRST_VALUE(c ORDER BY a DESC, c DESC) as first_a, + LAST_VALUE(c ORDER BY c DESC) as last_c +FROM multiple_ordered_table +GROUP BY d; +---- +0 95 0 +1 90 4 +2 97 1 +3 99 15 +4 98 9 + +query TT +EXPLAIN SELECT c +FROM multiple_ordered_table +ORDER BY c ASC; +---- +logical_plan +Sort: multiple_ordered_table.c ASC NULLS LAST +--TableScan: multiple_ordered_table projection=[c] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT LAST_VALUE(l.d ORDER BY l.a) AS amount_usd +FROM multiple_ordered_table AS l +INNER JOIN ( + SELECT *, ROW_NUMBER() OVER (ORDER BY r.a) as row_n FROM multiple_ordered_table AS r +) +ON l.d = r.d AND + l.a >= r.a - 10 +GROUP BY row_n +ORDER BY row_n +---- +logical_plan +Projection: amount_usd +--Sort: row_n ASC NULLS LAST +----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +--------Projection: l.a, l.d, row_n +----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) +------------SubqueryAlias: l +--------------TableScan: multiple_ordered_table projection=[a, d] +------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------------SubqueryAlias: r +------------------TableScan: multiple_ordered_table projection=[a, d] +physical_plan +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# reset partition number to 8. +statement ok +set datafusion.execution.target_partitions = 8; + +# Create an external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER, + primary key(c) +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# We can use column b during selection +# even if it is not among group by expressions +# because column c is primary key. +query TT +EXPLAIN SELECT c, b, SUM(d) +FROM multiple_ordered_table_with_pk +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--SortExec: expr=[c@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +# drop table multiple_ordered_table_with_pk +statement ok +drop table multiple_ordered_table_with_pk; + +# Create an external table with primary key +# column c, in this case use alternative syntax +# for defining primary key +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# We can use column b during selection +# even if it is not among group by expressions +# because column c is primary key. +query TT +EXPLAIN SELECT c, b, SUM(d) +FROM multiple_ordered_table_with_pk +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--SortExec: expr=[c@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT c, sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +GROUP BY c; +---- +logical_plan +Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] +--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) +--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c); +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs + ON lhs.b=rhs.b; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--Inner Join: lhs.b = rhs.b +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@3 as c, sum1@2 as sum1, sum1@5 as sum1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)] +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true +------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 + FROM + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as lhs + CROSS JOIN + (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) as rhs; +---- +logical_plan +Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 +--CrossJoin: +----SubqueryAlias: lhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +----SubqueryAlias: rhs +------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----------TableScan: multiple_ordered_table_with_pk projection=[c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] +--CrossJoinExec +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] +------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# we do not generate physical plan for Repartition yet (e.g Distribute By queries). +query TT +EXPLAIN SELECT a, b, sum1 +FROM (SELECT c, b, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c) +DISTRIBUTE BY a +---- +logical_plan +Repartition: DistributeBy(a) +--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] + +# union with aggregate +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +UNION ALL + SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Union +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +UnionExec +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true +--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# table scan should be simplified. +query TT +EXPLAIN SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +# limit should be simplified +query TT +EXPLAIN SELECT * + FROM (SELECT c, a, SUM(d) as sum1 + FROM multiple_ordered_table_with_pk + GROUP BY c + LIMIT 5) +---- +logical_plan +Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +--Limit: skip=0, fetch=5 +----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] +physical_plan +ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +--GlobalLimitExec: skip=0, fetch=5 +----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +statement ok +set datafusion.execution.target_partitions = 8; + +# Tests for single distinct to group by optimization rule +statement ok +CREATE TABLE t(x int) AS VALUES (1), (2), (1); + +statement ok +create table t1(x bigint,y int) as values (9223372036854775807,2), (9223372036854775806,2); + +query II +SELECT SUM(DISTINCT x), MAX(DISTINCT x) from t GROUP BY x ORDER BY x; +---- +1 1 +2 2 + +query II +SELECT MAX(DISTINCT x), SUM(DISTINCT x) from t GROUP BY x ORDER BY x; +---- +1 1 +2 2 + +query TT +EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT x) FROM t1 GROUP BY y; +---- +logical_plan +Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(DISTINCT CAST(t1.x AS Float64)), MAX(DISTINCT t1.x)]] +----TableScan: t1 projection=[x, y] +physical_plan +ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 +--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------AggregateExec: mode=Partial, gby=[y@1 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y; +---- +logical_plan +Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x) +--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]] +----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]] +------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y +--------TableScan: t1 projection=[x, y] +physical_plan +ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)] +--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 +--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[] +------------CoalesceBatchesExec: target_batch_size=2 +--------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8 +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] +--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] +----------------------MemoryExec: partitions=1, partition_sizes=[1] + +# create an unbounded table that contains ordered timestamp. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE unbounded_csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv' + +# below query should work in streaming mode. +query TT +EXPLAIN SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: time_chunks DESC NULLS FIRST, fetch=5 +----Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks +------Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: unbounded_csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 +----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] +------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------CoalesceBatchesExec: target_batch_size=2 +----------SortPreservingRepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] + +query P +SELECT date_bin('15 minutes', ts) as time_chunks + FROM unbounded_csv_with_timestamps + GROUP BY date_bin('15 minutes', ts) + ORDER BY time_chunks DESC + LIMIT 5; +---- +2018-12-13T12:00:00 +2018-11-13T17:00:00 + +# Since extract is not a monotonic function, below query should not run. +# when source is unbounded. +query error +SELECT extract(month from ts) as months + FROM unbounded_csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; + +# Create a table where timestamp is ordered +statement ok +CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts DESC) +LOCATION '../core/tests/data/timestamps.csv'; + +# below query should run since it operates on a bounded source and have a sort +# at the top of its plan. +query TT +EXPLAIN SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: months DESC NULLS FIRST, fetch=5 +----Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months +------Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] +--------TableScan: csv_with_timestamps projection=[ts] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [months@0 DESC], fetch=5 +----SortExec: TopK(fetch=5), expr=[months@0 DESC] +------ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] +--------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 +--------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false + +query R +SELECT extract(month from ts) as months + FROM csv_with_timestamps + GROUP BY extract(month from ts) + ORDER BY months DESC + LIMIT 5; +---- +12 +11 + +statement ok +drop table t1 + +# Reproducer for https://github.com/apache/arrow-datafusion/issues/8175 + +statement ok +create table t1(state string, city string, min_temp float, area int, time timestamp) as values + ('MA', 'Boston', 70.4, 1, 50), + ('MA', 'Bedford', 71.59, 2, 150); + +query RI +select date_part('year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970 1 + +query PI +select date_bin(interval '1 year', time) as bla, count(distinct state) as count from t1 group by bla; +---- +1970-01-01T00:00:00 1 + +statement ok +drop table t1 + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TIIII +SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 5 1 -101 32064 +b 5 1 -117 25286 +c 5 1 -117 29106 +d 5 1 -99 31106 +e 5 1 -95 32514 + +query TT +EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +logical_plan +Sort: aggregate_test_100.c1 ASC NULLS LAST +--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]] +------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] +physical_plan +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--SortExec: expr=[c1@0 ASC NULLS LAST] +----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 +--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3] +----------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true + +# Use PostgreSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +query II +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 +2 17 +3 13 +4 19 +5 11 + +query III +SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER (WHERE c1 != 'b') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 17 19 +2 17 18 +3 13 17 +4 19 18 +5 11 9 + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +statement ok +drop table aggregate_test_100; + + +# Create an unbounded external table with primary key +# column c +statement ok +CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER primary key, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Query below can be executed, since c is primary key. +query III rowsort +SELECT c, a, SUM(d) +FROM unbounded_multiple_ordered_table_with_pk +GROUP BY c +ORDER BY c +LIMIT 5 +---- +0 0 0 +1 0 2 +2 0 0 +3 0 0 +4 0 1 diff --git a/datafusion/core/tests/sqllogictests/test_files/identifiers.slt b/datafusion/sqllogictest/test_files/identifiers.slt similarity index 98% rename from datafusion/core/tests/sqllogictests/test_files/identifiers.slt rename to datafusion/sqllogictest/test_files/identifiers.slt index c4605979d1515..f60d60b2bfe03 100644 --- a/datafusion/core/tests/sqllogictests/test_files/identifiers.slt +++ b/datafusion/sqllogictest/test_files/identifiers.slt @@ -23,7 +23,7 @@ CREATE EXTERNAL TABLE case_insensitive_test ( ) STORED AS CSV WITH HEADER ROW -LOCATION './tests/data/example.csv' +LOCATION '../core/tests/data/example.csv' # normalized column identifiers query II diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt new file mode 100644 index 0000000000000..5c6bf6e2dac13 --- /dev/null +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -0,0 +1,542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Verify the information schema does not exit by default +statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found +SELECT * from information_schema.tables + +statement error DataFusion error: Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +show all + +# Turn it on + +# expect that the queries now work +statement ok +set datafusion.catalog.information_schema = true; + +# Verify the information schema now does exist and is empty +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +# Disable information_schema and verify it now errors again +statement ok +set datafusion.catalog.information_schema = false + +statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found +SELECT * from information_schema.tables + +statement error Error during planning: table 'datafusion.information_schema.columns' not found +SELECT * from information_schema.columns; + + +############ +## Enable information schema for the rest of the test +############ +statement ok +set datafusion.catalog.information_schema = true + +############ +# New tables should show up in information schema +########### +statement ok +create table t as values (1); + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +datafusion public t BASE TABLE + +# Another new table should show up in information schema +statement ok +create table t2 as values (1); + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +datafusion public t BASE TABLE +datafusion public t2 BASE TABLE + +query TTTT rowsort +SELECT * from information_schema.tables WHERE tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +query TTTT rowsort +SELECT * from information_schema.tables WHERE information_schema.tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +query TTTT rowsort +SELECT * from information_schema.tables WHERE datafusion.information_schema.tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +# Cleanup +statement ok +drop table t + +statement ok +drop table t2 + +############ +## SHOW VARIABLES should work +########### + +# target_partitions defaults to num_cores, so set +# to a known value that is unlikely to be +# the real number of cores on a system +statement ok +SET datafusion.execution.target_partitions=7 + +# planning_concurrency defaults to num_cores, so set +# to a known value that is unlikely to be +# the real number of cores on a system +statement ok +SET datafusion.execution.planning_concurrency=13 + +# pin the version string for test +statement ok +SET datafusion.execution.parquet.created_by=datafusion + +# show all variables +query TT rowsort +SHOW ALL +---- +datafusion.catalog.create_default_catalog_and_schema true +datafusion.catalog.default_catalog datafusion +datafusion.catalog.default_schema public +datafusion.catalog.format NULL +datafusion.catalog.has_header false +datafusion.catalog.information_schema true +datafusion.catalog.location NULL +datafusion.execution.aggregate.scalar_update_factor 10 +datafusion.execution.batch_size 8192 +datafusion.execution.coalesce_batches true +datafusion.execution.collect_statistics false +datafusion.execution.max_buffered_batches_per_output_file 2 +datafusion.execution.meta_fetch_concurrency 32 +datafusion.execution.minimum_parallel_output_files 4 +datafusion.execution.parquet.allow_single_file_parallelism true +datafusion.execution.parquet.bloom_filter_enabled false +datafusion.execution.parquet.bloom_filter_fpp NULL +datafusion.execution.parquet.bloom_filter_ndv NULL +datafusion.execution.parquet.column_index_truncate_length NULL +datafusion.execution.parquet.compression zstd(3) +datafusion.execution.parquet.created_by datafusion +datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 +datafusion.execution.parquet.data_pagesize_limit 1048576 +datafusion.execution.parquet.dictionary_enabled NULL +datafusion.execution.parquet.dictionary_page_size_limit 1048576 +datafusion.execution.parquet.enable_page_index true +datafusion.execution.parquet.encoding NULL +datafusion.execution.parquet.max_row_group_size 1048576 +datafusion.execution.parquet.max_statistics_size NULL +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 +datafusion.execution.parquet.metadata_size_hint NULL +datafusion.execution.parquet.pruning true +datafusion.execution.parquet.pushdown_filters false +datafusion.execution.parquet.reorder_filters false +datafusion.execution.parquet.skip_metadata true +datafusion.execution.parquet.statistics_enabled NULL +datafusion.execution.parquet.write_batch_size 1024 +datafusion.execution.parquet.writer_version 1.0 +datafusion.execution.planning_concurrency 13 +datafusion.execution.soft_max_rows_per_output_file 50000000 +datafusion.execution.sort_in_place_threshold_bytes 1048576 +datafusion.execution.sort_spill_reservation_bytes 10485760 +datafusion.execution.target_partitions 7 +datafusion.execution.time_zone +00:00 +datafusion.explain.logical_plan_only false +datafusion.explain.physical_plan_only false +datafusion.explain.show_statistics false +datafusion.optimizer.allow_symmetric_joins_without_pruning true +datafusion.optimizer.default_filter_selectivity 20 +datafusion.optimizer.enable_distinct_aggregation_soft_limit true +datafusion.optimizer.enable_round_robin_repartition true +datafusion.optimizer.enable_topk_aggregation true +datafusion.optimizer.filter_null_join_keys false +datafusion.optimizer.hash_join_single_partition_threshold 1048576 +datafusion.optimizer.max_passes 3 +datafusion.optimizer.prefer_existing_sort false +datafusion.optimizer.prefer_hash_join true +datafusion.optimizer.repartition_aggregations true +datafusion.optimizer.repartition_file_min_size 10485760 +datafusion.optimizer.repartition_file_scans true +datafusion.optimizer.repartition_joins true +datafusion.optimizer.repartition_sorts true +datafusion.optimizer.repartition_windows true +datafusion.optimizer.skip_failed_rules false +datafusion.optimizer.top_down_join_key_reordering true +datafusion.sql_parser.dialect generic +datafusion.sql_parser.enable_ident_normalization true +datafusion.sql_parser.parse_float_as_decimal false + +# show all variables with verbose +query TTT rowsort +SHOW ALL VERBOSE +---- +datafusion.catalog.create_default_catalog_and_schema true Whether the default catalog and schema should be created automatically. +datafusion.catalog.default_catalog datafusion The default catalog name - this impacts what SQL queries use if not specified +datafusion.catalog.default_schema public The default schema name - this impacts what SQL queries use if not specified +datafusion.catalog.format NULL Type of `TableProvider` to use when loading `default` schema +datafusion.catalog.has_header false If the file has a header +datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information +datafusion.catalog.location NULL Location scanned to load tables for `default` schema +datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. +datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption +datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting +datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption +datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics +datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. +datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.bloom_filter_enabled false Sets if bloom filter is enabled for any column +datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting +datafusion.execution.parquet.column_index_truncate_length NULL Sets column index truncate length +datafusion.execution.parquet.compression zstd(3) Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.created_by datafusion Sets "created by" property +datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 Sets best effort maximum number of rows in data page +datafusion.execution.parquet.data_pagesize_limit 1048576 Sets best effort maximum size of data page in bytes +datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting +datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes +datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. +datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_row_group_size 1048576 Sets maximum number of rows in a row group +datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer +datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file +datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded +datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query +datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata +datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.write_batch_size 1024 Sets write_batch_size in bytes +datafusion.execution.parquet.writer_version 1.0 Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system +datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max +datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. +datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). +datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour +datafusion.explain.logical_plan_only false When set to true, the explain statement will only print logical plans +datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans +datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans +datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. +datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). +datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. +datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores +datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible +datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. +datafusion.optimizer.hash_join_single_partition_threshold 1048576 The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition +datafusion.optimizer.max_passes 3 Number of times that the optimizer will attempt to optimize the plan +datafusion.optimizer.prefer_existing_sort false When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. +datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory +datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level +datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. +datafusion.optimizer.repartition_file_scans true When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. +datafusion.optimizer.repartition_joins true Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level +datafusion.optimizer.repartition_sorts true Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below ```text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` would turn into the plan below which performs better in multithreaded environments ```text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` +datafusion.optimizer.repartition_windows true Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level +datafusion.optimizer.skip_failed_rules false When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail +datafusion.optimizer.top_down_join_key_reordering true When set to true, the physical plan optimizer will run a top down process to reorder the join keys +datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. +datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) +datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type + +# show_variable_in_config_options +query TT +SHOW datafusion.execution.batch_size +---- +datafusion.execution.batch_size 8192 + +# show_variable_in_config_options_verbose +query TTT +SHOW datafusion.execution.batch_size VERBOSE +---- +datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption + +# show_time_zone_default_utc +# https://github.com/apache/arrow-datafusion/issues/3255 +query TT +SHOW TIME ZONE +---- +datafusion.execution.time_zone +00:00 + +# show_timezone_default_utc +# https://github.com/apache/arrow-datafusion/issues/3255 +query TT +SHOW TIMEZONE +---- +datafusion.execution.time_zone +00:00 + + +# show_time_zone_default_utc_verbose +# https://github.com/apache/arrow-datafusion/issues/3255 +query TTT +SHOW TIME ZONE VERBOSE +---- +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour + +# show_timezone_default_utc +# https://github.com/apache/arrow-datafusion/issues/3255 +query TTT +SHOW TIMEZONE VERBOSE +---- +datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour + + +# show empty verbose +query TTT +SHOW VERBOSE +---- + +# information_schema_describe_table + +## some_table +statement ok +CREATE OR REPLACE TABLE some_table AS VALUES (1,2),(3,4); + +query TTT rowsort +DESCRIBE some_table +---- +column1 Int64 YES +column2 Int64 YES + +statement ok +DROP TABLE public.some_table; + +## public.some_table + +statement ok +CREATE OR REPLACE TABLE public.some_table AS VALUES (1,2),(3,4); + +query TTT rowsort +DESCRIBE public.some_table +---- +column1 Int64 YES +column2 Int64 YES + +statement ok +DROP TABLE public.some_table; + +## datafusion.public.some_table + +statement ok +CREATE OR REPLACE TABLE datafusion.public.some_table AS VALUES (1,2),(3,4); + +query TTT rowsort +DESCRIBE datafusion.public.some_table +---- +column1 Int64 YES +column2 Int64 YES + +statement ok +DROP TABLE datafusion.public.some_table; + +# information_schema_describe_table_not_exists + +statement error Error during planning: table 'datafusion.public.table' not found +describe table; + + +# information_schema_show_tables +query TTTT rowsort +SHOW TABLES +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + + +# information_schema_show_tables_no_information_schema + +statement ok +set datafusion.catalog.information_schema = false; + +statement error Error during planning: SHOW TABLES is not supported unless information_schema is enabled +SHOW TABLES + +statement ok +set datafusion.catalog.information_schema = true; + + +# information_schema_show_columns +statement ok +CREATE TABLE t AS SELECT 1::int as i; + +statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported +SHOW COLUMNS FROM t LIKE 'f'; + +statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported +SHOW COLUMNS FROM t WHERE column_name = 'bar'; + +query TTTTTT +SHOW COLUMNS FROM t; +---- +datafusion public t i Int32 NO + +# This isn't ideal but it is consistent behavior for `SELECT * from "T"` +statement error Error during planning: table 'datafusion.public.T' not found +SHOW columns from "T" + +# information_schema_show_columns_full_extended +query TTTTITTTIIIIIIT +SHOW FULL COLUMNS FROM t; +---- +datafusion public t i 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL + +# expect same as above +query TTTTITTTIIIIIIT +SHOW EXTENDED COLUMNS FROM t; +---- +datafusion public t i 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL + +# information_schema_show_columns_no_information_schema + +statement ok +set datafusion.catalog.information_schema = false; + +statement error Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled +SHOW COLUMNS FROM t + +statement ok +set datafusion.catalog.information_schema = true; + + +# information_schema_show_columns_names() +query TTTTTT +SHOW columns from public.t +---- +datafusion public t i Int32 NO + +query TTTTTT +SHOW columns from datafusion.public.t +---- +datafusion public t i Int32 NO + +statement error Error during planning: table 'datafusion.public.t2' not found +SHOW columns from t2 + +statement error Error during planning: table 'datafusion.public.t2' not found +SHOW columns from datafusion.public.t2 + + +# show_non_existing_variable +# FIXME +# currently we cannot know whether a variable exists, this will output 0 row instead +statement ok +SHOW SOMETHING_UNKNOWN; + +statement ok +DROP TABLE t; + +# show_unsupported_when_information_schema_off + +statement ok +set datafusion.catalog.information_schema = false; + +statement error Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +SHOW SOMETHING + +statement error Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +SHOW SOMETHING VERBOSE + +statement ok +set datafusion.catalog.information_schema = true; + + + +# show_create_view() +statement ok +CREATE TABLE abc AS VALUES (1,2,3), (4,5,6); + +statement ok +CREATE VIEW xyz AS SELECT * FROM abc + +query TTTT +SHOW CREATE TABLE xyz +---- +datafusion public xyz CREATE VIEW xyz AS SELECT * FROM abc + +statement ok +DROP TABLE abc; + +statement ok +DROP VIEW xyz; + +# show_create_view_in_catalog +statement ok +CREATE TABLE abc AS VALUES (1,2,3), (4,5,6) + +statement ok +CREATE SCHEMA test; + +statement ok +CREATE VIEW test.xyz AS SELECT * FROM abc; + +query TTTT +SHOW CREATE TABLE test.xyz +---- +datafusion test xyz CREATE VIEW test.xyz AS SELECT * FROM abc + +statement error DataFusion error: Execution error: Cannot drop schema test because other tables depend on it: xyz +DROP SCHEMA test; + +statement ok +DROP TABLE abc; + +statement ok +DROP VIEW test.xyz + + +# show_external_create_table() +statement ok +CREATE EXTERNAL TABLE abc +STORED AS CSV +WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv'; + +query TTTT +SHOW CREATE TABLE abc; +---- +datafusion public abc CREATE EXTERNAL TABLE abc STORED AS CSV LOCATION ../../testing/data/csv/aggregate_test_100.csv diff --git a/datafusion/sqllogictest/test_files/information_schema_columns.slt b/datafusion/sqllogictest/test_files/information_schema_columns.slt new file mode 100644 index 0000000000000..7cf845c16d738 --- /dev/null +++ b/datafusion/sqllogictest/test_files/information_schema_columns.slt @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +set datafusion.catalog.information_schema = true; + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +########### +# Information schema columns +########### + +statement ok +CREATE TABLE t1 (i int) as values(1); + +# table t2 is created using rust code because it is not possible to set nullable columns with `arrow_cast` syntax + +query TTTTITTTIIIIIIT rowsort +SELECT * from information_schema.columns; +---- +my_catalog my_schema t1 i 0 NULL YES Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema t2 binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema t2 float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL +my_catalog my_schema t2 int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema t2 large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema t2 large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema t2 timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL +my_catalog my_schema t2 utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL + +# Cleanup +statement ok +drop table t1 + +statement ok +drop table t2 diff --git a/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt b/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt new file mode 100644 index 0000000000000..c7f4dcfd54d86 --- /dev/null +++ b/datafusion/sqllogictest/test_files/information_schema_multiple_catalogs.slt @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Verify the information schema does not exit by default +statement error DataFusion error: Error during planning: table 'datafusion.information_schema.tables' not found +SELECT * from information_schema.tables + +statement error DataFusion error: Error during planning: SHOW \[VARIABLE\] is not supported unless information_schema is enabled +show all + +# Turn it on + +# expect that the queries now work +statement ok +set datafusion.catalog.information_schema = true; + +# Verify the information schema now does exist and is empty +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +############ +# Create multiple catalogs +########### +statement ok +create database my_catalog; + +statement ok +create schema my_catalog.my_schema; + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +statement ok +create table t1 as values(1); + +statement ok +create table t2 as values(1); + +statement ok +create database my_other_catalog; + +statement ok +create schema my_other_catalog.my_other_schema; + +statement ok +set datafusion.catalog.default_catalog = my_other_catalog; + +statement ok +set datafusion.catalog.default_schema = my_other_schema; + +statement ok +create table t3 as values(1); + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +my_catalog information_schema columns VIEW +my_catalog information_schema df_settings VIEW +my_catalog information_schema tables VIEW +my_catalog information_schema views VIEW +my_catalog my_schema t1 BASE TABLE +my_catalog my_schema t2 BASE TABLE +my_other_catalog information_schema columns VIEW +my_other_catalog information_schema df_settings VIEW +my_other_catalog information_schema tables VIEW +my_other_catalog information_schema views VIEW +my_other_catalog my_other_schema t3 BASE TABLE + +# Cleanup + +statement ok +drop table t3 + +statement ok +set datafusion.catalog.default_catalog = my_catalog; + +statement ok +set datafusion.catalog.default_schema = my_schema; + +statement ok +drop table t1 + +statement ok +drop table t2 diff --git a/datafusion/sqllogictest/test_files/information_schema_table_types.slt b/datafusion/sqllogictest/test_files/information_schema_table_types.slt new file mode 100644 index 0000000000000..eb72f3399fe7f --- /dev/null +++ b/datafusion/sqllogictest/test_files/information_schema_table_types.slt @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Turn it on +statement ok +set datafusion.catalog.information_schema = true; + +############ +# Table with many types +############ + +statement ok +create table physical as values(1); + +statement ok +create view query as values(1); + +# Temporary tables cannot be created using SQL syntax so it is done using Rust code. + +query TTTT rowsort +SELECT * from information_schema.tables; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW +datafusion public physical BASE TABLE +datafusion public query VIEW +datafusion public temp LOCAL TEMPORARY + +# Cleanup + +statement ok +drop table physical + +statement ok +drop view query diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt new file mode 100644 index 0000000000000..e20b3779459bf --- /dev/null +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -0,0 +1,435 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## INSERT tests +########## + + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# test_insert_into + +statement ok +set datafusion.execution.target_partitions = 8; + +statement ok +CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); + +query TT +EXPLAIN +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +FROM aggregate_test_100 +ORDER by c1 +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +----------TableScan: aggregate_test_100 projection=[c1, c4, c9] +physical_plan +FileSinkExec: sink=MemoryTable (partitions=1) +--ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] +------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true + +query II +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +FROM aggregate_test_100 +ORDER by c1 +---- +100 + +# verify there is data now in the table +query I +SELECT COUNT(*) from table_without_values; +---- +100 + +# verify there is data now in the table +query II +SELECT * +FROM table_without_values +ORDER BY field1, field2 +LIMIT 5; +---- +-70111 3 +-65362 3 +-62295 3 +-56721 3 +-55414 3 + +statement ok +drop table table_without_values; + + + +# test_insert_into_as_select_multi_partitioned +statement ok +CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) + +query TT +EXPLAIN +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +------TableScan: aggregate_test_100 projection=[c1, c4, c9] +physical_plan +FileSinkExec: sink=MemoryTable (partitions=1) +--CoalescePartitionsExec +----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true + + + +query II +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +---- +100 + +statement ok +drop table table_without_values; + + +# test_insert_into_as_select_single_partition + +statement ok +CREATE TABLE table_without_values AS SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 + + +# // TODO: The generated plan is suboptimal since SortExec is in global state. +query TT +EXPLAIN +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +ORDER BY c1 +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: a1 AS a1, a2 AS a2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 +--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +----------TableScan: aggregate_test_100 projection=[c1, c4, c9] +physical_plan +FileSinkExec: sink=MemoryTable (partitions=8) +--ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] +----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] +------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] +--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true + + +query II +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +ORDER BY c1 +---- +100 + + +statement ok +drop table table_without_values; + +# test_insert_into_with_sort + +statement ok +create table table_without_values(c1 varchar not null); + +# verify that the sort order of the insert query is maintained into the +# insert (there should be a SortExec in the following plan) +# See https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 for more background +query TT +explain insert into table_without_values select c1 from aggregate_test_100 order by c1; +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: aggregate_test_100.c1 AS c1 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------TableScan: aggregate_test_100 projection=[c1] +physical_plan +FileSinkExec: sink=MemoryTable (partitions=1) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true + +query T +insert into table_without_values select c1 from aggregate_test_100 order by c1; +---- +100 + +query I +select count(*) from table_without_values; +---- +100 + +statement ok +drop table table_without_values; + + +# test insert with column names +statement ok +CREATE TABLE table_without_values(id BIGINT, name varchar); + +query IT +insert into table_without_values(id, name) values(1, 'foo'); +---- +1 + +query IT +insert into table_without_values(name, id) values('bar', 2); +---- +1 + +statement error Schema error: Schema contains duplicate unqualified field name id +insert into table_without_values(id, id) values(3, 3); + +statement error Arrow error: Cast error: Cannot cast string 'zoo' to value of Int64 type +insert into table_without_values(name, id) values(4, 'zoo'); + +statement error Error during planning: Column count doesn't match insert query! +insert into table_without_values(id) values(4, 'zoo'); + +# insert NULL values for the missing column (name) +query IT +insert into table_without_values(id) values(4); +---- +1 + +query IT rowsort +select * from table_without_values; +---- +1 foo +2 bar +4 NULL + +statement ok +drop table table_without_values; + + +# test insert with non-nullable column +statement ok +CREATE TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL); + +query II +insert into table_without_values values(1, 100); +---- +1 + +query II +insert into table_without_values values(2, NULL); +---- +1 + +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values values(NULL, 300); + +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values values(3, 300), (NULL, 400); + +query II rowsort +select * from table_without_values; +---- +1 100 +2 NULL +3 NULL + +statement ok +drop table table_without_values; + + +### Test for creating tables into directories that do not already exist +# note use of `scratch` directory (which is cleared between runs) + +statement ok +create external table new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/new_empty_table/'; -- needs trailing slash + +# should start empty +query I +select * from new_empty_table; +---- + +# should succeed and the table should create the direectory +statement ok +insert into new_empty_table values (1); + +# Now has values +query I +select * from new_empty_table; +---- +1 + +statement ok +drop table new_empty_table; + +## test we get an error if the path doesn't end in slash +statement ok +create external table bad_new_empty_table(x int) stored as parquet location 'test_files/scratch/insert/bad_new_empty_table'; -- no trailing slash + +# should fail +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +insert into bad_new_empty_table values (1); + +statement ok +drop table bad_new_empty_table; + + +### Test for specifying column's default value + +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) + +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + + +# test create table as +statement ok +create table test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) as values(1, 10, 100, 'ABC', now()) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +statement ok +drop table test_column_defaults + +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +create table test_column_defaults(a int, b int default a+1) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt new file mode 100644 index 0000000000000..85c2db7faaf60 --- /dev/null +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -0,0 +1,612 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +################################### +## INSERT to external table tests## +################################### + + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + + +statement ok +create table dictionary_encoded_values as values +('a', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('b', arrow_cast('bar', 'Dictionary(Int32, Utf8)')); + +query TTT +describe dictionary_encoded_values; +---- +column1 Utf8 YES +column2 Dictionary(Int32, Utf8) YES + +statement ok +CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( + a varchar, + b varchar, +) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned' +PARTITIONED BY (b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +insert into dictionary_encoded_parquet_partitioned +select * from dictionary_encoded_values +---- +2 + +query TT +select * from dictionary_encoded_parquet_partitioned order by (a); +---- +a foo +b bar + + +# test_insert_into +statement ok +set datafusion.execution.target_partitions = 8; + +statement ok +CREATE EXTERNAL TABLE +ordered_insert_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_ordered/' +WITH ORDER (a ASC, B DESC) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +EXPLAIN INSERT INTO ordered_insert_test values (5, 1), (4, 2), (7,7), (7,8), (7,9), (7,10), (3, 3), (2, 4), (1, 5); +---- +logical_plan +Dml: op=[Insert Into] table=[ordered_insert_test] +--Projection: column1 AS a, column2 AS b +----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... +physical_plan +FileSinkExec: sink=CsvSink(file_groups=[]) +--SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] +----ProjectionExec: expr=[column1@0 as a, column2@1 as b] +------ValuesExec + +query II +INSERT INTO ordered_insert_test values (5, 1), (4, 2), (7,7), (7,8), (7,9), (7,10), (3, 3), (2, 4), (1, 5); +---- +9 + +query II +SELECT * from ordered_insert_test; +---- +1 5 +2 4 +3 3 +4 2 +5 1 +7 10 +7 9 +7 8 +7 7 + +# test partitioned insert + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test(a string, b string, c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/' +PARTITIONED BY (a, b) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +#note that partitioned cols are moved to the end so value tuples are (c, a, b) +query ITT +INSERT INTO partitioned_insert_test values (1, 10, 100), (1, 10, 200), (1, 20, 100), (1, 20, 200), (2, 20, 100), (2, 20, 200); +---- +6 + +query ITT +select * from partitioned_insert_test order by a,b,c +---- +1 10 100 +1 10 200 +1 20 100 +2 20 100 +1 20 200 +2 20 200 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify(c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/a=20/b=100/' +OPTIONS( +insert_mode 'append_new_files', +); + +query I +select * from partitioned_insert_test_verify; +---- +1 +2 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_json(a string, b string) +STORED AS json +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_json/' +PARTITIONED BY (a) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query TT +INSERT INTO partitioned_insert_test_json values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); +---- +6 + +# Issue open for this error: https://github.com/apache/arrow-datafusion/issues/7816 +query error DataFusion error: Arrow error: Json error: Encountered unmasked nulls in non\-nullable StructArray child: Field \{ name: "a", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +select * from partitioned_insert_test_json order by a,b + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify_json(b string) +STORED AS json +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_json/a=2/' +OPTIONS( +insert_mode 'append_new_files', +); + +query T +select * from partitioned_insert_test_verify_json; +---- +1 +1 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_pq(a string, b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_pq/' +PARTITIONED BY (a) +OPTIONS( +create_local_path 'true', +insert_mode 'append_new_files', +); + +query IT +INSERT INTO partitioned_insert_test_pq values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); +---- +6 + +query IT +select * from partitioned_insert_test_pq order by a ASC, b ASC +---- +1 2 +1 2 +3 4 +3 4 +5 6 +5 6 + +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_verify_pq(b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_pq/a=2/' +OPTIONS( +insert_mode 'append_new_files', +); + +query I +select * from partitioned_insert_test_verify_pq; +---- +1 +1 + + +statement ok +CREATE EXTERNAL TABLE +single_file_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' +OPTIONS( +create_local_path 'true', +single_file 'true', +); + +query error DataFusion error: Error during planning: Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`\. To append to an existing file use StreamTable, e\.g\. by using CREATE UNBOUNDED EXTERNAL TABLE +INSERT INTO single_file_test values (1, 2), (3, 4); + +statement ok +drop table single_file_test; + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE +single_file_test(a bigint, b bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv' +OPTIONS( +create_local_path 'true', +single_file 'true', +); + +query II +INSERT INTO single_file_test values (1, 2), (3, 4); +---- +2 + +query II +INSERT INTO single_file_test values (4, 5), (6, 7); +---- +2 + +query II +select * from single_file_test; +---- +1 2 +3 4 +4 5 +6 7 + +statement ok +CREATE EXTERNAL TABLE +directory_test(a bigint, b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0' +OPTIONS( +create_local_path 'true', +); + +query II +INSERT INTO directory_test values (1, 2), (3, 4); +---- +2 + +query II +select * from directory_test; +---- +1 2 +3 4 + +statement ok +CREATE EXTERNAL TABLE +table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q1' +OPTIONS (create_local_path 'true'); + +query TT +EXPLAIN +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +FROM aggregate_test_100 +ORDER by c1 +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +----------TableScan: aggregate_test_100 projection=[c1, c4, c9] +physical_plan +FileSinkExec: sink=ParquetSink(file_groups=[]) +--ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] +------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true + +query II +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +FROM aggregate_test_100 +ORDER by c1 +---- +100 + +# verify there is data now in the table +query I +SELECT COUNT(*) from table_without_values; +---- +100 + +# verify there is data now in the table +query II +SELECT * +FROM table_without_values +ORDER BY field1, field2 +LIMIT 5; +---- +-70111 3 +-65362 3 +-62295 3 +-56721 3 +-55414 3 + +statement ok +drop table table_without_values; + +# test_insert_into_as_select_multi_partitioned +statement ok +CREATE EXTERNAL TABLE +table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q2' +OPTIONS (create_local_path 'true'); + +query TT +EXPLAIN +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +------TableScan: aggregate_test_100 projection=[c1, c4, c9] +physical_plan +FileSinkExec: sink=ParquetSink(file_groups=[]) +--CoalescePartitionsExec +----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true + + + +query II +INSERT INTO table_without_values SELECT +SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, +COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 +FROM aggregate_test_100 +---- +100 + +statement ok +drop table table_without_values; + + +# test_insert_into_with_sort +statement ok +CREATE EXTERNAL TABLE +table_without_values(c1 varchar NULL) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q3' +OPTIONS (create_local_path 'true'); + +# verify that the sort order of the insert query is maintained into the +# insert (there should be a SortExec in the following plan) +# See https://github.com/apache/arrow-datafusion/pull/6354#discussion_r1195284178 for more background +query TT +explain insert into table_without_values select c1 from aggregate_test_100 order by c1; +---- +logical_plan +Dml: op=[Insert Into] table=[table_without_values] +--Projection: aggregate_test_100.c1 AS c1 +----Sort: aggregate_test_100.c1 ASC NULLS LAST +------TableScan: aggregate_test_100 projection=[c1] +physical_plan +FileSinkExec: sink=ParquetSink(file_groups=[]) +--SortExec: expr=[c1@0 ASC NULLS LAST] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true + +query T +insert into table_without_values select c1 from aggregate_test_100 order by c1; +---- +100 + +query I +select count(*) from table_without_values; +---- +100 + + +statement ok +drop table table_without_values; + + +# test insert with column names +statement ok +CREATE EXTERNAL TABLE +table_without_values(id BIGINT, name varchar) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4' +OPTIONS (create_local_path 'true'); + +query IT +insert into table_without_values(id, name) values(1, 'foo'); +---- +1 + +query IT +insert into table_without_values(name, id) values('bar', 2); +---- +1 + +statement error Schema error: Schema contains duplicate unqualified field name id +insert into table_without_values(id, id) values(3, 3); + +statement error Arrow error: Cast error: Cannot cast string 'zoo' to value of Int64 type +insert into table_without_values(name, id) values(4, 'zoo'); + +statement error Error during planning: Column count doesn't match insert query! +insert into table_without_values(id) values(4, 'zoo'); + +# insert NULL values for the missing column (name) +query IT +insert into table_without_values(id) values(4); +---- +1 + +query IT rowsort +select * from table_without_values; +---- +1 foo +2 bar +4 NULL + +statement ok +drop table table_without_values; + +# test insert with non-nullable column +statement ok +CREATE EXTERNAL TABLE +table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5' +OPTIONS (create_local_path 'true'); + +query II +insert into table_without_values values(1, 100); +---- +1 + +query II +insert into table_without_values values(2, NULL); +---- +1 + +# insert NULL values for the missing column (field2) +query II +insert into table_without_values(field1) values(3); +---- +1 + +# insert NULL values for the missing column (field1), but column is non-nullable +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values(field2) values(300); + +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values values(NULL, 300); + +statement error Execution error: Invalid batch column at '0' has null but schema specifies non-nullable +insert into table_without_values values(3, 300), (NULL, 400); + +query II rowsort +select * from table_without_values; +---- +1 100 +2 NULL +3 NULL + +statement ok +drop table table_without_values; + + +### Test for specifying column's default value + +statement ok +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int not null default null, + c int default 100*2+300, + d text default lower('DEFAULT_TEXT'), + e timestamp default now() +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6' +OPTIONS (create_local_path 'true'); + +# fill in all column values +query IIITP +insert into test_column_defaults values(1, 10, 100, 'ABC', now()) +---- +1 + +statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable +insert into test_column_defaults(a) values(2) + +query IIITP +insert into test_column_defaults(b) values(20) +---- +1 + +query IIIT rowsort +select a,b,c,d from test_column_defaults +---- +1 10 100 ABC +NULL 20 500 default_text + +# fill the timestamp column with default value `now()` again, it should be different from the previous one +query IIITP +insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') +---- +1 + +# Ensure that the default expression `now()` is evaluated during insertion, not optimized away. +# Rows are inserted during different time, so their timestamp values should be different. +query I rowsort +select count(distinct e) from test_column_defaults +---- +3 + +# Expect all rows to be true as now() was inserted into the table +query B rowsort +select e < now() from test_column_defaults +---- +true +true +true + +statement ok +drop table test_column_defaults + +# test invalid default value +statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a. +CREATE EXTERNAL TABLE test_column_defaults( + a int, + b int default a+1 +) STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q7' +OPTIONS (create_local_path 'true'); diff --git a/datafusion/core/tests/sqllogictests/test_files/intersection.slt b/datafusion/sqllogictest/test_files/intersection.slt similarity index 99% rename from datafusion/core/tests/sqllogictests/test_files/intersection.slt rename to datafusion/sqllogictest/test_files/intersection.slt index 31121a333df81..301878cc98e2c 100644 --- a/datafusion/core/tests/sqllogictests/test_files/intersection.slt +++ b/datafusion/sqllogictest/test_files/intersection.slt @@ -30,7 +30,6 @@ SELECT * FROM (SELECT null AS id1, 1 AS id2) t1 ---- NULL 1 - query IR SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4 ---- diff --git a/datafusion/core/tests/sqllogictests/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt similarity index 79% rename from datafusion/core/tests/sqllogictests/test_files/interval.slt rename to datafusion/sqllogictest/test_files/interval.slt index 911b28c84be31..500876f76221c 100644 --- a/datafusion/core/tests/sqllogictests/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -190,25 +190,25 @@ select interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 seco query ? select -interval '5' - '1' - '2' year; ---- -0 years -24 mons 0 days 0 hours 0 mins 0.000000000 secs +0 years -96 mons 0 days 0 hours 0 mins 0.000000000 secs # Interval with nested string literal negation query ? select -interval '1 month' + '1 day' + '1 hour'; ---- -0 years -1 mons -1 days -1 hours 0 mins 0.000000000 secs +0 years -1 mons 1 days 1 hours 0 mins 0.000000000 secs # Interval with nested string literal negation and leading field query ? select -interval '10' - '1' - '1' month; ---- -0 years -8 mons 0 days 0 hours 0 mins 0.000000000 secs +0 years -12 mons 0 days 0 hours 0 mins 0.000000000 secs # Interval mega nested string literal negation query ? select -interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond' ---- -0 years -11 mons 1 days 1 hours 1 mins 1.001001001 secs +0 years -13 mons -1 days -1 hours -1 mins -1.001001001 secs # Interval string literal + date query D @@ -276,6 +276,12 @@ create table t (i interval) as values ('5 days 3 nanoseconds'::interval); statement ok insert into t values ('6 days 7 nanoseconds'::interval) +query ? rowsort +select -i from t order by 1; +---- +0 years 0 mons -5 days 0 hours 0 mins -0.000000003 secs +0 years 0 mons -6 days 0 hours 0 mins -0.000000007 secs + query ?T rowsort select i, @@ -430,12 +436,10 @@ select '1 month'::interval + '1980-01-01T12:00:00'::timestamp; ---- 1980-02-01T12:00:00 -# Exected error: interval (scalar) - date / timestamp (scalar) - -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select '1 month'::interval - '1980-01-01'::date; -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select '1 month'::interval - '1980-01-01T12:00:00'::timestamp; # interval (array) + date / timestamp (array) @@ -454,10 +458,10 @@ select i + ts from t; 2000-02-01T00:01:00 # expected error interval (array) - date / timestamp (array) -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select i - d from t; -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select i - ts from t; @@ -477,10 +481,10 @@ select '1 month'::interval + ts from t; 2000-03-01T00:00:00 # expected error interval (scalar) - date / timestamp (array) -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Date32 can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select '1 month'::interval - d from t; -query error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select '1 month'::interval - ts from t; # interval + date @@ -489,5 +493,55 @@ select interval '1 month' + '2012-01-01'::date; ---- 2012-02-01 +# is (not) distinct from +query BBBBBB +select + i is distinct from null, + i is distinct from (interval '1 month'), + i is distinct from i, + i is not distinct from null, + i is not distinct from (interval '1 day'), + i is not distinct from i +from t; +---- +true false false false false true +true true false false true true +true true false false false true + +### interval (array) cmp interval (array) +query BBBBBB +select i = i, i != i, i < i, i <= i, i > i, i >= i from t; +---- +true false false true false true +true false false true false true +true false false true false true + +### interval (array) cmp interval (scalar) +query BBBBBB +select + (interval '1 day') = i, + (interval '1 day') != i, + i < (interval '1 day'), + i <= (interval '1 day'), + i > (interval '1 day'), + i >= (interval '1 day') +from t; +---- +false true false false true true +true false false true false true +false true true true false false + +### interval (scalar) cmp interval (scalar) +query BBBBBB +select + (interval '1 day') = (interval '1 day'), + (interval '1 month') != (interval '1 day'), + (interval '1 minute') < (interval '1 day'), + (interval '1 hour') <= (interval '1 day'), + (interval '1 year') > (interval '1 day'), + (interval '1 day') >= (interval '1 day'); +---- +true true true true true true + statement ok drop table t diff --git a/datafusion/core/tests/sqllogictests/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt similarity index 97% rename from datafusion/core/tests/sqllogictests/test_files/join.slt rename to datafusion/sqllogictest/test_files/join.slt index 283ff57a984cb..386ffe766b193 100644 --- a/datafusion/core/tests/sqllogictests/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -556,7 +556,11 @@ query TT explain select * from t1 join t2 on false; ---- logical_plan EmptyRelation -physical_plan EmptyExec: produce_one_row=false +physical_plan EmptyExec + +# Make batch size smaller than table row number. to introduce parallelism to the plan. +statement ok +set datafusion.execution.batch_size = 1; # test covert inner join to cross join when condition is true query TT @@ -568,9 +572,9 @@ CrossJoin: --TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan CrossJoinExec ---CoalescePartitionsExec -----MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--MemoryExec: partitions=1, partition_sizes=[1] +--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----MemoryExec: partitions=1, partition_sizes=[1] statement ok drop table IF EXISTS t1; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt new file mode 100644 index 0000000000000..1312f2916ed61 --- /dev/null +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Join Tests +########## + +# turn off repartition_joins +statement ok +set datafusion.optimizer.repartition_joins = false; + +include ./join.slt + +statement ok +CREATE EXTERNAL TABLE annotated_data ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC, c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +query TT +EXPLAIN SELECT t2.a + FROM annotated_data as t1 + INNER JOIN annotated_data as t2 + ON t1.c = t2.c ORDER BY t2.a + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: t2.a ASC NULLS LAST, fetch=5 +----Projection: t2.a +------Inner Join: t1.c = t2.c +--------SubqueryAlias: t1 +----------TableScan: annotated_data projection=[c] +--------SubqueryAlias: t2 +----------TableScan: annotated_data projection=[a, c] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 +----ProjectionExec: expr=[a@1 as a] +------CoalesceBatchesExec: target_batch_size=8192 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# preserve_inner_join +query IIII nosort +SELECT t1.a, t1.b, t1.c, t2.a as a2 + FROM annotated_data as t1 + INNER JOIN annotated_data as t2 + ON t1.d = t2.d ORDER BY a2, t2.b + LIMIT 5 +---- +0 0 0 0 +0 0 2 0 +0 0 3 0 +0 0 6 0 +0 0 20 0 + +query TT +EXPLAIN SELECT t2.a as a2, t2.b + FROM annotated_data as t1 + RIGHT SEMI JOIN annotated_data as t2 + ON t1.d = t2.d AND t1.c = t2.c + WHERE t2.d = 3 + ORDER BY a2, t2.b +LIMIT 10 +---- +logical_plan +Limit: skip=0, fetch=10 +--Sort: a2 ASC NULLS LAST, t2.b ASC NULLS LAST, fetch=10 +----Projection: t2.a AS a2, t2.b +------RightSemi Join: t1.d = t2.d, t1.c = t2.c +--------SubqueryAlias: t1 +----------TableScan: annotated_data projection=[c, d] +--------SubqueryAlias: t2 +----------Filter: annotated_data.d = Int32(3) +------------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--SortPreservingMergeExec: [a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 +----ProjectionExec: expr=[a@0 as a2, b@1 as b] +------CoalesceBatchesExec: target_batch_size=8192 +--------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true +----------CoalesceBatchesExec: target_batch_size=8192 +------------FilterExec: d@3 = 3 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true + +# preserve_right_semi_join +query II nosort +SELECT t2.a as a2, t2.b + FROM annotated_data as t1 + RIGHT SEMI JOIN annotated_data as t2 + ON t1.d = t2.d AND t1.c = t2.c + WHERE t2.d = 3 + ORDER BY a2, t2.b +LIMIT 10 +---- +0 0 +0 0 +0 0 +0 1 +0 1 +0 1 +0 1 +0 1 +1 2 +1 2 + +# turn on repartition_joins +statement ok +set datafusion.optimizer.repartition_joins = true; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt new file mode 100644 index 0000000000000..0fea8da5a3420 --- /dev/null +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -0,0 +1,3488 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +#### +# Configuration +#### + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +statement ok +set datafusion.explain.logical_plan_only = true; + + +#### +# Data Setup +#### + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE join_t1(t1_id INT UNSIGNED, t1_name VARCHAR, t1_int INT UNSIGNED) +AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE join_t2(t2_id INT UNSIGNED, t2_name VARCHAR, t2_int INT UNSIGNED) +AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +# Left semi anti join + +statement ok +CREATE TABLE lsaj_t1(t1_id INT UNSIGNED, t1_name VARCHAR, t1_int INT UNSIGNED) +AS VALUES +(11, 'a', 1), +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4), +(NULL, 'e', 0); + +statement ok +CREATE TABLE lsaj_t2(t2_id INT UNSIGNED, t2_name VARCHAR, t2_int INT UNSIGNED) +AS VALUES +(11, 'z', 3), +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3), +(NULL, 'v', 0); + +statement ok +CREATE TABLE left_semi_anti_join_table_t1(t1_id INT UNSIGNED, t1_name VARCHAR, t1_int INT UNSIGNED) +AS VALUES +(11, 'a', 1), +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4), +(NULL, 'e', 0); + +statement ok +CREATE TABLE left_semi_anti_join_table_t2(t2_id INT UNSIGNED, t2_name VARCHAR, t2_int INT UNSIGNED) +AS VALUES +(11, 'z', 3), +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3), +(NULL, 'v', 0); + + +statement ok +CREATE TABLE right_semi_anti_join_table_t1(t1_id INT UNSIGNED, t1_name VARCHAR, t1_int INT UNSIGNED) +AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4), +(NULL, 'e', 0); + +statement ok +CREATE TABLE right_semi_anti_join_table_t2(t2_id INT UNSIGNED, t2_name VARCHAR) +AS VALUES +(11, 'a'), +(11, 'x'), +(NULL, NULL); + +# Table with all of the supported timestamp types values +# +# Columns are named: +# "nanos" --> TimestampNanosecondArray +# "micros" --> TimestampMicrosecondArray +# "millis" --> TimestampMillisecondArray +# "secs" --> TimestampSecondArray +# "names" --> StringArray +statement ok +CREATE TABLE test_timestamps_table_source(ts varchar, names varchar) +AS VALUES +('2018-11-13T17:11:10.011375885995', 'Row 0'), +('2011-12-13T11:13:10.12345', 'Row 1'), +(NULL, 'Row 2'), +('2021-01-01T05:11:10.432', 'Row 3'); + + +statement ok +CREATE TABLE test_timestamps_table as +SELECT + arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, None)') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, None)') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, None)') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, None)') as secs, + names +FROM + test_timestamps_table_source; + +# create a table of timestamps with time zone +statement ok +CREATE TABLE test_timestamps_tz_table as +SELECT + arrow_cast(ts::timestamp::bigint, 'Timestamp(Nanosecond, Some("UTC"))') as nanos, + arrow_cast(ts::timestamp::bigint / 1000, 'Timestamp(Microsecond, Some("UTC"))') as micros, + arrow_cast(ts::timestamp::bigint / 1000000, 'Timestamp(Millisecond, Some("UTC"))') as millis, + arrow_cast(ts::timestamp::bigint / 1000000000, 'Timestamp(Second, Some("UTC"))') as secs, + names +FROM + test_timestamps_table_source; + + +statement ok +CREATE TABLE hashjoin_datatype_table_t1_source(c1 INT, c2 BIGINT, c3 DECIMAL(5,2), c4 VARCHAR) +AS VALUES +(1, 86400000, 1.23, 'abc'), +(2, 172800000, 456.00, 'def'), +(null, 259200000, 789.000, 'ghi'), +(3, null, -123.12, 'jkl') +; + +statement ok +CREATE TABLE hashjoin_datatype_table_t1 +AS SELECT + arrow_cast(c1, 'Date32') as c1, + arrow_cast(c2, 'Date64') as c2, + c3, + arrow_cast(c4, 'Dictionary(Int32, Utf8)') as c4 +FROM + hashjoin_datatype_table_t1_source + +statement ok +CREATE TABLE hashjoin_datatype_table_t2_source(c1 INT, c2 BIGINT, c3 DECIMAL(10,2), c4 VARCHAR) +AS VALUES +(1, 86400000, -123.12, 'abc'), +(null, null, 100000.00, 'abcdefg'), +(null, 259200000, 0.00, 'qwerty'), +(3, null, 789.000, 'qwe') +; + +statement ok +CREATE TABLE hashjoin_datatype_table_t2 +AS SELECT + arrow_cast(c1, 'Date32') as c1, + arrow_cast(c2, 'Date64') as c2, + c3, + arrow_cast(c4, 'Dictionary(Int32, Utf8)') as c4 +FROM + hashjoin_datatype_table_t2_source + + + +statement ok +set datafusion.execution.target_partitions = 2; + +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + +########## +## Joins Tests +########## + +# create table t1 +statement ok +CREATE TABLE t1(a INT, b INT, c INT) AS VALUES +(1, 10, 50), +(2, 20, 60), +(3, 30, 70), +(4, 40, 80) + +# create table t2 +statement ok +CREATE TABLE t2(a INT, b INT, c INT) AS VALUES +(1, 100, 500), +(2, 200, 600), +(9, 300, 700), +(4, 400, 800) + +# equijoin +query II nosort +SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a +---- +1 100 +2 200 +4 400 + +query II nosort +SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a +---- +1 100 +2 200 +4 400 + +# inner_join_nulls +query ?? +SELECT * FROM (SELECT null AS id1) t1 +INNER JOIN (SELECT null AS id2) t2 ON id1 = id2 +---- + +statement ok +DROP TABLE t1 + +statement ok +DROP TABLE t2 + + +# create table a +statement ok +CREATE TABLE a(a INT, b INT, c INT) AS VALUES +(1, 10, 50), +(2, 20, 60), +(3, 30, 70), +(4, 40, 80) + +# create table b +statement ok +CREATE TABLE b(a INT, b INT, c INT) AS VALUES +(1, 100, 500), +(2, 200, 600), +(9, 300, 700), +(4, 400, 800) + +# issue_3002 +# // repro case for https://github.com/apache/arrow-datafusion/issues/3002 + +query II +select a.a, b.b from a join b on a.a = b.b +---- + +statement ok +DROP TABLE a + +statement ok +DROP TABLE b + +# create table t1 +statement ok +CREATE TABLE t1(t1_id INT, t1_name VARCHAR) AS VALUES +(11, 'a'), +(22, 'b'), +(33, 'c'), +(44, 'd'), +(77, 'e') + +# create table t2 +statement ok +CREATE TABLE t2(t2_id INT, t2_name VARCHAR) AS VALUES +(11, 'z'), +(22, 'y'), +(44, 'x'), +(55, 'w') + +# left_join_unbalanced +# // the t1_id is larger than t2_id so the join_selection optimizer should kick in +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id +---- +11 a z +22 b y +33 c NULL +44 d x +77 e NULL + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id +---- +11 a z +22 b y +33 c NULL +44 d x +77 e NULL + + +# cross_join_unbalanced +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name +---- +11 a w +11 a x +11 a y +11 a z +22 b w +22 b x +22 b y +22 b z +33 c w +33 c x +33 c y +33 c z +44 d w +44 d x +44 d y +44 d z +77 e w +77 e x +77 e y +77 e z + +statement ok +DROP TABLE t1 + +statement ok +DROP TABLE t2 + +# create table t1 +statement ok +CREATE TABLE t1(t1_id INT, t1_name VARCHAR) AS VALUES +(11, 'a'), +(22, 'b'), +(33, 'c'), +(44, 'd'), +(77, 'e'), +(88, NULL), +(99, NULL) + +# create table t2 +statement ok +CREATE TABLE t2(t2_id INT, t2_name VARCHAR) AS VALUES +(11, 'z'), +(22, NULL), +(44, 'x'), +(55, 'w'), +(99, 'u') + +# left_join_null_filter +# // Since t2 is the non-preserved side of the join, we cannot push down a NULL filter. +# // Note that this is only true because IS NULL does not remove nulls. For filters that +# // remove nulls, we can rewrite the join as an inner join and then push down the filter. +query IIT nosort +SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NULL ORDER BY t1_id +---- +22 22 NULL +33 NULL NULL +77 NULL NULL +88 NULL NULL + +# left_join_null_filter_on_join_column +# // Again, since t2 is the non-preserved side of the join, we cannot push down a NULL filter. +query IIT nosort +SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NULL ORDER BY t1_id +---- +33 NULL NULL +77 NULL NULL +88 NULL NULL + +# left_join_not_null_filter +query IIT nosort +SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id +---- +11 11 z +44 44 x +99 99 u + +# left_join_not_null_filter_on_join_column +query IIT nosort +SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id +---- +11 11 z +22 22 NULL +44 44 x +99 99 u + +# self_join_non_equijoin +query II nosort +SELECT x.t1_id, y.t1_id FROM t1 x JOIN t1 y ON x.t1_id = 11 AND y.t1_id = 44 +---- +11 44 + +# right_join_null_filter +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t2_id +---- +NULL NULL 55 +99 NULL 99 + +# right_join_null_filter_on_join_column +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NULL ORDER BY t2_id +---- +NULL NULL 55 + +# right_join_not_null_filter +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id +---- +11 a 11 +22 b 22 +44 d 44 + +# right_join_not_null_filter_on_join_column +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id +---- +11 a 11 +22 b 22 +44 d 44 +99 NULL 99 + +# full_join_null_filter +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id +---- +88 NULL NULL +99 NULL 99 +NULL NULL 55 + +# full_join_not_null_filter +query ITI nosort +SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id +---- +11 a 11 +22 b 22 +33 c NULL +44 d 44 +77 e NULL + +statement ok +DROP TABLE t1 + +statement ok +DROP TABLE t2 + +# create table t1 +statement ok +CREATE TABLE t1(id INT, t1_name VARCHAR, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4) + +# create table t2 +statement ok +CREATE TABLE t2(id INT, t2_name VARCHAR, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3) + +# left_join_using + +# set repartition_joins to true +statement ok +set datafusion.optimizer.repartition_joins = true + +query ITT nosort +SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id +---- +11 a z +22 b y +33 c NULL +44 d x + +# set repartition_joins to false +statement ok +set datafusion.optimizer.repartition_joins = false + +query ITT nosort +SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id +---- +11 a z +22 b y +33 c NULL +44 d x + +statement ok +DROP TABLE t1 + +statement ok +DROP TABLE t2 + +# create table t1 +statement ok +CREATE TABLE t1(t1_id INT, t1_name VARCHAR, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4) + +# create table t2 +statement ok +CREATE TABLE t2(t2_id INT, t2_name VARCHAR, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3) + +# cross_join + +# set repartition_joins to true +statement ok +set datafusion.optimizer.repartition_joins = true + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITITI rowsort +SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 +---- +11 a 11 z 3 +11 a 11 z 3 +11 a 22 y 1 +11 a 22 y 1 +11 a 44 x 3 +11 a 44 x 3 +11 a 55 w 3 +11 a 55 w 3 +22 b 11 z 3 +22 b 11 z 3 +22 b 22 y 1 +22 b 22 y 1 +22 b 44 x 3 +22 b 44 x 3 +22 b 55 w 3 +22 b 55 w 3 +33 c 11 z 3 +33 c 11 z 3 +33 c 22 y 1 +33 c 22 y 1 +33 c 44 x 3 +33 c 44 x 3 +33 c 55 w 3 +33 c 55 w 3 +44 d 11 z 3 +44 d 11 z 3 +44 d 22 y 1 +44 d 22 y 1 +44 d 44 x 3 +44 d 44 x 3 +44 d 55 w 3 +44 d 55 w 3 + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2_data +---- +11 a w +11 a w +11 a x +11 a x +11 a y +11 a y +11 a z +11 a z +22 b w +22 b w +22 b x +22 b x +22 b y +22 b y +22 b z +22 b z +33 c w +33 c w +33 c x +33 c x +33 c y +33 c y +33 c z +33 c z +44 d w +44 d w +44 d x +44 d x +44 d y +44 d y +44 d z +44 d z + +# set repartition_joins to true +statement ok +set datafusion.optimizer.repartition_joins = false + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITT nosort +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id +---- +11 a z +11 a y +11 a x +11 a w +22 b z +22 b y +22 b x +22 b w +33 c z +33 c y +33 c x +33 c w +44 d z +44 d y +44 d x +44 d w + +query ITITI rowsort +SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 +---- +11 a 11 z 3 +11 a 11 z 3 +11 a 22 y 1 +11 a 22 y 1 +11 a 44 x 3 +11 a 44 x 3 +11 a 55 w 3 +11 a 55 w 3 +22 b 11 z 3 +22 b 11 z 3 +22 b 22 y 1 +22 b 22 y 1 +22 b 44 x 3 +22 b 44 x 3 +22 b 55 w 3 +22 b 55 w 3 +33 c 11 z 3 +33 c 11 z 3 +33 c 22 y 1 +33 c 22 y 1 +33 c 44 x 3 +33 c 44 x 3 +33 c 55 w 3 +33 c 55 w 3 +44 d 11 z 3 +44 d 11 z 3 +44 d 22 y 1 +44 d 22 y 1 +44 d 44 x 3 +44 d 44 x 3 +44 d 55 w 3 +44 d 55 w 3 + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN (SELECT t2_name FROM t2 UNION ALL SELECT t2_name FROM t2) AS t2_data +---- +11 a w +11 a w +11 a x +11 a x +11 a y +11 a y +11 a z +11 a z +22 b w +22 b w +22 b x +22 b x +22 b y +22 b y +22 b z +22 b z +33 c w +33 c w +33 c x +33 c x +33 c y +33 c y +33 c z +33 c z +44 d w +44 d w +44 d x +44 d x +44 d y +44 d y +44 d z +44 d z + +statement ok +DROP TABLE t1 + +statement ok +DROP TABLE t2 + +# Join timestamp + +statement ok +CREATE TABLE timestamp(time TIMESTAMP) AS VALUES + (131964190213133), + (131964190213134), + (131964190213135); + +query PP +SELECT * +FROM timestamp as a +JOIN (SELECT * FROM timestamp) as b +ON a.time = b.time +ORDER BY a.time +---- +1970-01-02T12:39:24.190213133 1970-01-02T12:39:24.190213133 +1970-01-02T12:39:24.190213134 1970-01-02T12:39:24.190213134 +1970-01-02T12:39:24.190213135 1970-01-02T12:39:24.190213135 + +statement ok +DROP TABLE timestamp; + +# Join float32 + +statement ok +CREATE TABLE population(city VARCHAR, population FLOAT) AS VALUES + ('a', 838.698), + ('b', 1778.934), + ('c', 626.443); + +query TRTR +SELECT * +FROM population as a +JOIN (SELECT * FROM population) as b +ON a.population = b.population +ORDER BY a.population +---- +c 626.443 c 626.443 +a 838.698 a 838.698 +b 1778.934 b 1778.934 + +statement ok +DROP TABLE population; + +# Join float64 + +statement ok +CREATE TABLE population(city VARCHAR, population DOUBLE) AS VALUES + ('a', 838.698), + ('b', 1778.934), + ('c', 626.443); + +query TRTR +SELECT * +FROM population as a +JOIN (SELECT * FROM population) as b +ON a.population = b.population +ORDER BY a.population +---- +c 626.443 c 626.443 +a 838.698 a 838.698 +b 1778.934 b 1778.934 + +statement ok +DROP TABLE population; + +# Inner join qualified names + +statement ok +CREATE TABLE t1 (a INT, b INT, c INT) AS VALUES + (1, 10, 50), + (2, 20, 60), + (3, 30, 70), + (4, 40, 80); + +statement ok +CREATE TABLE t2 (a INT, b INT, c INT) AS VALUES + (1, 100, 500), + (2, 200, 600), + (9, 300, 700), + (4, 400, 800); + +query IIIIII +SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c +FROM t1 +INNER JOIN t2 ON t1.a = t2.a +ORDER BY t1.a +---- +1 10 50 1 100 500 +2 20 60 2 200 600 +4 40 80 4 400 800 + +query IIIIII +SELECT t1.a, t1.b, t1.c, t2.a, t2.b, t2.c +FROM t1 +INNER JOIN t2 ON t2.a = t1.a +ORDER BY t1.a +---- +1 10 50 1 100 500 +2 20 60 2 200 600 +4 40 80 4 400 800 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +# TODO: nestedjoin_with_alias + +# Nested join without alias + +query IIII +select * from (select 1 as a, 2 as b) c INNER JOIN (select 1 as a, 3 as d) e on c.a = e.a +---- +1 2 1 3 + +# Join tables with duplicated column name not in on constraint + +statement ok +CREATE TABLE countries (id INT, country VARCHAR) AS VALUES + (1, 'Germany'), + (2, 'Sweden'), + (3, 'Japan'); + +statement ok +CREATE TABLE cities (id INT, country_id INT, city VARCHAR) AS VALUES + (1, 1, 'Hamburg'), + (2, 2, 'Stockholm'), + (3, 3, 'Osaka'), + (4, 1, 'Berlin'), + (5, 2, 'Göteborg'), + (6, 3, 'Tokyo'), + (7, 3, 'Kyoto'); + +query IITT +SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id +---- +1 1 Hamburg Germany +2 2 Stockholm Sweden +3 3 Osaka Japan +4 1 Berlin Germany +5 2 Göteborg Sweden +6 3 Tokyo Japan +7 3 Kyoto Japan + +statement ok +DROP TABLE countries; + +statement ok +DROP TABLE cities; + +# TODO: join_timestamp + +# Left join and right join should not panic with empty side + +statement ok +CREATE TABLE t1 (t1_id BIGINT, ti_value VARCHAR NOT NULL) AS VALUES + (5247, 'a'), + (3821, 'b'), + (6321, 'c'), + (8821, 'd'), + (7748, 'e'); + +statement ok +CREATE TABLE t2 (t2_id BIGINT, t2_value BOOLEAN) AS VALUES + (358, true), + (2820, false), + (3804, NULL), + (7748, NULL); + +query ITIB rowsort +SELECT * FROM t1 LEFT JOIN t2 ON t1_id = t2_id +---- +3821 b NULL NULL +5247 a NULL NULL +6321 c NULL NULL +7748 e 7748 NULL +8821 d NULL NULL + +query IBIT rowsort +SELECT * FROM t2 RIGHT JOIN t1 ON t1_id = t2_id +---- +7748 NULL 7748 e +NULL NULL 3821 b +NULL NULL 5247 a +NULL NULL 6321 c +NULL NULL 8821 d + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +# TODO: left_join_using_2 + +# TODO: left_join_using_join_key_projection + +# TODO: left_join_2 + +# TODO: join_partitioned + +# TODO: hash_join_with_date32 + +# TODO: hash_join_with_date64 + +# TODO: hash_join_with_decimal + +# TODO: hash_join_with_dictionary + + +### +# Configuration setup +### + +statement ok +set datafusion.optimizer.repartition_joins = false; + +# Reduce left join 1 (to inner join) + +query TT +EXPLAIN +SELECT * +FROM join_t1 +LEFT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE join_t2.t2_id < 100 +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id +--Filter: join_t1.t1_id < UInt32(100) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--Filter: join_t2.t2_id < UInt32(100) +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce left join 2 (to inner join) + +query TT +EXPLAIN +SELECT * +FROM join_t1 +LEFT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE join_t2.t2_int < 10 or (join_t1.t1_int > 2 and join_t2.t2_name != 'w') +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8("w") +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8("w") +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce left join 3 (to inner join) + +query TT +EXPLAIN +SELECT * +FROM ( + SELECT join_t1.* + FROM join_t1 + LEFT JOIN join_t2 ON join_t1.t1_id = join_t2.t2_id + WHERE join_t2.t2_int < 3 +) t3 +LEFT JOIN join_t2 on t3.t1_int = join_t2.t2_int +WHERE t3.t1_id < 100 +---- +logical_plan +Left Join: t3.t1_int = join_t2.t2_int +--SubqueryAlias: t3 +----Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +------Inner Join: join_t1.t1_id = join_t2.t2_id +--------Filter: join_t1.t1_id < UInt32(100) +----------TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--------Projection: join_t2.t2_id +----------Filter: join_t2.t2_int < UInt32(3) AND join_t2.t2_id < UInt32(100) +------------TableScan: join_t2 projection=[t2_id, t2_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce right join 1 (to inner join) + +query TT +EXPLAIN +SELECT * +FROM join_t1 +RIGHT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE join_t1.t1_int IS NOT NULL +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id +--Filter: join_t1.t1_int IS NOT NULL +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce right join 2 (to inner join) + +query TT +EXPLAIN +SELECT * +FROM join_t1 +RIGHT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE NOT (join_t1.t1_int = join_t2.t2_int) +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t1.t1_int != join_t2.t2_int +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce full join to right join + +query TT +EXPLAIN +SELECT * +FROM join_t1 +FULL JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE join_t2.t2_name IS NOT NULL +---- +logical_plan +Right Join: join_t1.t1_id = join_t2.t2_id +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--Filter: join_t2.t2_name IS NOT NULL +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce full join to left join + +query TT +EXPLAIN +SELECT * +FROM join_t1 +FULL JOIN join_t2 ON join_t1.t1_id = join_t2.t2_id +WHERE join_t1.t1_name != 'b' +---- +logical_plan +Left Join: join_t1.t1_id = join_t2.t2_id +--Filter: join_t1.t1_name != Utf8("b") +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce full join to inner join + +query TT +EXPLAIN +SELECT * +FROM join_t1 +FULL JOIN join_t2 on join_t1.t1_id = join_t2.t2_id +WHERE join_t1.t1_name != 'b' and join_t2.t2_name = 'x' +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id +--Filter: join_t1.t1_name != Utf8("b") +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--Filter: join_t2.t2_name = Utf8("x") +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +### +# Configuration teardown +### + +statement ok +set datafusion.optimizer.repartition_joins = true; + + +# Sort merge equijoin + +query ITT +SELECT t1_id, t1_name, t2_name +FROM join_t1 +JOIN join_t2 ON t1_id = t2_id +ORDER BY t1_id +---- +11 a z +22 b y +44 d x + +query ITT +SELECT t1_id, t1_name, t2_name +FROM join_t1 +JOIN join_t2 ON t2_id = t1_id +ORDER BY t1_id +---- +11 a z +22 b y +44 d x + +# TODO: sort_merge_join_on_date32 + +# TODO: sort_merge_join_on_decimal + + +# TODO: Left semi join + +# Left semi join pushdown + +query TT +EXPLAIN +SELECT lsaj_t1.t1_id, lsaj_t1.t1_name +FROM lsaj_t1 +LEFT SEMI JOIN lsaj_t2 ON (lsaj_t1.t1_id = lsaj_t2.t2_id and lsaj_t2.t2_int > 1) +---- +logical_plan +LeftSemi Join: lsaj_t1.t1_id = lsaj_t2.t2_id +--TableScan: lsaj_t1 projection=[t1_id, t1_name] +--Projection: lsaj_t2.t2_id +----Filter: lsaj_t2.t2_int > UInt32(1) +------TableScan: lsaj_t2 projection=[t2_id, t2_int] + +# Left anti join + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +WHERE NOT EXISTS (SELECT 1 FROM lsaj_t2 WHERE t1_id = t2_id) +ORDER BY t1_id +---- +33 c +NULL e + +query I +SELECT t1_id +FROM lsaj_t1 +EXCEPT SELECT t2_id FROM lsaj_t2 +ORDER BY t1_id +---- +33 + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +LEFT ANTI JOIN lsaj_t2 ON (t1_id = t2_id) +ORDER BY t1_id +---- +33 c +NULL e + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +WHERE NOT EXISTS (SELECT 1 FROM lsaj_t2 WHERE t1_id = t2_id) +ORDER BY t1_id +---- +33 c +NULL e + +query I +SELECT t1_id +FROM lsaj_t1 +EXCEPT SELECT t2_id FROM lsaj_t2 +ORDER BY t1_id +---- +33 + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +LEFT ANTI JOIN lsaj_t2 ON (t1_id = t2_id) +ORDER BY t1_id +---- +33 c +NULL e + +# Error left anti join +# https://github.com/apache/arrow-datafusion/issues/4366 + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +WHERE NOT EXISTS (SELECT 1 FROM lsaj_t2 WHERE t1_id = t2_id and t1_id > 11) +ORDER BY t1_id +---- +11 a +11 a +33 c +NULL e + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query IT +SELECT t1_id, t1_name +FROM lsaj_t1 +WHERE NOT EXISTS (SELECT 1 FROM lsaj_t2 WHERE t1_id = t2_id and t1_id > 11) +ORDER BY t1_id +---- +11 a +11 a +33 c +NULL e + +# TODO: null_aware_left_anti_join + +# TODO: right_semi_join + +# Join and aggregate on same key + +statement ok +set datafusion.explain.logical_plan_only = false; + +query TT +EXPLAIN +select distinct(join_t1.t1_id) +from join_t1 +inner join join_t2 on join_t1.t1_id = join_t2.t2_id +---- +logical_plan +Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[]] +--Projection: join_t1.t1_id +----Inner Join: join_t1.t1_id = join_t2.t2_id +------TableScan: join_t1 projection=[t1_id] +------TableScan: join_t2 projection=[t2_id] +physical_plan +AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[] +--ProjectionExec: expr=[t1_id@0 as t1_id] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN +select count(*) +from (select * from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id) +group by t1_id +---- +logical_plan +Projection: COUNT(*) +--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----Projection: join_t1.t1_id +------Inner Join: join_t1.t1_id = join_t2.t2_id +--------TableScan: join_t1 projection=[t1_id] +--------TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] +--AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[COUNT(*)] +----ProjectionExec: expr=[t1_id@0 as t1_id] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN +select count(distinct join_t1.t1_id) +from join_t1 +inner join join_t2 on join_t1.t1_id = join_t2.t2_id +---- +logical_plan +Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id) +--Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] +----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] +------Projection: join_t1.t1_id +--------Inner Join: join_t1.t1_id = join_t2.t2_id +----------TableScan: join_t1 projection=[t1_id] +----------TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] +--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] +--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] +------------ProjectionExec: expr=[t1_id@0 as t1_id] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.explain.logical_plan_only = true; + +# Reduce cross join with expr join key all (to inner join) + +query TT +EXPLAIN +select * +from join_t1 +cross join join_t2 +where join_t1.t1_id + 12 = join_t2.t2_id + 1 +---- +logical_plan +Inner Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = CAST(join_t2.t2_id AS Int64) + Int64(1) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +# Reduce cross join with cast expr join key (to inner join) + +query TT +EXPLAIN +select join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +from join_t1 +cross join join_t2 where join_t1.t1_id + 11 = cast(join_t2.t2_id as BIGINT) +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: CAST(join_t1.t1_id AS Int64) + Int64(11) = CAST(join_t2.t2_id AS Int64) +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] + + +##### +# Config setup +##### + +statement ok +set datafusion.explain.logical_plan_only = false; + +# Reduce cross join with wildcard and expr (to inner join) + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +select *, join_t1.t1_id + 11 +from join_t1, join_t2 +where join_t1.t1_id + 11 = join_t2.t2_id +---- +logical_plan +Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_t2.t2_name, join_t2.t2_int, CAST(join_t1.t1_id AS Int64) + Int64(11) +--Inner Join: CAST(join_t1.t1_id AS Int64) + Int64(11) = CAST(join_t2.t2_id AS Int64) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +select *, join_t1.t1_id + 11 +from join_t1, join_t2 +where join_t1.t1_id + 11 = join_t2.t2_id +---- +logical_plan +Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int, join_t2.t2_id, join_t2.t2_name, join_t2.t2_int, CAST(join_t1.t1_id AS Int64) + Int64(11) +--Inner Join: CAST(join_t1.t1_id AS Int64) + Int64(11) = CAST(join_t2.t2_id AS Int64) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 as t2_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + Int64(11)@3, CAST(join_t2.t2_id AS Int64)@3)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + Int64(11)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as join_t1.t1_id + Int64(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(join_t2.t2_id AS Int64)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(join_t2.t2_id AS Int64)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +# Both side expr key inner join + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id + cast(12 as INT UNSIGNED) = join_t2.t2_id + cast(1 as INT UNSIGNED) +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id + UInt32(12) = join_t2.t2_id + UInt32(1) +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id + cast(12 as INT UNSIGNED) = join_t2.t2_id + cast(1 as INT UNSIGNED) +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id + UInt32(12) = join_t2.t2_id + UInt32(1) +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id + UInt32(1)@1, join_t1.t1_id + UInt32(12)@2)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t2.t2_id + UInt32(1)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as join_t2.t2_id + UInt32(1)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(12)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as join_t1.t1_id + UInt32(12)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +# Left side expr key inner join + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id + cast(11 as INT UNSIGNED) = join_t2.t2_id +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id + UInt32(11) = join_t2.t2_id +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id + cast(11 as INT UNSIGNED) = join_t2.t2_id +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id + UInt32(11) = join_t2.t2_id +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@3 as t2_id, t1_name@1 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t1.t1_id + UInt32(11)@2, t2_id@0)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t1.t1_id + UInt32(11)@2], 2), input_partitions=2 +----------ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as join_t1.t1_id + UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +# Right side expr key inner join + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id = join_t2.t2_id - cast(11 as INT UNSIGNED) +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalescePartitionsExec +--------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id = join_t2.t2_id - cast(11 as INT UNSIGNED) +---- +logical_plan +Projection: join_t1.t1_id, join_t2.t2_id, join_t1.t1_name +--Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) +----TableScan: join_t1 projection=[t1_id, t1_name] +----TableScan: join_t2 projection=[t2_id] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, t2_id@0 as t2_id, t1_name@3 as t1_name] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(join_t2.t2_id - UInt32(11)@1, t1_id@0)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@1], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +# Select wildcard with expr key inner join + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +SELECT * +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id = join_t2.t2_id - cast(11 as INT UNSIGNED) +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, join_t2.t2_id - UInt32(11)@3)] +------MemoryExec: partitions=1, partition_sizes=[1] +------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +SELECT * +FROM join_t1 +INNER JOIN join_t2 +ON join_t1.t1_id = join_t2.t2_id - cast(11 as INT UNSIGNED) +---- +logical_plan +Inner Join: join_t1.t1_id = join_t2.t2_id - UInt32(11) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, join_t2.t2_id - UInt32(11)@3)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([join_t2.t2_id - UInt32(11)@3], 2), input_partitions=2 +----------ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as join_t2.t2_id - UInt32(11)] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +##### +# Config teardown +##### + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# Join with type coercion for equi expr + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t2.t2_id +from join_t1 +inner join join_t2 on join_t1.t1_id + 11 = join_t2.t2_id +---- +logical_plan +Inner Join: CAST(join_t1.t1_id AS Int64) + Int64(11) = CAST(join_t2.t2_id AS Int64) +--TableScan: join_t1 projection=[t1_id, t1_name] +--TableScan: join_t2 projection=[t2_id] + +# Join only with filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t2.t2_id +from join_t1 +inner join join_t2 on join_t1.t1_id * 4 < join_t2.t2_id +---- +logical_plan +Inner Join: Filter: CAST(join_t1.t1_id AS Int64) * Int64(4) < CAST(join_t2.t2_id AS Int64) +--TableScan: join_t1 projection=[t1_id, t1_name] +--TableScan: join_t2 projection=[t2_id] + +# Type coercion join with filter and equi expr + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t2.t2_id +from join_t1 +inner join join_t2 +on join_t1.t1_id * 5 = join_t2.t2_id and join_t1.t1_id * 4 < join_t2.t2_id +---- +logical_plan +Inner Join: CAST(join_t1.t1_id AS Int64) * Int64(5) = CAST(join_t2.t2_id AS Int64) Filter: CAST(join_t1.t1_id AS Int64) * Int64(4) < CAST(join_t2.t2_id AS Int64) +--TableScan: join_t1 projection=[t1_id, t1_name] +--TableScan: join_t2 projection=[t2_id] + +# Test cross join to groupby with different key ordering + +statement ok +CREATE TABLE tbl(col1 VARCHAR, col2 BIGINT UNSIGNED, col3 BIGINT UNSIGNED) +AS VALUES +('A', 1, 1), +('A', 1, 1), +('A', 2, 1), +('A', 2, 1), +('A', 3, 1), +('A', 3, 1), +('A', 4, 1), +('A', 4, 1), +('BB', 5, 1), +('BB', 5, 1), +('BB', 6, 1), +('BB', 6, 1); + +query TIR +select col1, col2, coalesce(sum_col3, 0) as sum_col3 +from (select distinct col2 from tbl) AS q1 +cross join (select distinct col1 from tbl) AS q2 +left outer join (SELECT col1, col2, sum(col3) as sum_col3 FROM tbl GROUP BY col1, col2) AS q3 +USING(col2, col1) +ORDER BY col1, col2 +---- +A 1 2 +A 2 2 +A 3 2 +A 4 2 +A 5 0 +A 6 0 +BB 1 0 +BB 2 0 +BB 3 0 +BB 4 0 +BB 5 2 +BB 6 2 + +statement ok +DROP TABLE tbl; + +# Subquery to join with both side expr + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in (select join_t2.t2_id + 1 from join_t2) +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1) +------TableScan: join_t2 projection=[t2_id] + +query ITI rowsort +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in (select join_t2.t2_id + 1 from join_t2) +---- +11 a 1 +33 c 3 +44 d 4 + +# Subquery to join with multi filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int and join_t2.t2_int > 0 + ) +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) Filter: join_t1.t1_int <= __correlated_sq_1.t2_int +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1), join_t2.t2_int +------Filter: join_t2.t2_int > UInt32(0) +--------TableScan: join_t2 projection=[t2_id, t2_int] + +query ITI rowsort +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int and join_t2.t2_int > 0 + ) +---- +11 a 1 +33 c 3 + +# Three projection exprs subquery to join + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int + and join_t1.t1_name != join_t2.t2_name + and join_t2.t2_int > 0 + ) +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) Filter: join_t1.t1_int <= __correlated_sq_1.t2_int AND join_t1.t1_name != __correlated_sq_1.t2_name +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1), join_t2.t2_int, join_t2.t2_name +------Filter: join_t2.t2_int > UInt32(0) +--------TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +query ITI rowsort +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int + and join_t1.t1_name != join_t2.t2_name + and join_t2.t2_int > 0 + ) +---- +11 a 1 +33 c 3 + +# In subquery to join with correlated outer filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + (select join_t2.t2_id + 1 from join_t2 where join_t1.t1_int > 0) +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) +--Filter: join_t1.t1_int > UInt32(0) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1) +------TableScan: join_t2 projection=[t2_id] + +# Not in subquery to join with correlated outer filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 not in + (select join_t2.t2_id + 1 from join_t2 where join_t1.t1_int > 0) +---- +logical_plan +LeftAnti Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) Filter: join_t1.t1_int > UInt32(0) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1) +------TableScan: join_t2 projection=[t2_id] + +# In subquery to join with outer filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int + and join_t1.t1_name != join_t2.t2_name + ) + and join_t1.t1_id > 0 +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) Filter: join_t1.t1_int <= __correlated_sq_1.t2_int AND join_t1.t1_name != __correlated_sq_1.t2_name +--Filter: join_t1.t1_id > UInt32(0) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: CAST(join_t2.t2_id AS Int64) + Int64(1), join_t2.t2_int, join_t2.t2_name +------TableScan: join_t2 projection=[t2_id, t2_name, t2_int] + +query ITI rowsort +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in + ( + select join_t2.t2_id + 1 + from join_t2 + where join_t1.t1_int <= join_t2.t2_int + and join_t1.t1_name != join_t2.t2_name + ) + and join_t1.t1_id > 0 +---- +11 a 1 +33 c 3 + +# Two in subquery to join with outer filter + +query TT +EXPLAIN +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in (select join_t2.t2_id + 1 from join_t2) + and join_t1.t1_int in (select join_t2.t2_int + 1 from join_t2) + and join_t1.t1_id > 0 +---- +logical_plan +LeftSemi Join: CAST(join_t1.t1_int AS Int64) = __correlated_sq_2.join_t2.t2_int + Int64(1) +--LeftSemi Join: CAST(join_t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.join_t2.t2_id + Int64(1) +----Filter: join_t1.t1_id > UInt32(0) +------TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +----SubqueryAlias: __correlated_sq_1 +------Projection: CAST(join_t2.t2_id AS Int64) + Int64(1) +--------TableScan: join_t2 projection=[t2_id] +--SubqueryAlias: __correlated_sq_2 +----Projection: CAST(join_t2.t2_int AS Int64) + Int64(1) +------TableScan: join_t2 projection=[t2_int] + +query ITI +select join_t1.t1_id, join_t1.t1_name, join_t1.t1_int +from join_t1 +where join_t1.t1_id + 12 in (select join_t2.t2_id + 1 from join_t2) + and join_t1.t1_int in (select join_t2.t2_int + 1 from join_t2) + and join_t1.t1_id > 0 +---- +44 d 4 + + +##### +# Configuration setup +##### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.execution.target_partitions = 4; + +# Right as inner table nested loop join + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +---- +logical_plan +Inner Join: Filter: join_t1.t1_id > join_t2.t2_id +--Filter: join_t1.t1_id > UInt32(10) +----TableScan: join_t1 projection=[t1_id] +--Projection: join_t2.t2_id +----Filter: join_t2.t2_int > UInt32(1) +------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +NestedLoopJoinExec: join_type=Inner, filter=t1_id@0 > t2_id@1 +--CoalesceBatchesExec: target_batch_size=2 +----FilterExec: t1_id@0 > 10 +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] +--CoalescePartitionsExec +----ProjectionExec: expr=[t2_id@0 as t2_id] +------CoalesceBatchesExec: target_batch_size=2 +--------FilterExec: t2_int@1 > 1 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +---- +22 11 +33 11 +44 11 + +# Left as inner table nested loop join + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id +FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 +RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 + ON join_t1.t1_id < join_t2.t2_id +---- +logical_plan +Right Join: Filter: join_t1.t1_id < join_t2.t2_id +--SubqueryAlias: join_t1 +----Filter: join_t1.t1_id > UInt32(22) +------TableScan: join_t1 projection=[t1_id] +--SubqueryAlias: join_t2 +----Filter: join_t2.t2_id > UInt32(11) +------TableScan: join_t2 projection=[t2_id] +physical_plan +NestedLoopJoinExec: join_type=Right, filter=t1_id@0 < t2_id@1 +--CoalescePartitionsExec +----CoalesceBatchesExec: target_batch_size=2 +------FilterExec: t1_id@0 > 22 +--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +--CoalesceBatchesExec: target_batch_size=2 +----FilterExec: t2_id@0 > 11 +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT join_t1.t1_id, join_t2.t2_id +FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 +RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 + ON join_t1.t1_id < join_t2.t2_id +---- +33 44 +33 55 +44 55 +NULL 22 + +##### +# Configuration teardown +##### + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# Exists subquery to join expr filter + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftSemi Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----TableScan: join_t2 projection=[t2_id] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +22 b 2 +33 c 3 +44 d 4 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +22 b 2 +33 c 3 +44 d 4 + +# Exists subquery to join inner filter + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t2.t2_int < 3 +) +---- +logical_plan +LeftSemi Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: join_t2.t2_id +------Filter: join_t2.t2_int < UInt32(3) +--------TableScan: join_t2 projection=[t2_id, t2_int] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t2.t2_int < 3 +) +---- +44 d 4 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t2.t2_int < 3 +) +---- +44 d 4 + +# Exists subquery to join outer filter + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t1.t1_int < 3 +) +---- +logical_plan +LeftSemi Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--Filter: join_t1.t1_int < UInt32(3) +----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----TableScan: join_t2 projection=[t2_id] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t1.t1_int < 3 +) +---- +22 b 2 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 AND join_t1.t1_int < 3 +) +---- +22 b 2 + +# Not exists subquery to join expr filter + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----TableScan: join_t2 projection=[t2_id] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT t2_id + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +# Exists distinct subquery to join + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT DISTINCT t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: join_t2.t2_id +------Aggregate: groupBy=[[join_t2.t2_int, join_t2.t2_id]], aggr=[[]] +--------Projection: join_t2.t2_int, join_t2.t2_id +----------TableScan: join_t2 projection=[t2_id, t2_int] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT DISTINCT t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS ( + SELECT DISTINCT t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +# Exists distinct subquery to join with expr + +query TT +EXPLAIN +SELECT * +FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT t2_id + t2_int, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: join_t2.t2_id +------Aggregate: groupBy=[[join_t2.t2_id + join_t2.t2_int, join_t2.t2_int, join_t2.t2_id]], aggr=[[]] +--------Projection: join_t2.t2_id + join_t2.t2_int, join_t2.t2_int, join_t2.t2_id +----------TableScan: join_t2 projection=[t2_id, t2_int] + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT t2_id + t2_int, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query ITI +SELECT * +FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT t2_id + t2_int, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +# Exists distinct subquery left anti join with literal + +statement ok +set datafusion.optimizer.repartition_joins = false; + +query TT +EXPLAIN +SELECT * FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT 1, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: join_t2.t2_id +------Aggregate: groupBy=[[Int64(1), join_t2.t2_int, join_t2.t2_id]], aggr=[[]] +--------Projection: Int64(1), join_t2.t2_int, join_t2.t2_id +----------TableScan: join_t2 projection=[t2_id, t2_int] + +query ITI +SELECT * FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT 1, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +statement ok +set datafusion.optimizer.repartition_joins = true; + +query TT +EXPLAIN +SELECT * FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT 1, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +logical_plan +LeftAnti Join: Filter: CAST(join_t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) +--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: join_t2.t2_id +------Aggregate: groupBy=[[Int64(1), join_t2.t2_int, join_t2.t2_id]], aggr=[[]] +--------Projection: Int64(1), join_t2.t2_int, join_t2.t2_id +----------TableScan: join_t2 projection=[t2_id, t2_int] + +query ITI +SELECT * FROM join_t1 +WHERE NOT EXISTS( + SELECT DISTINCT 1, t2_int + FROM join_t2 + WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 +) +---- +11 a 1 + +statement ok +set datafusion.explain.logical_plan_only = false; + +# show the contents of the timestamp table +query PPPPT +select * from +test_timestamps_table +---- +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 + +# show the contents of the timestamp with timezone table +query PPPPT +select * from +test_timestamps_tz_table +---- +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +# test timestamp join on nanos datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.nanos = t2.nanos; +---- +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 + +# test timestamp with timezone join on nanos datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.nanos = t2.nanos; +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +# test timestamp join on micros datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.micros = t2.micros +---- +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 + +# test timestamp with timezone join on micros datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.micros = t2.micros +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +# test timestamp join on millis datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_table as t1 JOIN (SELECT * FROM test_timestamps_table ) as t2 ON t1.millis = t2.millis +---- +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 + +# test timestamp with timezone join on millis datatype +query PPPPTPPPPT rowsort +SELECT * FROM test_timestamps_tz_table as t1 JOIN (SELECT * FROM test_timestamps_tz_table ) as t2 ON t1.millis = t2.millis +---- +2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123450Z 2011-12-13T11:13:10.123Z 2011-12-13T11:13:10Z Row 1 +2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 2018-11-13T17:11:10.011375885Z 2018-11-13T17:11:10.011375Z 2018-11-13T17:11:10.011Z 2018-11-13T17:11:10Z Row 0 +2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10.432Z 2021-01-01T05:11:10Z Row 3 + +#### +# Config setup +#### + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# explain hash join on timestamp with timezone type +query TT +EXPLAIN SELECT * FROM test_timestamps_tz_table as t1 JOIN test_timestamps_tz_table as t2 ON t1.millis = t2.millis +---- +logical_plan +Inner Join: t1.millis = t2.millis +--SubqueryAlias: t1 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +--SubqueryAlias: t2 +----TableScan: test_timestamps_tz_table projection=[nanos, micros, millis, secs, names] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(millis@2, millis@2)] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([millis@2], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +# left_join_using_2 +query II +SELECT t1.c1, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 USING (c2) ORDER BY t2.c2; +---- +0 1 +0 2 +0 3 +0 4 +0 5 +0 6 +0 7 +0 8 +0 9 +0 10 + +# left_join_using_join_key_projection +query III +SELECT t1.c1, t1.c2, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 USING (c2) ORDER BY t2.c2 +---- +0 1 1 +0 2 2 +0 3 3 +0 4 4 +0 5 5 +0 6 6 +0 7 7 +0 8 8 +0 9 9 +0 10 10 + +# left_join_2 +query III +SELECT t1.c1, t1.c2, t2.c2 FROM test_partition_table t1 JOIN test_partition_table t2 ON t1.c2 = t2.c2 ORDER BY t1.c2 +---- +0 1 1 +0 2 2 +0 3 3 +0 4 4 +0 5 5 +0 6 6 +0 7 7 +0 8 8 +0 9 9 +0 10 10 + +#### +# Config setup +#### + +statement ok +set datafusion.explain.logical_plan_only = true + +# explain hash_join_with_date32 +query TT +explain select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 +---- +logical_plan +Inner Join: t1.c1 = t2.c1 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] + +# hash_join_with_date32 +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 +---- +1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc +1970-01-04 NULL -123.12 jkl 1970-01-04 NULL 789 qwe + + +# explain hash_join_with_date64 +query TT +explain select * from hashjoin_datatype_table_t1 t1 left join hashjoin_datatype_table_t2 t2 on t1.c2 = t2.c2 +---- +logical_plan +Left Join: t1.c2 = t2.c2 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] + +# hash_join_with_date64 +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 left join hashjoin_datatype_table_t2 t2 on t1.c2 = t2.c2 +---- +1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc +1970-01-03 1970-01-03T00:00:00 456 def NULL NULL NULL NULL +1970-01-04 NULL -123.12 jkl NULL NULL NULL NULL +NULL 1970-01-04T00:00:00 789 ghi NULL 1970-01-04T00:00:00 0 qwerty + + +# explain hash_join_with_decimal +query TT +explain select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t1 t2 on t1.c3 = t2.c3 +---- +logical_plan +Right Join: t1.c3 = t2.c3 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] + +# hash_join_with_decimal +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t1 t2 on t1.c3 = t2.c3 +---- +1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 1.23 abc +1970-01-03 1970-01-03T00:00:00 456 def 1970-01-03 1970-01-03T00:00:00 456 def +1970-01-04 NULL -123.12 jkl 1970-01-04 NULL -123.12 jkl +NULL 1970-01-04T00:00:00 789 ghi NULL 1970-01-04T00:00:00 789 ghi + +# explain hash_join_with_dictionary +query TT +explain select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t1 t2 on t1.c4 = t2.c4 +---- +logical_plan +Inner Join: t1.c4 = t2.c4 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] + +# hash_join_with_dictionary +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c4 = t2.c4 +---- +1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = false + + +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = false; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +# explain sort_merge_join_on_date32 inner sort merge join on data type (Date32) +query TT +explain select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 +---- +logical_plan +Inner Join: t1.c1 = t2.c1 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] +physical_plan +SortMergeJoin: join_type=Inner, on=[(c1@0, c1@0)] +--SortExec: expr=[c1@0 ASC] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] +--SortExec: expr=[c1@0 ASC] +----CoalesceBatchesExec: target_batch_size=2 +------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +# sort_merge_join_on_date32 inner sort merge join on data type (Date32) +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 +---- +1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc +1970-01-04 NULL -123.12 jkl 1970-01-04 NULL 789 qwe + +# explain sort_merge_join_on_decimal right join on data type (Decimal) +query TT +explain select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t2 t2 on t1.c3 = t2.c3 +---- +logical_plan +Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 +--SubqueryAlias: t1 +----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] +--SubqueryAlias: t2 +----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] +physical_plan +ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4] +--SortMergeJoin: join_type=Right, on=[(CAST(t1.c3 AS Decimal128(10, 2))@4, c3@2)] +----SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(t1.c3 AS Decimal128(10, 2))@4], 2), input_partitions=2 +----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))] +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +----SortExec: expr=[c3@2 ASC] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([c3@2], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] + +# sort_merge_join_on_decimal right join on data type (Decimal) +query DDR?DDR? rowsort +select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t2 t2 on t1.c3 = t2.c3 +---- +1970-01-04 NULL -123.12 jkl 1970-01-02 1970-01-02T00:00:00 -123.12 abc +NULL 1970-01-04T00:00:00 789 ghi 1970-01-04 NULL 789 qwe +NULL NULL NULL NULL NULL 1970-01-04T00:00:00 0 qwerty +NULL NULL NULL NULL NULL NULL 100000 abcdefg + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + + + +#Test the left_semi_join scenarios where the current repartition_joins parameter is set to true . +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +query TT +explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query IT rowsort +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +query IT rowsort +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT 1 FROM left_semi_anti_join_table_t2 t2 WHERE t1_id = t2_id) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +query I rowsort +SELECT t1_id FROM left_semi_anti_join_table_t1 t1 INTERSECT SELECT t2_id FROM left_semi_anti_join_table_t2 t2 ORDER BY t1_id +---- +11 +22 +44 +NULL + +query TT +explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOIN left_semi_anti_join_table_t2 t2 ON (t1_id = t2_id) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query IT +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOIN left_semi_anti_join_table_t2 t2 ON (t1_id = t2_id) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +#Test the left_semi_join scenarios where the current repartition_joins parameter is set to false . +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.optimizer.repartition_joins = false; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +query TT +explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +--------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +query IT rowsort +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id IN (SELECT t2_id FROM left_semi_anti_join_table_t2 t2) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +query IT rowsort +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT 1 FROM left_semi_anti_join_table_t2 t2 WHERE t1_id = t2_id) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +query I rowsort +SELECT t1_id FROM left_semi_anti_join_table_t1 t1 INTERSECT SELECT t2_id FROM left_semi_anti_join_table_t2 t2 ORDER BY t1_id +---- +11 +22 +44 +NULL + +query TT +explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOIN left_semi_anti_join_table_t2 t2 ON (t1_id = t2_id) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)] +--------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +query IT +SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOIN left_semi_anti_join_table_t2 t2 ON (t1_id = t2_id) ORDER BY t1_id +---- +11 a +11 a +22 b +44 d + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + + +#Test the right_semi_join scenarios where the current repartition_joins parameter is set to true . +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +query TT +explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query ITI rowsort +SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +11 a 1 + +query TT +explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGHT SEMI JOIN right_semi_anti_join_table_t1 t1 on (t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + +query ITI rowsort +SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGHT SEMI JOIN right_semi_anti_join_table_t1 t1 on (t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +11 a 1 + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + + +#Test the right_semi_join scenarios where the current repartition_joins parameter is set to false . +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.optimizer.repartition_joins = false; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + +query TT +explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +--------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +query ITI rowsort +SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHERE EXISTS (SELECT * FROM right_semi_anti_join_table_t2 t2 where t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +11 a 1 + +query TT +explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGHT SEMI JOIN right_semi_anti_join_table_t1 t1 on (t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +physical_plan +SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +--SortExec: expr=[t1_id@0 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +--------MemoryExec: partitions=1, partition_sizes=[1] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------MemoryExec: partitions=1, partition_sizes=[1] + +query ITI rowsort +SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGHT SEMI JOIN right_semi_anti_join_table_t1 t1 on (t2.t2_id = t1.t1_id and t2.t2_name <> t1.t1_name) ORDER BY t1_id +---- +11 a 1 + +#### +# Config teardown +#### +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.repartition_joins = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 2; + + +#### +# Config setup +#### +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.optimizer.prefer_hash_join = false; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +CREATE EXTERNAL TABLE annotated_data ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC NULLS FIRST, b ASC, c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# sort merge join should propagate ordering equivalence of the left side +# for inner join. Hence final requirement rn1 ASC is already satisfied at +# the end of SortMergeJoinExec. +query TT +EXPLAIN SELECT * + FROM (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data ) as l_table + JOIN annotated_data as r_table + ON l_table.a = r_table.a + ORDER BY l_table.rn1 +---- +logical_plan +Sort: l_table.rn1 ASC NULLS LAST +--Inner Join: l_table.a = r_table.a +----SubqueryAlias: l_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: r_table +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [rn1@5 ASC NULLS LAST] +--SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] +----SortExec: expr=[rn1@5 ASC NULLS LAST] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----SortExec: expr=[a@1 ASC] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# sort merge join should propagate ordering equivalence of the right side +# for right join. Hence final requirement rn1 ASC is already satisfied at +# the end of SortMergeJoinExec. +query TT +EXPLAIN SELECT * + FROM annotated_data as l_table + RIGHT JOIN (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data ) as r_table + ON l_table.a = r_table.a + ORDER BY r_table.rn1 +---- +logical_plan +Sort: r_table.rn1 ASC NULLS LAST +--Right Join: l_table.a = r_table.a +----SubqueryAlias: l_table +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: r_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [rn1@10 ASC NULLS LAST] +--SortMergeJoin: join_type=Right, on=[(a@1, a@1)] +----SortExec: expr=[a@1 ASC] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----SortExec: expr=[rn1@5 ASC NULLS LAST] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# SortMergeJoin should add ordering equivalences of +# right table as lexicographical append to the global ordering +# below query shouldn't add any SortExec for order by clause. +# since its requirement is already satisfied at the output of SortMergeJoinExec +query TT +EXPLAIN SELECT * + FROM (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data ) as l_table + JOIN (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data ) as r_table + ON l_table.a = r_table.a + ORDER BY l_table.a ASC NULLS FIRST, l_table.b, l_table.c, r_table.rn1 +---- +logical_plan +Sort: l_table.a ASC NULLS FIRST, l_table.b ASC NULLS LAST, l_table.c ASC NULLS LAST, r_table.rn1 ASC NULLS LAST +--Inner Join: l_table.a = r_table.a +----SubqueryAlias: l_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: r_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 ASC NULLS LAST] +--SortExec: expr=[a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 ASC NULLS LAST] +----SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] +------SortExec: expr=[a@1 ASC] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +----------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------SortExec: expr=[a@1 ASC] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +----------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +# to preserve ordering until Hash join set target partition to 1. +# Otherwise RepartitionExec s inserted may broke ordering. +statement ok +set datafusion.execution.target_partitions = 1; + +# hash join should propagate ordering equivalence of the right side for INNER join. +# Hence final requirement rn1 ASC is already satisfied at the end of HashJoinExec. +query TT +EXPLAIN SELECT * + FROM annotated_data as l_table + JOIN (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data) as r_table + ON l_table.a = r_table.a + ORDER BY r_table.rn1 +---- +logical_plan +Sort: r_table.rn1 ASC NULLS LAST +--Inner Join: l_table.a = r_table.a +----SubqueryAlias: l_table +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: r_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@1, a@1)] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# hash join should propagate ordering equivalence of the right side for RIGHT ANTI join. +# Hence final requirement rn1 ASC is already satisfied at the end of HashJoinExec. +query TT +EXPLAIN SELECT * + FROM annotated_data as l_table + RIGHT ANTI JOIN (SELECT *, ROW_NUMBER() OVER() as rn1 + FROM annotated_data) as r_table + ON l_table.a = r_table.a + ORDER BY r_table.rn1 +---- +logical_plan +Sort: r_table.rn1 ASC NULLS LAST +--RightAnti Join: l_table.a = r_table.a +----SubqueryAlias: l_table +------TableScan: annotated_data projection=[a] +----SubqueryAlias: r_table +------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +----------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@1)] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC], has_header=true +----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT l.a, LAST_VALUE(r.b ORDER BY r.a ASC NULLS FIRST) as last_col1 +FROM annotated_data as l +JOIN annotated_data as r +ON l.a = r.a +GROUP BY l.a, l.b, l.c +ORDER BY l.a ASC NULLS FIRST; +---- +logical_plan +Sort: l.a ASC NULLS FIRST +--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +------Inner Join: l.a = r.a +--------SubqueryAlias: l +----------TableScan: annotated_data projection=[a, b, c] +--------SubqueryAlias: r +----------TableScan: annotated_data projection=[a, b] +physical_plan +ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true + +# create a table where there more than one valid ordering +# that describes table. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +query TT +EXPLAIN SELECT LAST_VALUE(l.d ORDER BY l.a) AS amount_usd +FROM multiple_ordered_table AS l +INNER JOIN ( + SELECT *, ROW_NUMBER() OVER (ORDER BY r.a) as row_n FROM multiple_ordered_table AS r +) +ON l.d = r.d AND + l.a >= r.a - 10 +GROUP BY row_n +ORDER BY row_n +---- +logical_plan +Projection: amount_usd +--Sort: row_n ASC NULLS LAST +----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +--------Projection: l.a, l.d, row_n +----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) +------------SubqueryAlias: l +--------------TableScan: multiple_ordered_table projection=[a, d] +------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------------SubqueryAlias: r +------------------TableScan: multiple_ordered_table projection=[a, d] +physical_plan +ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d)], ordering_mode=Sorted +----ProjectionExec: expr=[a@0 as a, d@1 as d, row_n@4 as row_n] +------CoalesceBatchesExec: target_batch_size=2 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true + +# run query above in multiple partitions +statement ok +set datafusion.execution.target_partitions = 2; + +# use bounded variants +statement ok +set datafusion.optimizer.prefer_existing_sort = true; + +query TT +EXPLAIN SELECT l.a, LAST_VALUE(r.b ORDER BY r.a ASC NULLS FIRST) as last_col1 +FROM annotated_data as l +JOIN annotated_data as r +ON l.a = r.a +GROUP BY l.a, l.b, l.c +ORDER BY l.a ASC NULLS FIRST; +---- +logical_plan +Sort: l.a ASC NULLS FIRST +--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +------Inner Join: l.a = r.a +--------SubqueryAlias: l +----------TableScan: annotated_data projection=[a, b, c] +--------SubqueryAlias: r +----------TableScan: annotated_data projection=[a, b] +physical_plan +SortPreservingMergeExec: [a@0 ASC] +--SortExec: expr=[a@0 ASC] +----ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 +------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b)], ordering_mode=PartiallySorted([0]) +--------------CoalesceBatchesExec: target_batch_size=2 +----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +------------------CoalesceBatchesExec: target_batch_size=2 +--------------------SortPreservingRepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2, sort_exprs=a@0 ASC,b@1 ASC NULLS LAST +----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST], has_header=true + +#### +# Config teardown +#### + +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.optimizer.prefer_existing_sort = false; + +statement ok +drop table annotated_data; diff --git a/datafusion/core/tests/sqllogictests/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt similarity index 78% rename from datafusion/core/tests/sqllogictests/test_files/json.slt rename to datafusion/sqllogictest/test_files/json.slt index 7092127a793cf..c0d5e895f0f2e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -22,12 +22,12 @@ statement ok CREATE EXTERNAL TABLE json_test STORED AS JSON -LOCATION 'tests/data/2.json'; +LOCATION '../core/tests/data/2.json'; statement ok CREATE EXTERNAL TABLE single_nan STORED AS JSON -LOCATION 'tests/data/3.json'; +LOCATION '../core/tests/data/3.json'; query IR rowsort SELECT a, b FROM json_test @@ -49,17 +49,19 @@ query TT EXPLAIN SELECT count(*) from json_test ---- logical_plan -Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] ---TableScan: json_test projection=[a] +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--TableScan: json_test projection=[] physical_plan -AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))] +AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] --CoalescePartitionsExec -----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))] +----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]}, projection=[a] +--------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]} -query error DataFusion error: Schema error: No field named mycol\. +query ? SELECT mycol FROM single_nan +---- +NULL statement ok DROP TABLE json_test diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt new file mode 100644 index 0000000000000..e063d6e8960af --- /dev/null +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -0,0 +1,508 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Limit Tests +########## + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# async fn csv_query_limit +query T +SELECT c1 FROM aggregate_test_100 LIMIT 2 +---- +c +d + +# async fn csv_query_limit_bigger_than_nbr_of_rows +query I +SELECT c2 FROM aggregate_test_100 LIMIT 200 +---- +2 +5 +1 +1 +5 +4 +3 +3 +1 +4 +1 +4 +3 +2 +1 +1 +2 +1 +3 +2 +4 +1 +5 +4 +2 +1 +4 +5 +2 +3 +4 +2 +1 +5 +3 +1 +2 +3 +3 +3 +2 +4 +1 +3 +2 +5 +2 +1 +4 +1 +4 +2 +5 +4 +2 +3 +4 +4 +4 +5 +4 +2 +1 +2 +4 +2 +3 +5 +1 +1 +4 +2 +1 +2 +1 +1 +5 +4 +5 +2 +3 +2 +4 +1 +3 +4 +3 +2 +5 +3 +3 +2 +5 +5 +4 +1 +3 +3 +4 +4 + +# async fn csv_query_limit_with_same_nbr_of_rows +query I +SELECT c2 FROM aggregate_test_100 LIMIT 100 +---- +2 +5 +1 +1 +5 +4 +3 +3 +1 +4 +1 +4 +3 +2 +1 +1 +2 +1 +3 +2 +4 +1 +5 +4 +2 +1 +4 +5 +2 +3 +4 +2 +1 +5 +3 +1 +2 +3 +3 +3 +2 +4 +1 +3 +2 +5 +2 +1 +4 +1 +4 +2 +5 +4 +2 +3 +4 +4 +4 +5 +4 +2 +1 +2 +4 +2 +3 +5 +1 +1 +4 +2 +1 +2 +1 +1 +5 +4 +5 +2 +3 +2 +4 +1 +3 +4 +3 +2 +5 +3 +3 +2 +5 +5 +4 +1 +3 +3 +4 +4 + +# async fn csv_query_limit_zero +query T +SELECT c1 FROM aggregate_test_100 LIMIT 0 +---- + +# async fn csv_offset_without_limit_99 +query T +SELECT c1 FROM aggregate_test_100 OFFSET 99 +---- +e + +# async fn csv_offset_without_limit_100 +query T +SELECT c1 FROM aggregate_test_100 OFFSET 100 +---- + +# async fn csv_offset_without_limit_101 +query T +SELECT c1 FROM aggregate_test_100 OFFSET 101 +---- + +# async fn csv_query_offset +query T +SELECT c1 FROM aggregate_test_100 OFFSET 2 LIMIT 2 +---- +b +a + +# async fn csv_query_offset_the_same_as_nbr_of_rows +query T +SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 100 +---- + +# async fn csv_query_offset_bigger_than_nbr_of_rows +query T +SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101 +---- + +# +# global limit statistics test +# + +statement ok +CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10); + +# The aggregate does not need to be computed because the input statistics are exact and +# the number of rows is less than the skip value (OFFSET). +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=11, fetch=3 +----TableScan: t1 projection=[], fetch=14 +physical_plan +ProjectionExec: expr=[0 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); +---- +0 + +# The aggregate does not need to be computed because the input statistics are exact and +# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=8, fetch=3 +----TableScan: t1 projection=[], fetch=11 +physical_plan +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +2 + +# The aggregate does not need to be computed because the input statistics are exact and +# an OFFSET, but no LIMIT, is specified. +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Limit: skip=8, fetch=None +----TableScan: t1 projection=[] +physical_plan +ProjectionExec: expr=[2 as COUNT(*)] +--PlaceholderRowExec + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); +---- +2 + +# The aggregate needs to be computed because the input statistics are inexact. +query TT +EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); +---- +logical_plan +Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: +----Limit: skip=6, fetch=3 +------Filter: t1.a > Int32(3) +--------TableScan: t1 projection=[a] +physical_plan +AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +--CoalescePartitionsExec +----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------ProjectionExec: expr=[] +----------GlobalLimitExec: skip=6, fetch=3 +------------CoalesceBatchesExec: target_batch_size=8192 +--------------FilterExec: a@0 > 3 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query I +SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); +---- +1 + +# generate BIGINT data from 1 to 1000 in multiple partitions +statement ok +CREATE TABLE t1000 (i BIGINT) AS +WITH t AS (VALUES (0), (0), (0), (0), (0), (0), (0), (0), (0), (0)) +SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3; + +# verify that there are multiple partitions in the input (i.e. MemoryExec says +# there are 4 partitions) so that this tests multi-partition limit. +query TT +EXPLAIN SELECT DISTINCT i FROM t1000; +---- +logical_plan +Aggregate: groupBy=[[t1000.i]], aggr=[[]] +--TableScan: t1000 projection=[i] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 +------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] +--------MemoryExec: partitions=4, partition_sizes=[1, 1, 2, 1] + +query I +SELECT i FROM t1000 ORDER BY i DESC LIMIT 3; +---- +1000 +999 +998 + +query I +SELECT i FROM t1000 ORDER BY i LIMIT 3; +---- +1 +2 +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t1000 LIMIT 3); +---- +3 + +# limit_multi_partitions +statement ok +CREATE TABLE t15 (i BIGINT); + +query I +INSERT INTO t15 VALUES (1); +---- +1 + +query I +INSERT INTO t15 VALUES (1), (2); +---- +2 + +query I +INSERT INTO t15 VALUES (1), (2), (3); +---- +3 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4); +---- +4 + +query I +INSERT INTO t15 VALUES (1), (2), (3), (4), (5); +---- +5 + +query I +SELECT COUNT(*) FROM t15; +---- +15 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 1); +---- +1 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 2); +---- +2 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 3); +---- +3 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 4); +---- +4 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 5); +---- +5 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 6); +---- +6 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 7); +---- +7 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 8); +---- +8 + +query I +SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 9); +---- +9 + +######## +# Clean up after the test +######## + +statement ok +drop table aggregate_test_100; diff --git a/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/map.slt similarity index 62% rename from datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt rename to datafusion/sqllogictest/test_files/map.slt index 5f680fcae73fb..c3d16fca904e0 100644 --- a/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -5,9 +5,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -15,16 +15,32 @@ # specific language governing permissions and limitations # under the License. -########## -## Join Tests -########## - -# turn off repartition_joins statement ok -set datafusion.optimizer.repartition_joins = false; +CREATE EXTERNAL TABLE data +STORED AS PARQUET +LOCATION '../core/tests/data/parquet_map.parquet'; -include ./join.slt +query I +SELECT SUM(ints['bytes']) FROM data; +---- +5636785 -# turn on repartition_joins -statement ok -set datafusion.optimizer.repartition_joins = true; +query I +SELECT SUM(ints['bytes']) FROM data WHERE strings['method'] == 'GET'; +---- +649668 + +query TI +SELECT strings['method'] AS method, COUNT(*) as count FROM data GROUP BY method ORDER BY count DESC; +---- +POST 41 +HEAD 33 +PATCH 30 +OPTION 29 +GET 27 +PUT 25 +DELETE 24 + +query T +SELECT strings['not_found'] FROM data LIMIT 1; +---- diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt new file mode 100644 index 0000000000000..ee1e345f946a8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/math.slt @@ -0,0 +1,567 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Math expression Tests +########## + +statement ok +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; + +# Round +query R +SELECT ROUND(c1) FROM aggregate_simple +---- +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 + +# Round +query R +SELECT round(c1/3, 2) FROM aggregate_simple order by c1 +---- +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 + +# Round +query R +SELECT round(c1, 4) FROM aggregate_simple order by c1 +---- +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0.0001 +0.0001 +0.0001 +0.0001 +0.0001 + +# Round +query RRRRRRRR +SELECT round(125.2345, -3), round(125.2345, -2), round(125.2345, -1), round(125.2345), round(125.2345, 0), round(125.2345, 1), round(125.2345, 2), round(125.2345, 3) +---- +0 100 130 125 125 125.2 125.23 125.235 + +# atan2 +query RRRRRRR +SELECT atan2(2.0, 1.0), atan2(-2.0, 1.0), atan2(2.0, -1.0), atan2(-2.0, -1.0), atan2(NULL, 1.0), atan2(2.0, NULL), atan2(NULL, NULL); +---- +1.107148717794 -1.107148717794 2.034443935796 -2.034443935796 NULL NULL NULL + +# nanvl +query RRR +SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) +---- +1 1 NaN + +# isnan +query BBBB +SELECT isnan(1.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +---- +false true true NULL + +# iszero +query BBBB +SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) +---- +false true true NULL + +# abs: empty argumnet +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +SELECT abs(); + +# abs: wrong number of arguments +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +SELECT abs(1, 2); + +# abs: unsupported argument type +statement error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nThis feature is not implemented: Unsupported data type Utf8 for function abs +SELECT abs('foo'); + + +statement ok +CREATE TABLE test_nullable_integer( + c1 TINYINT, + c2 SMALLINT, + c3 INT, + c4 BIGINT, + c5 TINYINT UNSIGNED, + c6 SMALLINT UNSIGNED, + c7 INT UNSIGNED, + c8 BIGINT UNSIGNED, + dataset TEXT + ) + AS VALUES + (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 'nulls'), + (0, 0, 0, 0, 0, 0, 0, 0, 'zeros'), + (1, 1, 1, 1, 1, 1, 1, 1, 'ones'); + +query IIIIIIIIT +INSERT into test_nullable_integer values(-128, -32768, -2147483648, -9223372036854775808, 0, 0, 0, 0, 'mins'); +---- +1 + +query IIIIIIIIT +INSERT into test_nullable_integer values(127, 32767, 2147483647, 9223372036854775807, 255, 65535, 4294967295, 18446744073709551615, 'maxs'); +---- +1 + +query IIIIIIII +SELECT c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 FROM test_nullable_integer where dataset = 'nulls' +---- +NULL NULL NULL NULL NULL NULL NULL NULL + +query IIIIIIII +SELECT c1/0, c2/0, c3/0, c4/0, c5/0, c6/0, c7/0, c8/0 FROM test_nullable_integer where dataset = 'nulls' +---- +NULL NULL NULL NULL NULL NULL NULL NULL + +query IIIIIIII +SELECT c1%0, c2%0, c3%0, c4%0, c5%0, c6%0, c7%0, c8%0 FROM test_nullable_integer where dataset = 'nulls' +---- +NULL NULL NULL NULL NULL NULL NULL NULL + +query IIIIIIII rowsort +select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_nullable_integer where dataset != 'maxs' +---- +0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 +0 0 0 0 0 0 0 0 +NULL NULL NULL NULL NULL NULL NULL NULL + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c1/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c2/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c3/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c4/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c5/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c6/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c7/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c8/0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c1%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c2%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c3%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c4%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c5%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c6%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c7%0 FROM test_nullable_integer + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c8%0 FROM test_nullable_integer + +# abs: return type +query TTTTTTTT rowsort +select + arrow_typeof(abs(c1)), arrow_typeof(abs(c2)), arrow_typeof(abs(c3)), arrow_typeof(abs(c4)), + arrow_typeof(abs(c5)), arrow_typeof(abs(c6)), arrow_typeof(abs(c7)), arrow_typeof(abs(c8)) +from test_nullable_integer limit 1 +---- +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 + +# abs: unsigned integers +query IIII rowsort +select abs(c5), abs(c6), abs(c7), abs(c8) from test_nullable_integer +---- +0 0 0 0 +0 0 0 0 +1 1 1 1 +255 65535 4294967295 18446744073709551615 +NULL NULL NULL NULL + +# abs: signed integers +query IIII rowsort +select abs(c1), abs(c2), abs(c3), abs(c4) from test_nullable_integer where dataset != 'mins' +---- +0 0 0 0 +1 1 1 1 +127 32767 2147483647 9223372036854775807 +NULL NULL NULL NULL + +# abs: Int8 overlow +statement error DataFusion error: Arrow error: Compute error: Int8Array overflow on abs\(-128\) +select abs(c1) from test_nullable_integer where dataset = 'mins' + +# abs: Int16 overlow +statement error DataFusion error: Arrow error: Compute error: Int16Array overflow on abs\(-32768\) +select abs(c2) from test_nullable_integer where dataset = 'mins' + +# abs: Int32 overlow +statement error DataFusion error: Arrow error: Compute error: Int32Array overflow on abs\(-2147483648\) +select abs(c3) from test_nullable_integer where dataset = 'mins' + +# abs: Int64 overlow +statement error DataFusion error: Arrow error: Compute error: Int64Array overflow on abs\(-9223372036854775808\) +select abs(c4) from test_nullable_integer where dataset = 'mins' + +statement ok +drop table test_nullable_integer + + +statement ok +CREATE TABLE test_non_nullable_integer( + c1 TINYINT NOT NULL, + c2 SMALLINT NOT NULL, + c3 INT NOT NULL, + c4 BIGINT NOT NULL, + c5 TINYINT UNSIGNED NOT NULL, + c6 SMALLINT UNSIGNED NOT NULL, + c7 INT UNSIGNED NOT NULL, + c8 BIGINT UNSIGNED NOT NULL, + ); + +query IIIIIIII +INSERT INTO test_non_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1) +---- +1 + +query IIIIIIII rowsort +select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from test_non_nullable_integer +---- +0 0 0 0 0 0 0 0 + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c1/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c2/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c3/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c4/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c5/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c6/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c7/0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c8/0 FROM test_non_nullable_integer + + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c1%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c2%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c3%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c4%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c5%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c6%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c7%0 FROM test_non_nullable_integer + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c8%0 FROM test_non_nullable_integer + +statement ok +drop table test_non_nullable_integer + + +statement ok +CREATE TABLE test_nullable_float( + c1 float, + c2 double, + ) AS VALUES + (-1.0, -1.0), + (1.0, 1.0), + (NULL, NULL), + (0., 0.), + ('NaN'::double, 'NaN'::double); + +query RR rowsort +SELECT c1*0, c2*0 FROM test_nullable_float +---- +0 0 +0 0 +0 0 +NULL NULL +NaN NaN + +query RR rowsort +SELECT c1/0, c2/0 FROM test_nullable_float +---- +-Infinity -Infinity +Infinity Infinity +NULL NULL +NaN NaN +NaN NaN + +query RR rowsort +SELECT c1%0, c2%0 FROM test_nullable_float +---- +NULL NULL +NaN NaN +NaN NaN +NaN NaN +NaN NaN + +query RR rowsort +SELECT c1%1, c2%1 FROM test_nullable_float +---- +0 0 +0 0 +0 0 +NULL NULL +NaN NaN + +# abs: return type +query TT rowsort +SELECT arrow_typeof(abs(c1)), arrow_typeof(abs(c2)) FROM test_nullable_float limit 1 +---- +Float32 Float64 + +# abs: floats +query RR rowsort +SELECT abs(c1), abs(c2) from test_nullable_float +---- +0 0 +1 1 +1 1 +NULL NULL +NaN NaN + +statement ok +drop table test_nullable_float + + +statement ok +CREATE TABLE test_non_nullable_float( + c1 float NOT NULL, + c2 double NOT NULL, + ); + +query RR +INSERT INTO test_non_nullable_float VALUES + (-1.0, -1.0), + (1.0, 1.0), + (0., 0.), + ('NaN'::double, 'NaN'::double) +---- +4 + +query RR rowsort +SELECT c1*0, c2*0 FROM test_non_nullable_float +---- +0 0 +0 0 +0 0 +NaN NaN + +query RR rowsort +SELECT c1/0, c2/0 FROM test_non_nullable_float +---- +-Infinity -Infinity +Infinity Infinity +NaN NaN +NaN NaN + +query RR rowsort +SELECT c1%0, c2%0 FROM test_non_nullable_float +---- +NaN NaN +NaN NaN +NaN NaN +NaN NaN + +query RR rowsort +SELECT c1%1, c2%1 FROM test_non_nullable_float +---- +0 0 +0 0 +0 0 +NaN NaN + +statement ok +drop table test_non_nullable_float + + +statement ok +CREATE TABLE test_nullable_decimal( + c1 DECIMAL(10, 2), /* Decimal128 */ + c2 DECIMAL(38, 10), /* Decimal128 with max precision */ + c3 DECIMAL(40, 2), /* Decimal256 */ + c4 DECIMAL(76, 10) /* Decimal256 with max precision */ + ) AS VALUES + (0, 0, 0, 0), + (NULL, NULL, NULL, NULL); + +query RRRR +INSERT into test_nullable_decimal values + ( + -99999999.99, + '-9999999999999999999999999999.9999999999', + '-99999999999999999999999999999999999999.99', + '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ), + ( + 99999999.99, + '9999999999999999999999999999.9999999999', + '99999999999999999999999999999999999999.99', + '999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ) +---- +2 + + +query R +SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL; +---- +NULL + +query R +SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL; +---- +NULL + +query R +SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL; +---- +NULL + +query R +SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; +---- +0 +0 +0 + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; + +query error DataFusion error: Arrow error: Divide by zero error +SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; + +# abs: return type +query TTTT +SELECT + arrow_typeof(abs(c1)), + arrow_typeof(abs(c2)), + arrow_typeof(abs(c3)), + arrow_typeof(abs(c4)) +FROM test_nullable_decimal limit 1 +---- +Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10) + +# abs: decimals +query RRRR rowsort +SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal +---- +0 0 0 0 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +NULL NULL NULL NULL + +statement ok +drop table test_nullable_decimal + + +statement ok +CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); + +query R +INSERT INTO test_non_nullable_decimal VALUES(1) +---- +1 + +query R rowsort +SELECT c1*0 FROM test_non_nullable_decimal +---- +0 + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c1/0 FROM test_non_nullable_decimal + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nArrow error: Divide by zero error +SELECT c1%0 FROM test_non_nullable_decimal + +statement ok +drop table test_non_nullable_decimal diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt new file mode 100644 index 0000000000000..3b2b219244f55 --- /dev/null +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Tests for tables that has both metadata on each field as well as metadata on +## the schema itself. +########## + +## Note that table_with_metadata is defined using Rust code +## in the test harness as there is no way to define schema +## with metadata in SQL. + +query IT +select * from table_with_metadata; +---- +1 NULL +NULL bar +3 baz + +query I rowsort +SELECT ( + SELECT id FROM table_with_metadata + ) UNION ( + SELECT id FROM table_with_metadata + ); +---- +1 +3 +NULL + +query I rowsort +SELECT "data"."id" +FROM + ( + (SELECT "id" FROM "table_with_metadata") + UNION + (SELECT "id" FROM "table_with_metadata") + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples" +WHERE "data"."id" = "samples"."id"; +---- +1 +3 + +statement ok +drop table table_with_metadata; diff --git a/datafusion/core/tests/sqllogictests/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/misc.slt rename to datafusion/sqllogictest/test_files/misc.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/nullif.slt b/datafusion/sqllogictest/test_files/nullif.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/nullif.slt rename to datafusion/sqllogictest/test_files/nullif.slt diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt new file mode 100644 index 0000000000000..9366a9b3b3c8f --- /dev/null +++ b/datafusion/sqllogictest/test_files/options.slt @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +## Tests for config options +####### + + +statement ok +create table a(c0 int) as values (1), (2); + +# Expect coalesce and default batch size +query TT +explain SELECT * FROM a WHERE c0 < 1; +---- +logical_plan +Filter: a.c0 < Int32(1) +--TableScan: a projection=[c0] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: c0@0 < 1 +----MemoryExec: partitions=1, partition_sizes=[1] + +## +# test_disable_coalesce +## + +statement ok +set datafusion.execution.coalesce_batches = false + +# expect no coalsece +query TT +explain SELECT * FROM a WHERE c0 < 1; +---- +logical_plan +Filter: a.c0 < Int32(1) +--TableScan: a projection=[c0] +physical_plan +FilterExec: c0@0 < 1 +--MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.execution.coalesce_batches = true + + +## +# test_custom_batch_size +## + +statement ok +set datafusion.execution.batch_size = 1234; + +# expect batch size to be 1234 +query TT +explain SELECT * FROM a WHERE c0 < 1; +---- +logical_plan +Filter: a.c0 < Int32(1) +--TableScan: a projection=[c0] +physical_plan +CoalesceBatchesExec: target_batch_size=1234 +--FilterExec: c0@0 < 1 +----MemoryExec: partitions=1, partition_sizes=[1] + + +statement ok +set datafusion.execution.batch_size = 8192; + +statement ok +drop table a + +# test datafusion.sql_parser.parse_float_as_decimal +# +# default option value is false +query RR +select 10000000000000000000.01, -10000000000000000000.01 +---- +10000000000000000000 -10000000000000000000 + +query TT +select arrow_typeof(10000000000000000000.01), arrow_typeof(-10000000000000000000.01) +---- +Float64 Float64 + +# select 0, i64::MIN, i64::MIN-1, i64::MAX, i64::MAX + 1, u64::MAX, u64::MAX + 1 +query IIRIIIR +select 0, -9223372036854775808, -9223372036854775809, 9223372036854775807, + 9223372036854775808, 18446744073709551615, 18446744073709551616 +---- +0 -9223372036854775808 -9223372036854776000 9223372036854775807 9223372036854775808 18446744073709551615 18446744073709552000 + +query TTTTTTT +select arrow_typeof(0), arrow_typeof(-9223372036854775808), arrow_typeof(-9223372036854775809), + arrow_typeof(9223372036854775807), arrow_typeof(9223372036854775808), + arrow_typeof(18446744073709551615), arrow_typeof(18446744073709551616) +---- +Int64 Int64 Float64 Int64 UInt64 UInt64 Float64 + + +statement ok +set datafusion.sql_parser.parse_float_as_decimal = true; + +query RR +select 10000000000000000000.01, -10000000000000000000.01 +---- +10000000000000000000.01 -10000000000000000000.01 + +query TT +select arrow_typeof(10000000000000000000.01), arrow_typeof(-10000000000000000000.01) +---- +Decimal128(22, 2) Decimal128(22, 2) + +# select 0, i64::MIN, i64::MIN-1, i64::MAX, i64::MAX + 1, u64::MAX, u64::MAX + 1 +query IIRIIIR +select 0, -9223372036854775808, -9223372036854775809, 9223372036854775807, + 9223372036854775808, 18446744073709551615, 18446744073709551616 +---- +0 -9223372036854775808 -9223372036854775809 9223372036854775807 9223372036854775808 18446744073709551615 18446744073709551616 + +query TTTTTTT +select arrow_typeof(0), arrow_typeof(-9223372036854775808), arrow_typeof(-9223372036854775809), + arrow_typeof(9223372036854775807), arrow_typeof(9223372036854775808), + arrow_typeof(18446744073709551615), arrow_typeof(18446744073709551616) +---- +Int64 Int64 Decimal128(19, 0) Int64 UInt64 UInt64 Decimal128(20, 0) + +# special cases +query RRRR +select .0 as c1, 0. as c2, 0000. as c3, 00000.00 as c4 +---- +0 0 0 0 + +query TTTT +select arrow_typeof(.0) as c1, arrow_typeof(0.) as c2, arrow_typeof(0000.) as c3, arrow_typeof(00000.00) as c4 +---- +Decimal128(1, 1) Decimal128(1, 0) Decimal128(1, 0) Decimal128(2, 2) + +query RR +select 999999999999999999999999999999999999, -999999999999999999999999999999999999 +---- +999999999999999999999999999999999999 -999999999999999999999999999999999999 + +query TT +select arrow_typeof(999999999999999999999999999999999999), arrow_typeof(-999999999999999999999999999999999999) +---- +Decimal128(36, 0) Decimal128(36, 0) + +query RR +select 99999999999999999999999999999999999999, -99999999999999999999999999999999999999 +---- +99999999999999999999999999999999999999 -99999999999999999999999999999999999999 + +query TT +select arrow_typeof(99999999999999999999999999999999999999), arrow_typeof(-99999999999999999999999999999999999999) +---- +Decimal128(38, 0) Decimal128(38, 0) + +query RR +select 9999999999999999999999999999999999.9999, -9999999999999999999999999999999999.9999 +---- +9999999999999999999999999999999999.9999 -9999999999999999999999999999999999.9999 + +query TT +select arrow_typeof(9999999999999999999999999999999999.9999), arrow_typeof(-9999999999999999999999999999999999.9999) +---- +Decimal128(38, 4) Decimal128(38, 4) + +# leading zeroes +query RRR +select 00009999999999999999999999999999999999.9999, -00009999999999999999999999999999999999.9999, 0018446744073709551616 +---- +9999999999999999999999999999999999.9999 -9999999999999999999999999999999999.9999 18446744073709551616 + +query TTT +select arrow_typeof(00009999999999999999999999999999999999.9999), + arrow_typeof(-00009999999999999999999999999999999999.9999), + arrow_typeof(0018446744073709551616) +---- +Decimal128(38, 4) Decimal128(38, 4) Decimal128(20, 0) + +# precision overflow +statement error DataFusion error: SQL error: ParserError\("Cannot parse 123456789012345678901234567890123456789 as i128 when building decimal: precision overflow"\) +select 123456789.012345678901234567890123456789 + +statement error SQL error: ParserError\("Cannot parse 123456789012345678901234567890123456789 as i128 when building decimal: precision overflow"\) +select -123456789.012345678901234567890123456789 + +# can not fit in i128 +statement error SQL error: ParserError\("Cannot parse 1234567890123456789012345678901234567890 as i128 when building decimal: number too large to fit in target type"\) +select 123456789.0123456789012345678901234567890 + +statement error SQL error: ParserError\("Cannot parse 1234567890123456789012345678901234567890 as i128 when building decimal: number too large to fit in target type"\) +select -123456789.0123456789012345678901234567890 + +# Restore option to default value +statement ok +set datafusion.sql_parser.parse_float_as_decimal = false; diff --git a/datafusion/core/tests/sqllogictests/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt similarity index 51% rename from datafusion/core/tests/sqllogictests/test_files/order.slt rename to datafusion/sqllogictest/test_files/order.slt index 92faff623c1e6..77df9e0bb4937 100644 --- a/datafusion/core/tests/sqllogictests/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -98,7 +98,7 @@ NULL three statement ok CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION 'tests/data/partitioned_csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; # Demonstrate types query TTT @@ -410,3 +410,171 @@ SELECT DISTINCT time as "first_seen" FROM t ORDER BY 1; ## Cleanup statement ok drop table t; + +# Create a table having 3 columns which are ordering equivalent by the source. In the next step, +# we will expect to observe the removed SortExec by propagating the orders across projection. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC) +WITH ORDER (b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +query TT +EXPLAIN SELECT (b+a+c) AS result +FROM multiple_ordered_table +ORDER BY result; +---- +logical_plan +Sort: result ASC NULLS LAST +--Projection: multiple_ordered_table.b + multiple_ordered_table.a + multiple_ordered_table.c AS result +----TableScan: multiple_ordered_table projection=[a, b, c] +physical_plan +SortPreservingMergeExec: [result@0 ASC NULLS LAST] +--ProjectionExec: expr=[b@1 + a@0 + c@2 as result] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST], [b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +statement ok +drop table multiple_ordered_table; + +# Create tables having some ordered columns. In the next step, we will expect to observe that scalar +# functions, such as mathematical functions like atan(), ceil(), sqrt(), or date_time functions +# like date_bin() and date_trunc(), will maintain the order of its argument columns. +statement ok +CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP +) +STORED AS CSV +WITH ORDER (ts ASC NULLS LAST) +LOCATION '../core/tests/data/timestamps.csv'; + +query TT +EXPLAIN SELECT DATE_BIN(INTERVAL '15 minutes', ts, TIMESTAMP '2022-08-03 14:40:00Z') as db15 +FROM csv_with_timestamps +ORDER BY db15; +---- +logical_plan +Sort: db15 ASC NULLS LAST +--Projection: date_bin(IntervalMonthDayNano("900000000000"), csv_with_timestamps.ts, TimestampNanosecond(1659537600000000000, None)) AS db15 +----TableScan: csv_with_timestamps projection=[ts] +physical_plan +SortPreservingMergeExec: [db15@0 ASC NULLS LAST] +--ProjectionExec: expr=[date_bin(900000000000, ts@0, 1659537600000000000) as db15] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 ASC NULLS LAST], has_header=false + +query TT +EXPLAIN SELECT DATE_TRUNC('DAY', ts) as dt_day +FROM csv_with_timestamps +ORDER BY dt_day; +---- +logical_plan +Sort: dt_day ASC NULLS LAST +--Projection: date_trunc(Utf8("DAY"), csv_with_timestamps.ts) AS dt_day +----TableScan: csv_with_timestamps projection=[ts] +physical_plan +SortPreservingMergeExec: [dt_day@0 ASC NULLS LAST] +--ProjectionExec: expr=[date_trunc(DAY, ts@0) as dt_day] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 ASC NULLS LAST], has_header=false + +statement ok +drop table csv_with_timestamps; + +statement ok +drop table aggregate_test_100; + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER(c11) +WITH ORDER(c12 DESC) +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TT +EXPLAIN SELECT ATAN(c11) as atan_c11 +FROM aggregate_test_100 +ORDER BY atan_c11; +---- +logical_plan +Sort: atan_c11 ASC NULLS LAST +--Projection: atan(aggregate_test_100.c11) AS atan_c11 +----TableScan: aggregate_test_100 projection=[c11] +physical_plan +SortPreservingMergeExec: [atan_c11@0 ASC NULLS LAST] +--ProjectionExec: expr=[atan(c11@0) as atan_c11] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true + +query TT +EXPLAIN SELECT CEIL(c11) as ceil_c11 +FROM aggregate_test_100 +ORDER BY ceil_c11; +---- +logical_plan +Sort: ceil_c11 ASC NULLS LAST +--Projection: ceil(aggregate_test_100.c11) AS ceil_c11 +----TableScan: aggregate_test_100 projection=[c11] +physical_plan +SortPreservingMergeExec: [ceil_c11@0 ASC NULLS LAST] +--ProjectionExec: expr=[ceil(c11@0) as ceil_c11] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true + +query TT + EXPLAIN SELECT LOG(c11, c12) as log_c11_base_c12 + FROM aggregate_test_100 + ORDER BY log_c11_base_c12; +---- +logical_plan +Sort: log_c11_base_c12 ASC NULLS LAST +--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c11_base_c12 +----TableScan: aggregate_test_100 projection=[c11, c12] +physical_plan +SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] +--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true + +query TT +EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 +FROM aggregate_test_100 +ORDER BY log_c12_base_c11 DESC; +---- +logical_plan +Sort: log_c12_base_c11 DESC NULLS FIRST +--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c12_base_c11 +----TableScan: aggregate_test_100 projection=[c11, c12] +physical_plan +SortPreservingMergeExec: [log_c12_base_c11@0 DESC] +--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true + +statement ok +drop table aggregate_test_100; diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_null.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_null.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_type_coercion.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_type_coercion.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_type_coercion.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_type_coercion.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_types.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_types.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_union.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_union.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_window.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_window.slt rename to datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt similarity index 71% rename from datafusion/core/tests/sqllogictests/test_files/predicates.slt rename to datafusion/sqllogictest/test_files/predicates.slt index f37495c47cc73..e992a440d0a25 100644 --- a/datafusion/core/tests/sqllogictests/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -192,6 +192,10 @@ statement ok CREATE TABLE IF NOT EXISTS test AS VALUES('foo'),('Barrr'),('Bazzz'),('ZZZZZ'); # async fn test_regexp_is_match +query error Error during planning: Cannot infer common argument type for regex operation Int64 \~ Utf8 +SELECT * FROM test WHERE 12 ~ 'z' + + query T SELECT * FROM test WHERE column1 ~ 'z' ---- @@ -249,6 +253,93 @@ SELECT * FROM test WHERE column1 IN ('foo', 'Bar', 'fazzz') foo fazzz +statement ok +CREATE TABLE IF NOT EXISTS test_float AS VALUES + ('a', 1.2, 2.3, 1.2, -3.5, 1.1), + ('b', 2.1, 'NaN'::double, -1.7, -8.2, NULL), + ('c', NULL, NULL, '-NaN'::double, -5.4, 1.5), + ('d', 'NaN'::double, 'NaN'::double, 1.1, '-NaN'::double, NULL), + ('e', '-NaN'::double, 6.2, 'NaN'::double, -3.3, 5.6) + ; + +# IN expr for float +query T +SELECT column1 FROM test_float WHERE column2 IN (0.0, -1.2) +---- + +query T +SELECT column1 FROM test_float WHERE column2 IN (0.0, 1.2) +---- +a + +query T +SELECT column1 FROM test_float WHERE column2 IN (2.1, 1.2) +---- +a +b + +query T +SELECT column1 FROM test_float WHERE column2 IN (0.0, 1.2, NULL) +---- +a + +query T +SELECT column1 FROM test_float WHERE column2 IN (0.0, -1.2, NULL) +---- + +query T +SELECT column1 FROM test_float WHERE column2 IN (0.0, 1.2, 'NaN'::double, '-NaN'::double) +---- +a +d +e + +query T +SELECT column1 FROM test_float WHERE column2 IN (column3, column4, column5, column6) +---- +a +d + +query T +SELECT column1 FROM test_float WHERE column2 IN (column3, column4, column5, column6, 2.1, NULL, '-NaN'::double) +---- +a +b +d +e + +query T +SELECT column1 FROM test_float WHERE column2 NOT IN (column3, column4, column5, column6) +---- +e + +query T +SELECT column1 FROM test_float WHERE column2 NOT IN (column3, column4, column5, column6, 2.1, NULL, '-NaN'::double) +---- + + +query T +SELECT column1 FROM test_float WHERE NULL IN (column2, column2 + 1, column2 + 2, column2 + 3) +---- + +query T +SELECT column1 FROM test_float WHERE 'NaN'::double IN (column2, column2 + 1, column2 + 2, column2 + 3) +---- +d + +query T +SELECT column1 FROM test_float WHERE '-NaN'::double IN (column2, column2 + 1, column2 + 2, column2 + 3) +---- +e + +query II +SELECT c3, c7 FROM aggregate_test_100 WHERE c3 IN (c7 / 10, c7 / 20, c7 / 30, c7 / 40, 68, 103) +---- +1 25 +103 146 +68 224 +68 121 +3 133 ### # Test logical plan simplifies large OR chains @@ -347,3 +438,86 @@ drop table alltypes_plain; statement ok DROP TABLE test; + +statement ok +DROP TABLE test_float; + +######### +# Predicates on memory tables / statistics generation +# Reproducer for https://github.com/apache/arrow-datafusion/issues/7125 +######### + +statement ok +CREATE TABLE t (i integer, s string, b boolean) AS VALUES + (1, 'One', true), + (2, 'Two', false), + (NULL, NULL, NULL), + (4, 'Four', false) + ; + +query ITB +select * from t where (b OR b) = b; +---- +1 One true +2 Two false +4 Four false + +query ITB +select * from t where (s LIKE 'T%') = true; +---- +2 Two false + +query ITB +select * from t where (i & 3) = 1; +---- +1 One true + + + + +######## +# Clean up after the test +######## +statement ok +DROP TABLE t; + + +######## +# Test query with bloom filter +# Refer to https://github.com/apache/arrow-datafusion/pull/7821#pullrequestreview-1688062599 +######## + +statement ok +CREATE EXTERNAL TABLE data_index_bloom_encoding_stats STORED AS PARQUET LOCATION '../../parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + +statement ok +set datafusion.execution.parquet.bloom_filter_enabled=true; + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; +---- + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'test'; +---- +test + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" like '%e%'; +---- +Hello +test +are you +the quick +over +the lazy + +statement ok +set datafusion.execution.parquet.bloom_filter_enabled=false; + + +######## +# Clean up after the test +######## +statement ok +DROP TABLE data_index_bloom_encoding_stats; diff --git a/datafusion/core/tests/sqllogictests/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/prepare.slt rename to datafusion/sqllogictest/test_files/prepare.slt diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt new file mode 100644 index 0000000000000..b752f5644b7fb --- /dev/null +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Projection Statement Tests +########## + +# prepare data +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +statement ok +CREATE EXTERNAL TABLE aggregate_simple ( + c1 FLOAT NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../core/tests/data/aggregate_simple.csv' + +statement ok +CREATE TABLE memory_table(a INT NOT NULL, b INT NOT NULL, c INT NOT NULL) AS VALUES +(1, 2, 3), +(10, 12, 12), +(10, 12, 12), +(100, 120, 120); + +statement ok +CREATE TABLE cpu_load_short(host STRING NOT NULL) AS VALUES +('host1'), +('host2'); + +statement ok +CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; + +statement ok +CREATE EXTERNAL TABLE test_simple (c1 int, c2 bigint, c3 boolean) +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv'; + +# projection same fields +query I rowsort +select (1+1) as a from (select 1 as a) as b; +---- +2 + +# projection type alias +query R rowsort +SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2; +---- +0.00001 +0.00002 + +# csv query group by avg with projection +query RT rowsort +SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1; +---- +0.410407092638 b +0.486006692713 e +0.487545174661 a +0.488553793875 d +0.660045653644 c + +# parallel projection +query II +SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC +---- +3 0 +3 1 +3 2 +3 3 +3 4 +3 5 +3 6 +3 7 +3 8 +3 9 +3 10 +2 0 +2 1 +2 2 +2 3 +2 4 +2 5 +2 6 +2 7 +2 8 +2 9 +2 10 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +1 9 +1 10 +0 0 +0 1 +0 2 +0 3 +0 4 +0 5 +0 6 +0 7 +0 8 +0 9 +0 10 + +# subquery alias case insensitive +query II +SELECT V1.c1, v1.C2 FROM (SELECT test_simple.C1, TEST_SIMPLE.c2 FROM test_simple) V1 ORDER BY v1.c1, V1.C2 LIMIT 1; +---- +0 0 + +# projection on table scan +statement ok +set datafusion.explain.logical_plan_only = true + +query TT +EXPLAIN SELECT c2 FROM test; +---- +logical_plan TableScan: test projection=[c2] + +statement count 44 +select c2 from test; + +statement ok +set datafusion.explain.logical_plan_only = false + +# project cast dictionary +query T +SELECT + CASE + WHEN cpu_load_short.host IS NULL THEN '' + ELSE cpu_load_short.host + END AS host +FROM + cpu_load_short; +---- +host1 +host2 + +# projection on memory scan +query TT +explain select b from memory_table; +---- +logical_plan TableScan: memory_table projection=[b] +physical_plan MemoryExec: partitions=1, partition_sizes=[1] + +query I +select b from memory_table; +---- +2 +12 +12 +120 + +# project column with same name as relation +query I +select a.a from (select 1 as a) as a; +---- +1 + +# project column with filters that cant pushed down always false +query I +select * from (select 1 as a) f where f.a=2; +---- + + +# project column with filters that cant pushed down always true +query I +select * from (select 1 as a) f where f.a=1; +---- +1 + +# project columns in memory without propagation +query I +SELECT column1 as a from (values (1), (2)) f where f.column1 = 2; +---- +2 + +# clean data +statement ok +DROP TABLE aggregate_simple; + +statement ok +DROP TABLE aggregate_test_100; + +statement ok +DROP TABLE memory_table; + +statement ok +DROP TABLE cpu_load_short; + +statement ok +DROP TABLE test; + +statement ok +DROP TABLE test_simple; diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt similarity index 57% rename from datafusion/core/tests/sqllogictests/test_files/scalar.slt rename to datafusion/sqllogictest/test_files/scalar.slt index 2d1925702c4d0..b3597c664fbb2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -20,29 +20,75 @@ ############# statement ok -CREATE TABLE t1( +CREATE TABLE unsigned_integers( a INT, b INT, c INT, - d INT + d INT, + e INT, + f INT ) as VALUES - (1, 100, 567, 1024), - (2, 1000, 123, 256), - (3, 10000, 978, 2048) + (1, 100, 567, 1024, 4, 10), + (2, 1000, 123, 256, 5, 11), + (3, 10000, 978, 2048, 6, 12), + (4, NULL, NULL, 512, NULL, NULL) ; +statement ok +CREATE TABLE signed_integers( + a INT, + b INT, + c INT, + d INT, + e INT, + f INT +) as VALUES + (-1, 100, -567, 1024, -4, 10), + (2, -1000, 123, -256, 5, -11), + (-3, 10000, -978, 2048, -6, 12), + (4, NULL, NULL, -512, NULL, NULL) +; + +statement ok +CREATE TABLE small_floats( + a FLOAT, + b FLOAT, + c FLOAT, + d FLOAT, + e FLOAT, + f FLOAT +) as VALUES + (0.2, -0.1, 1.0, -0.9, 0.1, 0.5), + (0.5, -0.2, 0.0, 0.9, -0.2, 0.6), + (-0.7, 0.1, -1.0, 0.9, 0.3, -0.7), + (-1.0, NULL, NULL, -0.9, NULL, NULL) +; + +## abs + # abs scalar function -query RRR rowsort +query III rowsort select abs(64), abs(0), abs(-64); ---- 64 0 64 # abs scalar nulls -query R rowsort +query ? rowsort select abs(null); ---- NULL +# abs with columns +query III rowsort +select abs(a), abs(b), abs(c) from signed_integers; +---- +1 100 567 +2 1000 123 +3 10000 978 +4 NULL NULL + +## acos + # acos scalar function query RRR rowsort select acos(0), acos(0.5), acos(1); @@ -55,6 +101,17 @@ select acos(null); ---- NULL +# acos with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(acos(a), 5), round(acos(b), 5), round(acos(c), 5) from small_floats; +---- +1.0472 1.77215 1.5708 +1.36944 1.67096 0 +2.34619 1.47063 3.14159 +3.14159 NULL NULL + +## acosh + # acosh scalar function # cosh(x) = (exp(x) + exp(-x)) / 2 query RRR rowsort @@ -68,6 +125,17 @@ select acosh(null); ---- NULL +# acosh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(acosh(a), 5), round(acosh(b), 5), round(acosh(c), 5) from signed_integers; +---- +1.31696 NaN 5.50532 +2.06344 NULL NULL +NaN 5.29829 NaN +NaN 9.90349 NaN + +## asin + # asin scalar function query RRR rowsort select asin(0), asin(0.5), asin(1); @@ -80,6 +148,17 @@ select asin(null); ---- NULL +# asin with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(asin(a), 5), round(asin(b), 5), round(asin(c), 5) from small_floats; +---- +-0.7754 0.10017 -1.5708 +-1.5708 NULL NULL +0.20136 -0.10017 1.5708 +0.5236 -0.20136 0 + +## asinh + # asinh scalar function # sinh(x) = (exp(x) - exp(-x)) / 2 query RRR rowsort @@ -93,6 +172,17 @@ select asinh(null); ---- NULL +# asinh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(asinh(a), 5), round(asinh(b), 5), round(asinh(c), 5) from small_floats; +---- +-0.65267 0.09983 -0.88137 +-0.88137 NULL NULL +0.19869 -0.09983 0.88137 +0.48121 -0.19869 0 + +## atan + # atan scalar function query RRR rowsort select atan(0), atan(cbrt(3)), atan(1); @@ -105,6 +195,17 @@ select atan(null); ---- NULL +# atan with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(atan(a), 5), round(atan(b), 5), round(atan(c), 5) from small_floats; +---- +-0.61073 0.09967 -0.7854 +-0.7854 NULL NULL +0.1974 -0.09967 0.7854 +0.46365 -0.1974 0 + +## atanh + # atanh scalar function # tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) query RRR rowsort @@ -118,6 +219,17 @@ select atanh(null); ---- NULL +# atanh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(atanh(a), 5), round(atanh(b), 5), round(atanh(c), 5) from small_floats; +---- +-0.8673 0.10034 -Infinity +-Infinity NULL NULL +0.20273 -0.10034 Infinity +0.54931 -0.20273 0 + +## atan2 + # atan2 scalar function query RRR rowsort select atan2(0, 1), atan2(1, 2), atan2(2, 2); @@ -130,18 +242,29 @@ select atan2(null, 64); ---- NULL -# atan2 scalar nulls 1 +# atan2 scalar nulls #1 query R rowsort select atan2(2, null); ---- NULL -# atan2 scalar nulls 2 +# atan2 scalar nulls #2 query R rowsort select atan2(null, null); ---- NULL +# atan2 with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(atan2(a, b), 5), round(atan2(c, d), 5), round(atan2(f, e), 5) from small_floats; +---- +-1.4289 -0.83798 -1.1659 +1.9513 0 1.89255 +2.03444 2.30361 1.3734 +NULL NULL NULL + +## cbrt + # cbrt scalar function query RRR rowsort select cbrt(0), cbrt(8), cbrt(27); @@ -154,6 +277,17 @@ select cbrt(null); ---- NULL +# cbrt with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(cbrt(a), 5), round(cbrt(b), 5), round(cbrt(c), 5) from signed_integers; +---- +-1 4.64159 -8.27677 +-1.44225 21.54435 -9.92612 +1.25992 -10 4.97319 +1.5874 NULL NULL + +## ceil + # ceil scalar function query RRR rowsort select ceil(1.6), ceil(1.5), ceil(1.4); @@ -166,6 +300,17 @@ select ceil(null); ---- NULL +# ceil with columns +query RRR rowsort +select ceil(a), ceil(b), ceil(c) from small_floats; +---- +-1 NULL NULL +0 1 -1 +1 0 0 +1 0 1 + +## degrees + # degrees scalar function query RRR rowsort select degrees(0), degrees(pi() / 2), degrees(pi()); @@ -178,6 +323,17 @@ select degrees(null); ---- NULL +# degrees with columns +query RRR rowsort +select round(degrees(a), 5), round(degrees(e), 5), round(degrees(f), 5) from signed_integers; +---- +-171.88734 -343.77468 687.54935 +-57.29578 -229.18312 572.9578 +114.59156 286.4789 -630.25357 +229.18312 NULL NULL + +## cos + # cos scalar function query RRR rowsort select cos(0), cos(pi() / 3), cos(pi() / 2); @@ -190,6 +346,17 @@ select cos(null); ---- NULL +# cos with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(cos(a), 5), round(cos(b), 5), round(cos(c), 5) from signed_integers; +---- +-0.41615 0.56238 -0.88797 +-0.65364 NULL NULL +-0.98999 -0.95216 -0.56968 +0.5403 0.86232 0.05744 + +## cosh + # cosh scalar function # cosh(x) = (exp(x) + exp(-x)) / 2 query RRR rowsort @@ -203,6 +370,17 @@ select cosh(null); ---- NULL +# cosh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(cosh(a), 5), round(cosh(b), 5), round(cosh(c), 5) from small_floats; +---- +1.02007 1.005 1.54308 +1.12763 1.02007 1 +1.25517 1.005 1.54308 +1.54308 NULL NULL + +## exp + # exp scalar function query RRR rowsort select exp(0), exp(1), exp(2); @@ -215,6 +393,17 @@ select exp(null); ---- NULL +# exp with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(exp(a), 5), round(exp(e), 5), round(exp(f), 5) from signed_integers; +---- +0.04979 0.00248 162754.79142 +0.36788 0.01832 22026.46579 +54.59815 NULL NULL +7.38906 148.41316 0.00002 + +## factorial + # factorial scalar function query III rowsort select factorial(0), factorial(10), factorial(15); @@ -227,6 +416,17 @@ select factorial(null); ---- NULL +# factorial with columns +query III rowsort +select factorial(a), factorial(e), factorial(f) from unsigned_integers; +---- +1 24 3628800 +2 120 39916800 +24 NULL NULL +6 720 479001600 + +## floor + # floor scalar function query RRR rowsort select floor(1.4), floor(1.5), floor(1.6); @@ -239,6 +439,17 @@ select floor(null); ---- NULL +# floor with columns +query RRR rowsort +select floor(a), floor(b), floor(c) from signed_integers; +---- +-1 100 -567 +-3 10000 -978 +2 -1000 123 +4 NULL NULL + +## gcd + # gcd scalar function query III rowsort select gcd(0, 0), gcd(2, 3), gcd(15, 10); @@ -251,18 +462,29 @@ select gcd(null, 64); ---- NULL -# gcd scalar nulls 1 +# gcd scalar nulls #1 query I rowsort select gcd(2, null); ---- NULL -# gcd scalar nulls 2 +# gcd scalar nulls #2 query I rowsort select gcd(null, null); ---- NULL +# gcd with columns +query III rowsort +select gcd(a, b), gcd(c, d), gcd(e, f) from signed_integers; +---- +1 1 2 +1 2 6 +2 1 1 +NULL NULL NULL + +## lcm + # lcm scalar function query III rowsort select lcm(0, 0), lcm(2, 3), lcm(15, 10); @@ -275,18 +497,29 @@ select lcm(null, 64); ---- NULL -# lcm scalar nulls 1 +# lcm scalar nulls #1 query I rowsort select lcm(2, null); ---- NULL -# lcm scalar nulls 2 +# lcm scalar nulls #2 query I rowsort select lcm(null, null); ---- NULL +# lcm with columns +query III rowsort +select lcm(a, b), lcm(c, d), lcm(e, f) from signed_integers; +---- +100 580608 20 +1000 31488 55 +30000 1001472 12 +NULL NULL NULL + +## ln + # ln scalar function query RRR rowsort select ln(1), ln(exp(1)), ln(3); @@ -306,6 +539,17 @@ select ln(0); ---- -Infinity +# ln with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers; +---- +0.69315 NaN 4.81218 +1.38629 NULL NULL +NaN 4.60517 NaN +NaN 9.21034 NaN + +## log + # log scalar function query RR rowsort select log(2, 64) a, log(100) b union all select log(2, 8), log(10); @@ -313,27 +557,19 @@ select log(2, 64) a, log(100) b union all select log(2, 8), log(10); 3 1 6 2 -# log scalar function -query RRR rowsort -select log(a, 64) a, log(b), log(10, b) from t1; ----- -3.7855785 4 4 -6 3 3 -Infinity 2 2 - # log scalar nulls query RR rowsort select log(null, 64) a, log(null) b; ---- NULL NULL -# log scalar nulls 1 +# log scalar nulls #1 query RR rowsort select log(2, null) a, log(null) b; ---- NULL NULL -# log scalar nulls 2 +# log scalar nulls #2 query RR rowsort select log(null, null) a, log(null) b; ---- @@ -346,6 +582,26 @@ select log(0) a, log(1, 64) b; ---- -Infinity Infinity +# log with columns #1 +query RRR rowsort +select log(a, 64) a, log(b), log(10, b) from unsigned_integers; +---- +3 NULL NULL +3.7855785 4 4 +6 3 3 +Infinity 2 2 + +# log with columns #2 +query RRR rowsort +select log(a, 64) a, log(b), log(10, b) from signed_integers; +---- +3 NULL NULL +6 NaN NaN +NaN 2 2 +NaN 4 4 + +## log10 + # log10 scalar function query RRR rowsort select log10(1), log10(10), log10(100); @@ -365,6 +621,17 @@ select log10(0); ---- -Infinity +# log10 with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(log(a), 5), round(log(b), 5), round(log(c), 5) from signed_integers; +---- +0.30103 NaN 2.08991 +0.60206 NULL NULL +NaN 2 NaN +NaN 4 NaN + +## log2 + # log2 scalar function query RRR rowsort select log2(1), log2(2), log2(4); @@ -384,12 +651,106 @@ select log2(0); ---- -Infinity +# log2 with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integers; +---- +1 NaN 6.94251 +2 NULL NULL +NaN 13.28771 NaN +NaN 6.64386 NaN + +## nanvl + +# nanvl scalar function +query RRR rowsort +select nanvl(0, 1), nanvl(asin(10), 2), nanvl(3, asin(10)); +---- +0 2 3 + +# nanvl scalar nulls +query R rowsort +select nanvl(null, 64); +---- +NULL + +# nanvl scalar nulls #1 +query R rowsort +select nanvl(2, null); +---- +NULL + +# nanvl scalar nulls #2 +query R rowsort +select nanvl(null, null); +---- +NULL + +# nanvl with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats; +---- +0.7754 1.11977 -0.9273 +2 -0.20136 0.7754 +2 -1.11977 4 +NULL NULL NULL + +## isnan + +# isnan scalar function +query BBB +select isnan(10.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE) +---- +false true true + +# isnan scalar nulls +query B +select isnan(NULL) +---- +NULL + +# isnan with columns +query BBBB +select isnan(asin(a + b + c)), isnan(-asin(a + b + c)), isnan(asin(d + e + f)), isnan(-asin(d + e + f)) from small_floats; +---- +true true false false +false false true true +true true false false +NULL NULL NULL NULL + +## iszero + +# iszero scalar function +query BBB +select iszero(10.0), iszero(0.0), iszero(-0.0) +---- +false true true + +# iszero scalar nulls +query B +select iszero(NULL) +---- +NULL + +# iszero with columns +query BBBB +select iszero(floor(a + b + c)), iszero(-floor(a + b + c)), iszero(floor(d + e + f)), iszero(-floor(d + e + f)) from small_floats; +---- +false false false false +true true false false +false false true true +NULL NULL NULL NULL + +## pi + # pi scalar function query RRR rowsort select pi(), pi() / 2, pi() / 3; ---- 3.14159265359 1.570796326795 1.047197551197 +## power + # power scalar function query III rowsort select power(2, 0), power(2, 1), power(2, 2); @@ -402,18 +763,29 @@ select power(null, 64); ---- NULL -# power scalar nulls 1 +# power scalar nulls #1 query I rowsort select power(2, null); ---- NULL -# power scalar nulls 2 +# power scalar nulls #2 query R rowsort select power(null, null); ---- NULL +# power with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(power(a, b), 5), round(power(c, d), 5), round(power(e, f), 5) from small_floats; +---- +1.1487 0 NaN +1.17462 1 0.31623 +NULL NULL NULL +NaN NaN 2.32282 + +## radians + # radians scalar function query RRR rowsort select radians(0), radians(90), radians(180); @@ -426,6 +798,17 @@ select radians(null); ---- NULL +# radians with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(radians(a), 5), round(radians(b), 5), round(radians(c), 5) from signed_integers; +---- +-0.01745 1.74533 -9.89602 +-0.05236 174.53293 -17.06932 +0.03491 -17.45329 2.14675 +0.06981 NULL NULL + +## round + # round scalar function query RRR rowsort select round(1.4), round(1.5), round(1.6); @@ -438,6 +821,17 @@ select round(null); ---- NULL +# round with columns +query RRR rowsort +select round(a), round(b), round(c) from small_floats; +---- +-1 0 -1 +-1 NULL NULL +0 0 1 +1 0 0 + +## signum + # signum scalar function query RRR rowsort select signum(-2), signum(0), signum(2); @@ -450,6 +844,17 @@ select signum(null); ---- NULL +# signum with columns +query RRR rowsort +select signum(a), signum(b), signum(c) from signed_integers; +---- +-1 1 -1 +-1 1 -1 +1 -1 1 +1 NULL NULL + +## sin + # sin scalar function query RRR rowsort select sin(0), sin(pi() / 3), sin(pi() / 2); @@ -462,6 +867,17 @@ select sin(null); ---- NULL +# sin with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(sin(a), 5), round(sin(b), 5), round(sin(c), 5) from small_floats; +---- +-0.64422 0.09983 -0.84147 +-0.84147 NULL NULL +0.19867 -0.09983 0.84147 +0.47943 -0.19867 0 + +## sinh + # sinh scalar function # sinh(x) = (exp(x) - exp(-x)) / 2 query RRR rowsort @@ -475,6 +891,17 @@ select sinh(null); ---- NULL +# sinh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(sinh(a), 5), round(sinh(b), 5), round(sinh(c), 5) from small_floats; +---- +-0.75858 0.10017 -1.1752 +-1.1752 NULL NULL +0.20134 -0.10017 1.1752 +0.5211 -0.20134 0 + +## sqrt + # sqrt scalar function query RRR rowsort select sqrt(0), sqrt(4), sqrt(9); @@ -487,6 +914,17 @@ select sqrt(null); ---- NULL +# sqrt with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(sqrt(a), 5), round(sqrt(b), 5), round(sqrt(c), 5) from signed_integers; +---- +1.41421 NaN 11.09054 +2 NULL NULL +NaN 10 NaN +NaN 100 NaN + +## tan + # tan scalar function query RRR rowsort select tan(0), tan(pi() / 6), tan(pi() / 4); @@ -499,6 +937,17 @@ select tan(null); ---- NULL +# tan with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(tan(a), 5), round(tan(b), 5), round(tan(c), 5) from small_floats; +---- +-0.84229 0.10033 -1.55741 +-1.55741 NULL NULL +0.20271 -0.10033 1.55741 +0.5463 -0.20271 0 + +## tanh + # tanh scalar function # tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) query RRR rowsort @@ -512,6 +961,17 @@ select tanh(null); ---- NULL +# tanh with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(tanh(a), 5), round(tanh(b), 5), round(tanh(c), 5) from small_floats; +---- +-0.60437 0.09967 -0.76159 +-0.76159 NULL NULL +0.19738 -0.09967 0.76159 +0.46212 -0.19738 0 + +## trunc + # trunc scalar function query RRR rowsort select trunc(1.5), trunc(2.71), trunc(3.1415); @@ -524,48 +984,191 @@ select trunc(null); ---- NULL +# trunc with columns +query RRR rowsort +select trunc(a), trunc(b), trunc(c) from small_floats; +---- +-1 NULL NULL +0 0 -1 +0 0 0 +0 0 1 + +# trunc with precision +query RRRRR rowsort +select trunc(4.267, 3), trunc(1.1234, 2), trunc(-1.1231, 6), trunc(1.2837284, 2), trunc(1.1, 0); +---- +4.267 1.12 -1.1231 1.28 1 + +# trunc with negative precision should truncate digits left of decimal +query R +select trunc(12345.678, -3); +---- +12000 + +# trunc with columns and precision +query RRR rowsort +select + trunc(sqrt(abs(a)), 3) as a3, + trunc(sqrt(abs(a)), 1) as a1, + trunc(arrow_cast(sqrt(abs(a)), 'Float64'), 3) as a3_f64 +from small_floats; +---- +0.447 0.4 0.447 +0.707 0.7 0.707 +0.837 0.8 0.837 +1 1 1 + +## bitwise and + # bitwise and with column and scalar query I rowsort -select c & 856 from t1; +select c & 856 from signed_integers; ---- -528 -848 +328 +8 88 +NULL + +# bitwise and with columns +query III rowsort +select a & b, c & d, e & f from signed_integers; +---- +0 0 5 +100 1024 8 +10000 2048 8 +NULL NULL NULL + +## bitwise or # bitwise or with column and scalar query I rowsort -select c | 856 from t1; +select c | 856 from signed_integers; ---- +-130 +-39 891 -895 -986 +NULL + +# bitwise or with columns +query III rowsort +select a | b, c | d, e | f from signed_integers; +---- +-1 -567 -2 +-3 -978 -2 +-998 -133 -11 +NULL NULL NULL + +## bitwise xor # bitwise xor with column and scalar query I rowsort -select c ^ 856 from t1; +select c ^ 856 from signed_integers; ---- -138 -367 +-138 +-367 803 +NULL + +# bitwise xor with columns +query III rowsort +select a ^ b, c ^ d, e ^ f from signed_integers; +---- +-10003 -3026 -10 +-101 -1591 -10 +-998 -133 -16 +NULL NULL NULL + +# bitwise xor with other operators +query II rowsort +select 2 * c - 1 ^ 856 + d + 3, d ^ 7 >> 4 from signed_integers; +---- +-3328 128 +-822 64 +686 -16 +NULL -32 + +statement ok +set datafusion.sql_parser.dialect = postgresql; + +# postgresql bitwise xor with column and scalar +query I rowsort +select c # 856 from signed_integers; +---- +-138 +-367 +803 +NULL + +# postgresql bitwise xor with columns +query III rowsort +select a # b, c # d, e # f from signed_integers; +---- +-10003 -3026 -10 +-101 -1591 -10 +-998 -133 -16 +NULL NULL NULL -# right shift with column and scalar +# postgresql bitwise xor with other operators +query II rowsort +select 2 * c - 1 # 856 + d + 3, d # 7 >> 4 from signed_integers; +---- +-3328 128 +-822 64 +686 -16 +NULL -32 + +statement ok +set datafusion.sql_parser.dialect = generic; + + +## bitwise right shift + +# bitwise right shift with column and scalar query I rowsort -select d >> 2 from t1; +select d >> 2 from signed_integers; ---- +-128 +-64 256 512 -64 -# left shift with column and scalar +# bitwise right shift with columns +query III rowsort +select a >> b, c >> d, e >> f from signed_integers; +---- +-1 -567 -1 +-1 -978 -1 +0 123 0 +NULL NULL NULL + +## bitwise left shift + +# bitwise left shift with column and scalar query I rowsort -select d << 2 from t1; +select d << 2 from signed_integers; ---- -1024 +-1024 +-2048 4096 8192 +# bitwise left shift with columns +query III rowsort +select a << b, c << d, e << f from signed_integers; +---- +-16 -567 -4096 +-196608 -978 -24576 +33554432 123 10485760 +NULL NULL NULL + statement ok -drop table t1 +drop table unsigned_integers; + +statement ok +drop table signed_integers; + +statement ok +drop table small_floats; statement ok @@ -636,6 +1239,18 @@ FROM t1 999 999 +# issue: https://github.com/apache/arrow-datafusion/issues/7004 +query B +select case c1 + when 'foo' then TRUE + when 'bar' then FALSE +end from t1 +---- +NULL +NULL +NULL +NULL + statement ok drop table t1 @@ -710,21 +1325,22 @@ SELECT arrow_typeof(c8), arrow_typeof(c6), arrow_typeof(c8 + c6) FROM aggregate_ Int32 Int64 Int64 # in list array -query BBBBB rowsort +query BBBBBB rowsort SELECT c1 IN ('a', 'c') AS utf8_in_true ,c1 IN ('x', 'y') AS utf8_in_false ,c1 NOT IN ('x', 'y') AS utf8_not_in_true ,c1 NOT IN ('a', 'c') AS utf8_not_in_false ,NULL IN ('a', 'c') AS utf8_in_null + ,'a' IN (c1, NULL, 'c') uft8_in_column FROM aggregate_test_100 WHERE c12 < 0.05 ---- -false false true true NULL -false false true true NULL -false false true true NULL -false false true true NULL -true false true false NULL -true false true false NULL -true false true false NULL +false false true true NULL NULL +false false true true NULL NULL +false false true true NULL NULL +false false true true NULL NULL +true false true false NULL NULL +true false true false NULL true +true false true false NULL true # csv count star query III @@ -911,7 +1527,7 @@ SELECT not(true), not(false) ---- false true -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nInternal error: NOT 'Literal \{ value: Int64\(1\) \}' can't be evaluated because the expression's type is Int64, not boolean or NULL\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nInternal error: NOT 'Literal \{ value: Int64\(1\) \}' can't be evaluated because the expression's type is Int64, not boolean or NULL SELECT not(1), not(0) query ?B @@ -919,7 +1535,7 @@ SELECT null, not(null) ---- NULL NULL -query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nInternal error: NOT 'Literal \{ value: Utf8\("hi"\) \}' can't be evaluated because the expression's type is Utf8, not boolean or NULL\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +query error DataFusion error: Optimizer rule 'simplify_expressions' failed\ncaused by\nInternal error: NOT 'Literal \{ value: Utf8\("hi"\) \}' can't be evaluated because the expression's type is Utf8, not boolean or NULL SELECT NOT('hi') # test_negative_expressions() @@ -1025,7 +1641,6 @@ true true false true true true # csv query boolean gt gt eq query BBBBBB rowsort SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1 ----- ---- NULL NULL NULL NULL NULL NULL NULL false NULL false NULL NULL @@ -1040,10 +1655,10 @@ true true false true true true # csv query boolean distinct from query BBBBBB rowsort SELECT a, b, - a is distinct from b as df, - b is distinct from true as df_scalar, - a is not distinct from b as ndf, - a is not distinct from true as ndf_scalar + a is distinct from b as df, + b is distinct from true as df_scalar, + a is not distinct from b as ndf, + a is not distinct from true as ndf_scalar FROM t1 ---- NULL NULL false true true false @@ -1263,3 +1878,68 @@ query T SELECT CONCAT('Hello', 'World') ---- HelloWorld + +statement ok +CREATE TABLE simple_string( + letter STRING, + letter2 STRING +) as VALUES + ('A', 'APACHE'), + ('B', 'APACHE'), + ('C', 'APACHE'), + ('D', 'APACHE') +; + +query TT +EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; +---- +logical_plan +Projection: simple_string.letter, simple_string.letter = Utf8("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) +--TableScan: simple_string projection=[letter] +physical_plan +ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter = left(Utf8("APACHE"),Int64(1))] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TB +SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; + ---- +---- +A true +B false +C false +D false + +query TT +EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; +---- +logical_plan +Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) +--TableScan: simple_string projection=[letter, letter2] +physical_plan +ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TB +SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; +---- +A true +B false +C false +D false + +# test string_temporal_coercion +query BBBBBBBBBB +select + arrow_cast(to_timestamp('2020-01-01 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == '2020-01-01T01:01:11', + arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)') == arrow_cast('2020-01-02T01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-03 01:01:11.1234567890Z'), 'Time32(Second)') == '01:01:11', + arrow_cast(to_timestamp('2020-01-04 01:01:11.1234567890Z'), 'Time32(Second)') == arrow_cast('01:01:11', 'LargeUtf8'), + arrow_cast(to_timestamp('2020-01-05 01:01:11.1234567890Z'), 'Time64(Microsecond)') == '01:01:11.123456', + arrow_cast(to_timestamp('2020-01-06 01:01:11.1234567890Z'), 'Time64(Microsecond)') == arrow_cast('01:01:11.123456', 'LargeUtf8'), + arrow_cast('2020-01-07', 'Date32') == '2020-01-07', + arrow_cast('2020-01-08', 'Date64') == '2020-01-08', + arrow_cast('2020-01-09', 'Date32') == arrow_cast('2020-01-09', 'LargeUtf8'), + arrow_cast('2020-01-10', 'Date64') == arrow_cast('2020-01-10', 'LargeUtf8') +; +---- +true true true true true true true true true true diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt similarity index 63% rename from datafusion/core/tests/sqllogictests/test_files/select.slt rename to datafusion/sqllogictest/test_files/select.slt index cc8828ef879c4..ea570b99d4dd1 100644 --- a/datafusion/core/tests/sqllogictests/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -44,7 +44,7 @@ CREATE EXTERNAL TABLE aggregate_simple ( ) STORED AS CSV WITH HEADER ROW -LOCATION 'tests/data/aggregate_simple.csv' +LOCATION '../core/tests/data/aggregate_simple.csv' ########## @@ -224,6 +224,12 @@ select ---- false true false true true false false true false true true false true true false false true +# select NaNs +query BBBB +select (isnan('NaN'::double) AND 'NaN'::double > 0) a, (isnan('-NaN'::double) AND '-NaN'::double < 0) b, (isnan('NaN'::float) AND 'NaN'::float > 0) c, (isnan('-NaN'::float) AND '-NaN'::float < 0) d +---- +true true true true + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit 1; @@ -479,8 +485,7 @@ Projection: select_between_data.c1 >= Int64(2) AND select_between_data.c1 <= Int --TableScan: select_between_data projection=[c1] physical_plan ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as select_between_data.c1 BETWEEN Int64(2) AND Int64(3)] ---RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----MemoryExec: partitions=1, partition_sizes=[1] +--MemoryExec: partitions=1, partition_sizes=[1] # TODO: query_get_indexed_field @@ -729,7 +734,7 @@ CREATE EXTERNAL TABLE annotated_data_finite2 ( STORED AS CSV WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION 'tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv'; # test_source_projection @@ -843,8 +848,259 @@ statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT conta SELECT * EXCLUDE(a, a) FROM table1 +# if EXCEPT all the columns, query should still succeed but return empty +statement ok +SELECT * EXCEPT(a, b, c, d) +FROM table1 + +# EXCLUDE order shouldn't matter +query II +SELECT * EXCLUDE(b, a) +FROM table1 +ORDER BY c +LIMIT 5 +---- +100 1000 +200 2000 + +# EXCLUDE with out of order but duplicate columns should error +statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT contains duplicate column names +SELECT * EXCLUDE(d, b, c, a, a, b, c, d) +FROM table1 + +# avoiding adding an alias if the column name is the same +query TT +EXPLAIN select a as a FROM table1 order by a +---- +logical_plan +Sort: table1.a ASC NULLS LAST +--TableScan: table1 projection=[a] +physical_plan +SortExec: expr=[a@0 ASC NULLS LAST] +--MemoryExec: partitions=1, partition_sizes=[1] + +# ambiguous column references in on join +query error DataFusion error: Schema error: Ambiguous reference to unqualified field a +EXPLAIN select a as a FROM table1 t1 CROSS JOIN table1 t2 order by a + +# run below query in multi partitions +statement ok +set datafusion.execution.target_partitions = 2; + +# since query below contains computation +# inside projection expr, increasing partitions +# is beneficial +query TT +EXPLAIN SELECT a, a+b +FROM annotated_data_finite2 +ORDER BY a ASC; +---- +logical_plan +Sort: annotated_data_finite2.a ASC NULLS LAST +--Projection: annotated_data_finite2.a, annotated_data_finite2.a + annotated_data_finite2.b +----TableScan: annotated_data_finite2 projection=[a, b] +physical_plan +SortPreservingMergeExec: [a@0 ASC NULLS LAST] +--ProjectionExec: expr=[a@0 as a, a@0 + b@1 as annotated_data_finite2.a + annotated_data_finite2.b] +----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true + +# since query below doesn't computation +# inside projection expr, increasing partitions +# is not beneficial. Hence plan doesn't contain +# RepartitionExec +query TT +EXPLAIN SELECT a, b, 2 +FROM annotated_data_finite2 +ORDER BY a ASC; +---- +logical_plan +Sort: annotated_data_finite2.a ASC NULLS LAST +--Projection: annotated_data_finite2.a, annotated_data_finite2.b, Int64(2) +----TableScan: annotated_data_finite2 projection=[a, b] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, 2 as Int64(2)] +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true + +# source is ordered by a,b,c +# when filter result is constant for column a +# ordering b, c is still satisfied. Final plan shouldn't have +# SortExec. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE a=0 +ORDER BY b, c; +---- +logical_plan +Sort: annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST +--Filter: annotated_data_finite2.a = Int32(0) +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0)] +physical_plan +SortPreservingMergeExec: [b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: a@1 = 0 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# source is ordered by a,b,c +# when filter result is constant for column a and b +# ordering c is still satisfied. Final plan shouldn't have +# SortExec. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE a=0 and b=0 +ORDER BY c; +---- +logical_plan +Sort: annotated_data_finite2.c ASC NULLS LAST +--Filter: annotated_data_finite2.a = Int32(0) AND annotated_data_finite2.b = Int32(0) +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0), annotated_data_finite2.b = Int32(0)] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: a@1 = 0 AND b@2 = 0 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# source is ordered by a,b,c +# when filter result is constant for column a and b +# ordering b, c is still satisfied. Final plan shouldn't have +# SortExec. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE a=0 and b=0 +ORDER BY b, c; +---- +logical_plan +Sort: annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST +--Filter: annotated_data_finite2.a = Int32(0) AND annotated_data_finite2.b = Int32(0) +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0), annotated_data_finite2.b = Int32(0)] +physical_plan +SortPreservingMergeExec: [b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: a@1 = 0 AND b@2 = 0 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# source is ordered by a,b,c +# when filter result is constant for column a and b +# ordering a, b, c is still satisfied. Final plan shouldn't have +# SortExec. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE a=0 and b=0 +ORDER BY a, b, c; +---- +logical_plan +Sort: annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST +--Filter: annotated_data_finite2.a = Int32(0) AND annotated_data_finite2.b = Int32(0) +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0), annotated_data_finite2.b = Int32(0)] +physical_plan +SortPreservingMergeExec: [a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: a@1 = 0 AND b@2 = 0 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# source is ordered by a,b,c +# when filter result is when filter contains or +# column a, and b may not be constant. Hence final plan +# should contain SortExec +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE a=0 or b=0 +ORDER BY c; +---- +logical_plan +Sort: annotated_data_finite2.c ASC NULLS LAST +--Filter: annotated_data_finite2.a = Int32(0) OR annotated_data_finite2.b = Int32(0) +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0) OR annotated_data_finite2.b = Int32(0)] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--SortExec: expr=[c@3 ASC NULLS LAST] +----CoalesceBatchesExec: target_batch_size=8192 +------FilterExec: a@1 = 0 OR b@2 = 0 +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +# When ordering lost during projection, we shouldn't keep the SortExec. +# in the final physical plan. +query TT +EXPLAIN SELECT c2, COUNT(*) +FROM (SELECT c2 +FROM aggregate_test_100 +ORDER BY c1, c2) +GROUP BY c2; +---- +logical_plan +Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--Projection: aggregate_test_100.c2 +----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +------Projection: aggregate_test_100.c2, aggregate_test_100.c1 +--------TableScan: aggregate_test_100 projection=[c1, c2] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 +------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true + statement ok drop table annotated_data_finite2; statement ok drop table t; + +statement ok +create table t(x bigint, y bigint) as values (1,2), (1,3); + +query II +select z+1, y from (select x+1 as z, y from t) where y > 1; +---- +3 2 +3 3 + +query TT +EXPLAIN SELECT x/2, x/2+1 FROM t; +---- +logical_plan +Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) +--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +----TableScan: t projection=[x] +physical_plan +ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] +--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT x/2, x/2+1 FROM t; +---- +0 1 +0 1 + +query TT +EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; +---- +logical_plan +Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) +--Projection: abs(t.x) AS abs(t.x)t.x, t.y +----TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] +--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +----MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT abs(x), abs(x) + abs(y) FROM t; +---- +1 3 +1 4 + +statement ok +DROP TABLE t; diff --git a/datafusion/core/tests/sqllogictests/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt similarity index 86% rename from datafusion/core/tests/sqllogictests/test_files/set_variable.slt rename to datafusion/sqllogictest/test_files/set_variable.slt index 04e3715fd353d..440fb2c6ef2b0 100644 --- a/datafusion/core/tests/sqllogictests/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -93,14 +93,10 @@ datafusion.execution.coalesce_batches false statement ok set datafusion.catalog.information_schema = true -statement error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -External error: provided string was not `true` or `false` +statement error DataFusion error: Error parsing 1 as bool SET datafusion.execution.coalesce_batches to 1 -statement error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -External error: provided string was not `true` or `false` +statement error DataFusion error: Error parsing abc as bool SET datafusion.execution.coalesce_batches to abc # set u64 variable @@ -136,19 +132,13 @@ datafusion.execution.batch_size 2 statement ok set datafusion.catalog.information_schema = true -statement error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -External error: invalid digit found in string +statement error DataFusion error: Error parsing -1 as usize SET datafusion.execution.batch_size to -1 -statement error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -External error: invalid digit found in string +statement error DataFusion error: Error parsing abc as usize SET datafusion.execution.batch_size to abc -statement error DataFusion error: SQL error: ParserError\("Expected an SQL statement, found: caused"\) -caused by -External error: invalid digit found in string +statement error External error: invalid digit found in string SET datafusion.execution.batch_size to 0.1 # set time zone diff --git a/datafusion/core/tests/sqllogictests/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/strings.slt rename to datafusion/sqllogictest/test_files/strings.slt diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt new file mode 100644 index 0000000000000..fc14798a3bfed --- /dev/null +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Struct Expressions Tests +############# + +statement ok +CREATE TABLE values( + a INT, + b FLOAT, + c VARCHAR +) AS VALUES + (1, 1.1, 'a'), + (2, 2.2, 'b'), + (3, 3.3, 'c') +; + +# struct[i] +query IRT +select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; +---- +1 2.55 a + +# struct[i] with columns +query R +select struct(a, b, c)['c1'] from values; +---- +1.1 +2.2 +3.3 + +# struct scalar function #1 +query ? +select struct(1, 3.14, 'e'); +---- +{c0: 1, c1: 3.14, c2: e} + +# struct scalar function with columns #1 +query ? +select struct(a, b, c) from values; +---- +{c0: 1, c1: 1.1, c2: a} +{c0: 2, c1: 2.2, c2: b} +{c0: 3, c1: 3.3, c2: c} + +statement ok +drop table values; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt new file mode 100644 index 0000000000000..3e0fcb7aa96eb --- /dev/null +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -0,0 +1,1062 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + +############# +## Subquery Tests +############# + + +############# +## Setup test data table +############# +# there tables for subquery +statement ok +CREATE TABLE t0(t0_id INT, t0_name TEXT, t0_int INT) AS VALUES +(11, 'o', 6), +(22, 'p', 7), +(33, 'q', 8), +(44, 'r', 9); + +statement ok +CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS customer ( + c_custkey BIGINT, + c_name VARCHAR, + c_address VARCHAR, + c_nationkey BIGINT, + c_phone VARCHAR, + c_acctbal DECIMAL(15, 2), + c_mktsegment VARCHAR, + c_comment VARCHAR, +) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/customer.csv'; + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS orders ( + o_orderkey BIGINT, + o_custkey BIGINT, + o_orderstatus VARCHAR, + o_totalprice DECIMAL(15, 2), + o_orderdate DATE, + o_orderpriority VARCHAR, + o_clerk VARCHAR, + o_shippriority INTEGER, + o_comment VARCHAR, +) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/orders.csv'; + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( + l_orderkey BIGINT, + l_partkey BIGINT, + l_suppkey BIGINT, + l_linenumber INTEGER, + l_quantity DECIMAL(15, 2), + l_extendedprice DECIMAL(15, 2), + l_discount DECIMAL(15, 2), + l_tax DECIMAL(15, 2), + l_returnflag VARCHAR, + l_linestatus VARCHAR, + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct VARCHAR, + l_shipmode VARCHAR, + l_comment VARCHAR, +) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/lineitem.csv'; + +# in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +11 a 1 +33 c 3 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id + 12 not in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +22 b 2 + +# in subquery with two parentheses, see #5529 +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (( + select t2.t2_id from t2 + )) +---- +11 a 1 +22 b 2 +44 d 4 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (( + select t2.t2_id from t2 + )) +and t1.t1_int < 3 +---- +11 a 1 +22 b 2 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id not in (( + select t2.t2_id from t2 where t2.t2_int = 3 + )) +---- +22 b 2 +33 c 3 + +# VALUES in subqueries, see 6017 +query I +select t1_id +from t1 +where t1_int = (select max(i) from (values (1)) as s(i)); +---- +11 + +# aggregated_correlated_scalar_subquery +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +--Left Join: t1.t1_id = __scalar_sq_1.t2_id +----TableScan: t1 projection=[t1_id] +----SubqueryAlias: __scalar_sq_1 +------Projection: SUM(t2.t2_int), t2.t2_id +--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +----------TableScan: t2 projection=[t2_id, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +11 3 +22 1 +33 NULL +44 3 + +# aggregated_correlated_scalar_subquery_with_cast +query TT +explain SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int * Float64(1)) + Int64(1) AS t2_sum +--Left Join: t1.t1_id = __scalar_sq_1.t2_id +----TableScan: t1 projection=[t1_id] +----SubqueryAlias: __scalar_sq_1 +------Projection: SUM(t2.t2_int * Float64(1)) + Float64(1) AS SUM(t2.t2_int * Float64(1)) + Int64(1), t2.t2_id +--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Float64)) AS SUM(t2.t2_int * Float64(1))]] +----------TableScan: t2 projection=[t2_id, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] +--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------CoalesceBatchesExec: target_batch_size=2 +------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query IR rowsort +SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +11 4 +22 2 +33 NULL +44 4 + +# aggregated_correlated_scalar_subquery_with_extra_group_by_constant +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +--Left Join: t1.t1_id = __scalar_sq_1.t2_id +----TableScan: t1 projection=[t1_id] +----SubqueryAlias: __scalar_sq_1 +------Projection: SUM(t2.t2_int), t2.t2_id +--------Aggregate: groupBy=[[t2.t2_id, Utf8("a")]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +----------TableScan: t2 projection=[t2_id, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Left, on=[(t1_id@0, t2_id@1)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t2_id@1], 4), input_partitions=4 +----------ProjectionExec: expr=[SUM(t2.t2_int)@2 as SUM(t2.t2_int), t2_id@0 as t2_id] +------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id, Utf8("a")@1 as Utf8("a")], aggr=[SUM(t2.t2_int)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([t2_id@0, Utf8("a")@1], 4), input_partitions=4 +------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id, a as Utf8("a")], aggr=[SUM(t2.t2_int)] +--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 +---- +11 3 +22 1 +33 NULL +44 3 + +# aggregated_correlated_scalar_subquery_with_having +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +--Left Join: t1.t1_id = __scalar_sq_1.t2_id +----TableScan: t1 projection=[t1_id] +----SubqueryAlias: __scalar_sq_1 +------Projection: SUM(t2.t2_int), t2.t2_id +--------Filter: SUM(t2.t2_int) < Int64(3) +----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +------------TableScan: t2 projection=[t2_id, t2_int] +physical_plan +ProjectionExec: expr=[t1_id@2 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)] +------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +--------CoalesceBatchesExec: target_batch_size=2 +----------FilterExec: SUM(t2.t2_int)@1 < 3 +------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 +---- +11 NULL +22 1 +33 NULL +44 NULL + + +statement ok +set datafusion.explain.logical_plan_only = true; + +# correlated_recursive_scalar_subquery +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + ) +) order by c_custkey; +---- +logical_plan +Sort: customer.c_custkey ASC NULLS LAST +--Projection: customer.c_custkey +----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.SUM(orders.o_totalprice) +------TableScan: customer projection=[c_custkey, c_acctbal] +------SubqueryAlias: __scalar_sq_1 +--------Projection: SUM(orders.o_totalprice), orders.o_custkey +----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] +------------Projection: orders.o_custkey, orders.o_totalprice +--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price +----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +----------------SubqueryAlias: __scalar_sq_2 +------------------Projection: SUM(lineitem.l_extendedprice) AS price, lineitem.l_orderkey +--------------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] +----------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# correlated_where_in +query TT +explain select o_orderkey from orders +where o_orderstatus in ( + select l_linestatus from lineitem where l_orderkey = orders.o_orderkey +); +---- +logical_plan +Projection: orders.o_orderkey +--LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey +----TableScan: orders projection=[o_orderkey, o_orderstatus] +----SubqueryAlias: __correlated_sq_1 +------Projection: lineitem.l_linestatus, lineitem.l_orderkey +--------TableScan: lineitem projection=[l_orderkey, l_linestatus] + +query I rowsort +select o_orderkey from orders +where o_orderstatus in ( + select l_linestatus from lineitem where l_orderkey = orders.o_orderkey +); +---- +2 +3 + +#exists_subquery_with_same_table +#Subquery and outer query refer to the same table. +#It will not be rewritten to join because it is not a correlated subquery. +query TT +explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE EXISTS(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int) +---- +logical_plan +Filter: EXISTS () +--Subquery: +----Projection: t1.t1_int +------Filter: t1.t1_id > t1.t1_int +--------TableScan: t1 +--TableScan: t1 projection=[t1_id, t1_name, t1_int] + + +#in_subquery_with_same_table +#Subquery and outer query refer to the same table. +#It will be rewritten to join because in-subquery has extra predicate(`t1.t1_id = __correlated_sq_10.t1_int`). +query TT +explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t1_int FROM t1 WHERE t1.t1_id > t1.t1_int) +---- +logical_plan +LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int +--TableScan: t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: t1.t1_int +------Filter: t1.t1_id > t1.t1_int +--------TableScan: t1 projection=[t1_id, t1_int] + +#in_subquery_nested_exist_subquery +query TT +explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int)) +---- +logical_plan +LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +--TableScan: t1 projection=[t1_id, t1_name, t1_int] +--SubqueryAlias: __correlated_sq_1 +----Projection: t2.t2_id +------LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int +--------TableScan: t2 projection=[t2_id, t2_int] +--------SubqueryAlias: __correlated_sq_2 +----------TableScan: t1 projection=[t1_int] + +#invalid_scalar_subquery +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Scalar subquery should only return one column, but found 2: t2.t2_id, t2.t2_name +SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t1.t1_int) FROM t1 + +#subquery_not_allowed +#In/Exist Subquery is not allowed in ORDER BY clause. +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes +SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) + +#non_aggregated_correlated_scalar_subquery +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row +SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1 + +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + + +#non_aggregated_correlated_scalar_subquery +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row +SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1 + +#non_aggregated_correlated_scalar_subquery_with_limit +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row +SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 2) as t2_int from t1 + +#non_aggregated_correlated_scalar_subquery_with_single_row +query TT +explain SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) as t2_int from t1 +---- +logical_plan +Projection: t1.t1_id, () AS t2_int +--Subquery: +----Limit: skip=0, fetch=1 +------Projection: t2.t2_int +--------Filter: t2.t2_int = outer_ref(t1.t1_int) +----------TableScan: t2 +--TableScan: t1 projection=[t1_id, t1_int] + +query TT +explain SELECT t1_id from t1 where t1_int = (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1) +---- +logical_plan +Projection: t1.t1_id +--Filter: t1.t1_int = () +----Subquery: +------Limit: skip=0, fetch=1 +--------Projection: t2.t2_int +----------Filter: t2.t2_int = outer_ref(t1.t1_int) +------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_int] + +query TT +explain SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.a AS t2_int +--Left Join: CAST(t1.t1_int AS Int64) = __scalar_sq_1.a +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: Int64(1) AS a +--------EmptyRelation + +query II rowsort +SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1 +---- +11 1 +22 NULL +33 NULL +44 NULL + +#non_equal_correlated_scalar_subquery +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated column is not allowed in predicate: t2\.t2_id < outer_ref\(t1\.t1_id\) +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 + +#aggregated_correlated_scalar_subquery_with_extra_group_by_columns +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_name) as t2_sum from t1 + +#support_agg_correlated_columns +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name) +---- +logical_plan +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) +--------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +----------Filter: outer_ref(t1.t1_name) = t2.t2_name +------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] + +#support_agg_correlated_columns2 +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0) +---- +logical_plan +Projection: t1.t1_id, t1.t1_name +--Filter: EXISTS () +----Subquery: +------Projection: COUNT(*) +--------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) +----------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +------------Filter: outer_ref(t1.t1_name) = t2.t2_name +--------------TableScan: t2 +----TableScan: t1 projection=[t1_id, t1_name, t1_int] + +#support_join_correlated_columns +query TT +explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name)) +---- +logical_plan +Filter: EXISTS () +--Subquery: +----Projection: Int64(1) +------Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) +--------TableScan: t1 +--------TableScan: t2 +--TableScan: t0 projection=[t0_id, t0_name] + +#subquery_contains_join_contains_correlated_columns +query TT +explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN (select * from t2 where t2.t2_name = t0.t0_name) as t2 ON(t1.t1_id = t2.t2_id )) +---- +logical_plan +LeftSemi Join: t0.t0_name = __correlated_sq_1.t2_name +--TableScan: t0 projection=[t0_id, t0_name] +--SubqueryAlias: __correlated_sq_1 +----Projection: t2.t2_name +------Inner Join: t1.t1_id = t2.t2_id +--------TableScan: t1 projection=[t1_id] +--------SubqueryAlias: t2 +----------TableScan: t2 projection=[t2_id, t2_name] + +#subquery_contains_join_contains_sub_query_alias_correlated_columns +query TT +explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (select 1 from (SELECT * FROM t1 where t1.t1_id = t0.t0_id) as x INNER JOIN (select * from t2 where t2.t2_name = t0.t0_name) as y ON(x.t1_id = y.t2_id)) +---- +logical_plan +LeftSemi Join: t0.t0_id = __correlated_sq_1.t1_id, t0.t0_name = __correlated_sq_1.t2_name +--TableScan: t0 projection=[t0_id, t0_name] +--SubqueryAlias: __correlated_sq_1 +----Projection: x.t1_id, y.t2_name +------Inner Join: x.t1_id = y.t2_id +--------SubqueryAlias: x +----------TableScan: t1 projection=[t1_id] +--------SubqueryAlias: y +----------TableScan: t2 projection=[t2_id, t2_name] + +#support_order_by_correlated_columns +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id >= t1_id order by t1_id) +---- +logical_plan +Filter: EXISTS () +--Subquery: +----Sort: outer_ref(t1.t1_id) ASC NULLS LAST +------Projection: t2.t2_id, t2.t2_name, t2.t2_int +--------Filter: t2.t2_id >= outer_ref(t1.t1_id) +----------TableScan: t2 +--TableScan: t1 projection=[t1_id, t1_name] + +#exists_subquery_with_select_null +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT NULL) +---- +logical_plan +Filter: EXISTS () +--Subquery: +----Projection: NULL +------EmptyRelation +--TableScan: t1 projection=[t1_id, t1_name] + +#exists_subquery_with_limit +#de-correlated, limit is removed +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1) +---- +logical_plan +LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +--TableScan: t1 projection=[t1_id, t1_name] +--SubqueryAlias: __correlated_sq_1 +----TableScan: t2 projection=[t2_id] + +query IT rowsort +SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1) +---- +11 a +22 b +44 d + +#exists_subquery_with_limit0 +#de-correlated, limit is removed and replaced with EmptyRelation +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) +---- +logical_plan +LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +--TableScan: t1 projection=[t1_id, t1_name] +--EmptyRelation + +query IT rowsort +SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) +---- + + +#not_exists_subquery_with_limit0 +#de-correlated, limit is removed and replaced with EmptyRelation +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) +---- +logical_plan +LeftAnti Join: t1.t1_id = __correlated_sq_1.t2_id +--TableScan: t1 projection=[t1_id, t1_name] +--EmptyRelation + +query IT rowsort +SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) +---- +11 a +22 b +33 c +44 d + +#in_correlated_subquery_with_limit +#not de-correlated +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where t1_name = t2_name limit 10) +---- +logical_plan +Filter: t1.t1_id IN () +--Subquery: +----Limit: skip=0, fetch=10 +------Projection: t2.t2_id +--------Filter: outer_ref(t1.t1_name) = t2.t2_name +----------TableScan: t2 +--TableScan: t1 projection=[t1_id, t1_name] + +#in_non_correlated_subquery_with_limit +#de-correlated, limit is kept +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 limit 10) +---- +logical_plan +LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +--TableScan: t1 projection=[t1_id, t1_name] +--SubqueryAlias: __correlated_sq_1 +----Limit: skip=0, fetch=10 +------TableScan: t2 projection=[t2_id], fetch=10 + + +#uncorrelated_scalar_subquery_with_limit0 +query TT +explain SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.t2_id AS t2_id +--Left Join: +----TableScan: t1 projection=[t1_id] +----EmptyRelation + +query II rowsort +SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 +---- +11 NULL +22 NULL +33 NULL +44 NULL + +#support_union_subquery +query TT +explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id UNION ALL SELECT * FROM t2 WHERE upper(t2_name) = upper(t1.t1_name)) +---- +logical_plan +Filter: EXISTS () +--Subquery: +----Union +------Projection: t2.t2_id, t2.t2_name, t2.t2_int +--------Filter: t2.t2_id = outer_ref(t1.t1_id) +----------TableScan: t2 +------Projection: t2.t2_id, t2.t2_name, t2.t2_int +--------Filter: upper(t2.t2_name) = upper(outer_ref(t1.t1_name)) +----------TableScan: t2 +--TableScan: t1 projection=[t1_id, t1_name] + +#simple_uncorrelated_scalar_subquery +query TT +explain select (select count(*) from t1) as b +---- +logical_plan +Projection: __scalar_sq_1.COUNT(*) AS b +--SubqueryAlias: __scalar_sq_1 +----Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +------TableScan: t1 projection=[] + +#simple_uncorrelated_scalar_subquery2 +query TT +explain select (select count(*) from t1) as b, (select count(1) from t2) +---- +logical_plan +Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT(Int64(1)) +--Left Join: +----SubqueryAlias: __scalar_sq_1 +------Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------TableScan: t1 projection=[] +----SubqueryAlias: __scalar_sq_2 +------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +--------TableScan: t2 projection=[] + +query II +select (select count(*) from t1) as b, (select count(1) from t2) +---- +4 4 + +#correlated_scalar_subquery_count_agg +query TT +explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +---- +logical_plan +Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS COUNT(*) +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*), t2.t2_int, __always_true +--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +---- +11 1 +22 0 +33 3 +44 0 + + +#correlated_scalar_subquery_count_agg2 +query TT +explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +---- +logical_plan +Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS cnt +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*), t2.t2_int, __always_true +--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +---- +11 1 +22 0 +33 3 +44 0 + +#correlated_scalar_subquery_count_agg_with_alias +query TT +explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +---- +logical_plan +Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) AS _cnt ELSE __scalar_sq_1._cnt END AS cnt +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*) AS _cnt, t2.t2_int, __always_true +--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +---- +11 1 +22 0 +33 3 +44 0 + +#correlated_scalar_subquery_count_agg_complex_expr +query TT +explain SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +---- +logical_plan +Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS _cnt ELSE __scalar_sq_1._cnt END AS _cnt +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*) + Int64(2) AS _cnt, t2.t2_int, __always_true +--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +---- +11 3 +22 2 +33 5 +44 2 + +#correlated_scalar_subquery_count_agg_where_clause +query TT +explain select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = t2.t2_id) < t1.t1_int +---- +logical_plan +Projection: t1.t1_int +--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END < CAST(t1.t1_int AS Int64) +----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +------Left Join: t1.t1_id = __scalar_sq_1.t2_id +--------TableScan: t1 projection=[t1_id, t1_int] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: COUNT(*), t2.t2_id, __always_true +------------Aggregate: groupBy=[[t2.t2_id, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_id] + +query I rowsort +select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = t2.t2_id) < t1.t1_int +---- +2 +3 +4 + +#correlated_scalar_subquery_count_agg_with_having +#the having condition is kept as the normal filter condition, no need to pull up +query TT +explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) >1) from t1 +---- +logical_plan +Projection: t1.t1_id, __scalar_sq_1.cnt_plus_2 AS cnt_plus_2 +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int +--------Filter: COUNT(*) > Int64(1) +----------Projection: t2.t2_int, COUNT(*) +------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) >1) from t1 +---- +11 NULL +22 NULL +33 5 +44 NULL + +#correlated_scalar_subquery_count_agg_with_pull_up_having +#the having condition need to pull up and evaluated after the left out join +query TT +explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 +---- +logical_plan +Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +--Left Join: t1.t1_int = __scalar_sq_1.t2_int +----TableScan: t1 projection=[t1_id, t1_int] +----SubqueryAlias: __scalar_sq_1 +------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int, COUNT(*), __always_true +--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +----------TableScan: t2 projection=[t2_int] + +query II rowsort +SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 +---- +11 NULL +22 2 +33 NULL +44 2 + +#correlated_scalar_subquery_count_agg_in_having +query TT +explain select t1.t1_int from t1 group by t1.t1_int having (select count(*) from t2 where t1.t1_int = t2.t2_int) = 0 +---- +logical_plan +Projection: t1.t1_int +--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END = Int64(0) +----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +------Left Join: t1.t1_int = __scalar_sq_1.t2_int +--------Aggregate: groupBy=[[t1.t1_int]], aggr=[[]] +----------TableScan: t1 projection=[t1_int] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: COUNT(*), t2.t2_int, __always_true +------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_int] + +query I rowsort +select t1.t1_int from t1 group by t1.t1_int having (select count(*) from t2 where t1.t1_int = t2.t2_int) = 0 +---- +2 +4 + +#correlated_scalar_subquery_count_agg_in_nested_projection +query TT +explain select t1.t1_int from t1 where (select cnt from (select count(*) as cnt, sum(t2_int) from t2 where t1.t1_int = t2.t2_int)) = 0 +---- +logical_plan +Projection: t1.t1_int +--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.cnt END = Int64(0) +----Projection: t1.t1_int, __scalar_sq_1.cnt, __scalar_sq_1.__always_true +------Left Join: t1.t1_int = __scalar_sq_1.t2_int +--------TableScan: t1 projection=[t1_int] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: COUNT(*) AS cnt, t2.t2_int, __always_true +------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_int] + + +query I rowsort +select t1.t1_int from t1 where (select cnt from (select count(*) as cnt, sum(t2_int) from t2 where t1.t1_int = t2.t2_int)) = 0 +---- +2 +4 + +#correlated_scalar_subquery_count_agg_in_nested_subquery +#pull up the deeply nested having condition +query TT +explain +select t1.t1_int from t1 where ( + select cnt_plus_one + 1 as cnt_plus_two from ( + select cnt + 1 as cnt_plus_one from ( + select count(*) as cnt, sum(t2_int) s from t2 where t1.t1_int = t2.t2_int having cnt = 0 + ) + ) +) = 2 +---- +logical_plan +Projection: t1.t1_int +--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) +----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +------Left Join: t1.t1_int = __scalar_sq_1.t2_int +--------TableScan: t1 projection=[t1_int] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: COUNT(*) + Int64(1) + Int64(1) AS cnt_plus_two, t2.t2_int, COUNT(*), __always_true +------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_int] + +query I rowsort +select t1.t1_int from t1 where ( + select cnt_plus_one + 1 as cnt_plus_two from ( + select cnt + 1 as cnt_plus_one from ( + select count(*) as cnt, sum(t2_int) s from t2 where t1.t1_int = t2.t2_int having cnt = 0 + ) + ) +) = 2 +---- +2 +4 + +#correlated_scalar_subquery_count_agg_in_case_when +query TT +explain +select t1.t1_int from t1 where + (select case when count(*) = 1 then null else count(*) end as cnt from t2 where t2.t2_int = t1.t1_int) = 0 +---- +logical_plan +Projection: t1.t1_int +--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.cnt END = Int64(0) +----Projection: t1.t1_int, __scalar_sq_1.cnt, __scalar_sq_1.__always_true +------Left Join: t1.t1_int = __scalar_sq_1.t2_int +--------TableScan: t1 projection=[t1_int] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: CASE WHEN COUNT(*) = Int64(1) THEN Int64(NULL) ELSE COUNT(*) END AS cnt, t2.t2_int, __always_true +------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------------TableScan: t2 projection=[t2_int] + + +query I rowsort +select t1.t1_int from t1 where + (select case when count(*) = 1 then null else count(*) end as cnt from t2 where t2.t2_int = t1.t1_int) = 0 +---- +2 +4 + +query B rowsort +select t1_int > (select avg(t1_int) from t1) from t1 +---- +false +false +true +true + + +# issue: https://github.com/apache/arrow-datafusion/issues/7027 +query TTTT rowsort +SELECT * FROM + (VALUES ('catan-prod1-daily', 'success')) as jobs(cron_job_name, status) + JOIN + (VALUES ('catan-prod1-daily', 'high')) as severity(cron_job_name, level) + ON (severity.cron_job_name = jobs.cron_job_name); +---- +catan-prod1-daily success catan-prod1-daily high + +##correlated_scalar_subquery_sum_agg_bug +#query TT +#explain +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#logical_plan +#Projection: t1.t1_int +#--Inner Join: t1.t1_id = __scalar_sq_1.t2_id +#----TableScan: t1 projection=[t1_id, t1_int] +#----SubqueryAlias: __scalar_sq_1 +#------Projection: t2.t2_id +#--------Filter: SUM(t2.t2_int) IS NULL +#----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] +#------------TableScan: t2 projection=[t2_id, t2_int] + +#query I rowsort +#select t1.t1_int from t1 where +# (select sum(t2_int) is null from t2 where t1.t1_id = t2.t2_id) +#---- +#2 +#3 +#4 + +statement ok +create table t(a bigint); + +# Result of query below shouldn't depend on +# number of optimization passes +# See issue: https://github.com/apache/arrow-datafusion/issues/8296 +statement ok +set datafusion.optimizer.max_passes = 1; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] + +statement ok +set datafusion.optimizer.max_passes = 3; + +query TT +explain select a/2, a/2 + 1 from t +---- +logical_plan +Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) +--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +----TableScan: t projection=[a] diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt similarity index 53% rename from datafusion/core/tests/sqllogictests/test_files/timestamps.slt rename to datafusion/sqllogictest/test_files/timestamps.slt index baf1f4d5b9de3..71b6ddf33f39a 100644 --- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -15,6 +15,38 @@ # specific language governing permissions and limitations # under the License. +########## +## Common timestamp data +# +# ts_data: Int64 nanosecods +# ts_data_nanos: Timestamp(Nanosecond, None) +# ts_data_micros: Timestamp(Microsecond, None) +# ts_data_millis: Timestamp(Millisecond, None) +# ts_data_secs: Timestamp(Second, None) +########## + +# Create timestamp tables with different precisions but the same logical values + +statement ok +create table ts_data(ts bigint, value int) as values + (1599572549190855123, 1), + (1599568949190855123, 2), + (1599565349190855123, 3); + +statement ok +create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(Nanosecond, None)') as ts, value from ts_data; + +statement ok +create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, None)') as ts, value from ts_data; + +statement ok +create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(Millisecond, None)') as ts, value from ts_data; + +statement ok +create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; + + + ########## ## Timestamp Handling Tests ########## @@ -68,6 +100,40 @@ select * from foo where ts != '2000-02-01T00:00:00'; statement ok drop table foo; + +########## +## Timezone Handling Tests +########## + +statement ok +SET TIME ZONE = '+08' + +# should use execution timezone +query P +SELECT TIMESTAMPTZ '2000-01-01T01:01:01' +---- +2000-01-01T01:01:01+08:00 + +# casts return timezone to use execution timezone (same as postgresql) +query P +SELECT TIMESTAMPTZ '2000-01-01T01:01:01+07:00' +---- +2000-01-01T02:01:01+08:00 + +query P +SELECT TIMESTAMPTZ '2000-01-01T01:01:01Z' +---- +2000-01-01T09:01:01+08:00 + +statement ok +SET TIME ZONE = '+00' + +query P +SELECT TIMESTAMPTZ '2000-01-01T01:01:01' +---- +2000-01-01T01:01:01Z + + ########## ## to_timestamp tests ########## @@ -105,29 +171,6 @@ SELECT to_timestamp_seconds(ts / 1000) FROM t1 LIMIT 3 2009-03-01T00:01:00 2009-04-01T00:00:00 -statement error DataFusion error: Execution error: Table 'ts' doesn't exist\. -drop table ts; - -# Create timestamp tables with different precisions but the same logical values - -statement ok -create table ts_data(ts bigint, value int) as values - (1599572549190855000, 1), - (1599568949190855000, 2), - (1599565349190855000, 3); - -statement ok -create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(Nanosecond, None)') as ts, value from ts_data; - -statement ok -create table ts_data_micros as select arrow_cast(ts / 1000, 'Timestamp(Microsecond, None)') as ts, value from ts_data; - -statement ok -create table ts_data_millis as select arrow_cast(ts / 1000000, 'Timestamp(Millisecond, None)') as ts, value from ts_data; - -statement ok -create table ts_data_secs as select arrow_cast(ts / 1000000000, 'Timestamp(Second, None)') as ts, value from ts_data; - # query_cast_timestamp_nanos_to_others @@ -174,7 +217,7 @@ SELECT to_timestamp_micros(ts) FROM ts_data_secs LIMIT 3 # to nanos query P -SELECT to_timestamp(ts) FROM ts_data_secs LIMIT 3 +SELECT to_timestamp_nanos(ts) FROM ts_data_secs LIMIT 3 ---- 2020-09-08T13:42:29 2020-09-08T12:42:29 @@ -201,7 +244,7 @@ SELECT to_timestamp_seconds(ts) FROM ts_data_micros LIMIT 3 2020-09-08T11:42:29 -# Original column is micros, convert to nanos and check timestamp +# Original column is micros, convert to seconds and check timestamp query P SELECT to_timestamp(ts) FROM ts_data_micros LIMIT 3 @@ -223,7 +266,7 @@ SELECT from_unixtime(ts / 1000000000) FROM ts_data LIMIT 3; # to_timestamp query I -SELECT COUNT(*) FROM ts_data_nanos where ts > to_timestamp('2020-09-08T12:00:00+00:00') +SELECT COUNT(*) FROM ts_data_nanos where ts > timestamp '2020-09-08T12:00:00+00:00' ---- 2 @@ -248,6 +291,35 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08T12 ---- 2 + +# to_timestamp float inputs + +query PPP +SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; +---- +1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 + +query PPP +SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; +---- +1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 + +query PPP +SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; +---- +1970-01-01T00:00:00 1970-01-01T00:00:00 1970-01-01T00:00:00 + +query PPP +SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; +---- +1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 + +query PPP +SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; +---- +1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 + + # from_unixtime # 1599566400 is '2020-09-08T12:00:00+00:00' @@ -261,9 +333,9 @@ SELECT COUNT(*) FROM ts_data_secs where ts > from_unixtime(1599566400) query P rowsort SELECT DISTINCT ts FROM ts_data_nanos; ---- -2020-09-08T11:42:29.190855 -2020-09-08T12:42:29.190855 -2020-09-08T13:42:29.190855 +2020-09-08T11:42:29.190855123 +2020-09-08T12:42:29.190855123 +2020-09-08T13:42:29.190855123 query I @@ -332,7 +404,7 @@ set datafusion.optimizer.skip_failed_rules = true query P select to_timestamp(a) from (select to_timestamp(1) as a) A; ---- -1970-01-01T00:00:00.000000001 +1970-01-01T00:00:01 # cast_to_timestamp_seconds_twice query P @@ -340,7 +412,6 @@ select to_timestamp_seconds(a) from (select to_timestamp_seconds(1) as a)A ---- 1970-01-01T00:00:01 - # cast_to_timestamp_millis_twice query P select to_timestamp_millis(a) from (select to_timestamp_millis(1) as a)A; @@ -353,11 +424,17 @@ select to_timestamp_micros(a) from (select to_timestamp_micros(1) as a)A; ---- 1970-01-01T00:00:00.000001 +# cast_to_timestamp_nanos_twice +query P +select to_timestamp_nanos(a) from (select to_timestamp_nanos(1) as a)A; +---- +1970-01-01T00:00:00.000000001 + # to_timestamp_i32 query P select to_timestamp(cast (1 as int)); ---- -1970-01-01T00:00:00.000000001 +1970-01-01T00:00:01 # to_timestamp_micros_i32 query P @@ -365,6 +442,12 @@ select to_timestamp_micros(cast (1 as int)); ---- 1970-01-01T00:00:00.000001 +# to_timestamp_nanos_i32 +query P +select to_timestamp_nanos(cast (1 as int)); +---- +1970-01-01T00:00:00.000000001 + # to_timestamp_millis_i32 query P select to_timestamp_millis(cast (1 as int)); @@ -377,25 +460,14 @@ select to_timestamp_seconds(cast (1 as int)); ---- 1970-01-01T00:00:01 -statement ok -drop table ts_data - -statement ok -drop table ts_data_nanos - -statement ok -drop table ts_data_micros - -statement ok -drop table ts_data_millis - -statement ok -drop table ts_data_secs - ########## ## test date_bin function ########## +# invalid second arg type +query error DataFusion error: Error during planning: No function matches the given name and argument types 'date_bin\(Interval\(MonthDayNano\), Int64, Timestamp\(Nanosecond, None\)\)'\. +SELECT DATE_BIN(INTERVAL '0 second', 25, TIMESTAMP '1970-01-01T00:00:00Z') + # not support interval 0 statement error Execution error: DATE_BIN stride must be non-zero SELECT DATE_BIN(INTERVAL '0 second', TIMESTAMP '2022-08-03 14:38:50.000000006Z', TIMESTAMP '1970-01-01T00:00:00Z') @@ -843,6 +915,16 @@ SELECT DATE_TRUNC('YEAR', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-01-01T00:00:00 +query P +SELECT DATE_TRUNC('year', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('YEAR', NULL); +---- +NULL + query P SELECT DATE_TRUNC('quarter', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -853,6 +935,16 @@ SELECT DATE_TRUNC('QUARTER', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-07-01T00:00:00 +query P +SELECT DATE_TRUNC('quarter', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('QUARTER', NULL); +---- +NULL + query P SELECT DATE_TRUNC('month', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -863,6 +955,16 @@ SELECT DATE_TRUNC('MONTH', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-01T00:00:00 +query P +SELECT DATE_TRUNC('month', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('MONTH', NULL); +---- +NULL + query P SELECT DATE_TRUNC('week', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -873,6 +975,16 @@ SELECT DATE_TRUNC('WEEK', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-01T00:00:00 +query P +SELECT DATE_TRUNC('week', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('WEEK', NULL); +---- +NULL + query P SELECT DATE_TRUNC('day', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -883,6 +995,16 @@ SELECT DATE_TRUNC('DAY', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-03T00:00:00 +query P +SELECT DATE_TRUNC('day', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('DAY', NULL); +---- +NULL + query P SELECT DATE_TRUNC('hour', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -893,6 +1015,16 @@ SELECT DATE_TRUNC('HOUR', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-03T14:00:00 +query P +SELECT DATE_TRUNC('hour', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('HOUR', NULL); +---- +NULL + query P SELECT DATE_TRUNC('minute', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -903,6 +1035,16 @@ SELECT DATE_TRUNC('MINUTE', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-03T14:38:00 +query P +SELECT DATE_TRUNC('minute', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('MINUTE', NULL); +---- +NULL + query P SELECT DATE_TRUNC('second', TIMESTAMP '2022-08-03 14:38:50Z'); ---- @@ -913,6 +1055,144 @@ SELECT DATE_TRUNC('SECOND', TIMESTAMP '2022-08-03 14:38:50Z'); ---- 2022-08-03T14:38:50 +query P +SELECT DATE_TRUNC('second', NULL); +---- +NULL + +query P +SELECT DATE_TRUNC('SECOND', NULL); +---- +NULL + +# Test date trunc on different timestamp types and ensure types are consistent +query TP rowsort +SELECT 'ts_data_nanos', DATE_TRUNC('day', ts) FROM ts_data_nanos + UNION ALL +SELECT 'ts_data_micros', DATE_TRUNC('day', ts) FROM ts_data_micros + UNION ALL +SELECT 'ts_data_millis', DATE_TRUNC('day', ts) FROM ts_data_millis + UNION ALL +SELECT 'ts_data_secs', DATE_TRUNC('day', ts) FROM ts_data_secs +---- +ts_data_micros 2020-09-08T00:00:00 +ts_data_micros 2020-09-08T00:00:00 +ts_data_micros 2020-09-08T00:00:00 +ts_data_millis 2020-09-08T00:00:00 +ts_data_millis 2020-09-08T00:00:00 +ts_data_millis 2020-09-08T00:00:00 +ts_data_nanos 2020-09-08T00:00:00 +ts_data_nanos 2020-09-08T00:00:00 +ts_data_nanos 2020-09-08T00:00:00 +ts_data_secs 2020-09-08T00:00:00 +ts_data_secs 2020-09-08T00:00:00 +ts_data_secs 2020-09-08T00:00:00 + +# Test date trun on different granularity +query TP rowsort +SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_nanos + UNION ALL +SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_nanos + UNION ALL +SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_nanos + UNION ALL +SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_nanos +---- +microsecond 2020-09-08T11:42:29.190855 +microsecond 2020-09-08T12:42:29.190855 +microsecond 2020-09-08T13:42:29.190855 +millisecond 2020-09-08T11:42:29.190 +millisecond 2020-09-08T12:42:29.190 +millisecond 2020-09-08T13:42:29.190 +minute 2020-09-08T11:42:00 +minute 2020-09-08T12:42:00 +minute 2020-09-08T13:42:00 +second 2020-09-08T11:42:29 +second 2020-09-08T12:42:29 +second 2020-09-08T13:42:29 + +query TP rowsort +SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_micros + UNION ALL +SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_micros + UNION ALL +SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_micros + UNION ALL +SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_micros +---- +microsecond 2020-09-08T11:42:29.190855 +microsecond 2020-09-08T12:42:29.190855 +microsecond 2020-09-08T13:42:29.190855 +millisecond 2020-09-08T11:42:29.190 +millisecond 2020-09-08T12:42:29.190 +millisecond 2020-09-08T13:42:29.190 +minute 2020-09-08T11:42:00 +minute 2020-09-08T12:42:00 +minute 2020-09-08T13:42:00 +second 2020-09-08T11:42:29 +second 2020-09-08T12:42:29 +second 2020-09-08T13:42:29 + +query TP rowsort +SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_millis + UNION ALL +SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_millis + UNION ALL +SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_millis + UNION ALL +SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_millis +---- +microsecond 2020-09-08T11:42:29.190 +microsecond 2020-09-08T12:42:29.190 +microsecond 2020-09-08T13:42:29.190 +millisecond 2020-09-08T11:42:29.190 +millisecond 2020-09-08T12:42:29.190 +millisecond 2020-09-08T13:42:29.190 +minute 2020-09-08T11:42:00 +minute 2020-09-08T12:42:00 +minute 2020-09-08T13:42:00 +second 2020-09-08T11:42:29 +second 2020-09-08T12:42:29 +second 2020-09-08T13:42:29 + +query TP rowsort +SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_secs + UNION ALL +SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_secs + UNION ALL +SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_secs + UNION ALL +SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_secs +---- +microsecond 2020-09-08T11:42:29 +microsecond 2020-09-08T12:42:29 +microsecond 2020-09-08T13:42:29 +millisecond 2020-09-08T11:42:29 +millisecond 2020-09-08T12:42:29 +millisecond 2020-09-08T13:42:29 +minute 2020-09-08T11:42:00 +minute 2020-09-08T12:42:00 +minute 2020-09-08T13:42:00 +second 2020-09-08T11:42:29 +second 2020-09-08T12:42:29 +second 2020-09-08T13:42:29 + + +# test date trunc on different timestamp scalar types and ensure they are consistent +query P rowsort +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)')) as ts + UNION ALL +SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Nanosecond, None)')) as ts + UNION ALL +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Microsecond, None)')) as ts + UNION ALL +SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Millisecond, None)')) as ts +---- +2023-08-03T00:00:00 +2023-08-03T00:00:00 +2023-08-03T14:38:50 +2023-08-03T14:38:50 + # Demonstrate that strings are automatically coerced to timestamps (don't use TIMESTAMP) @@ -980,10 +1260,10 @@ create table bar (val int, i1 interval, i2 interval) as values query I? SELECT val, ts1 - ts2 FROM foo ORDER BY ts2 - ts1; ---- -4 0 years 0 mons -15250 days -13 hours -28 mins -44.999876545 secs -3 0 years 0 mons 15952 days 23 hours 22 mins 12.667123455 secs -2 0 years 0 mons 8406 days 1 hours 1 mins 54.877123455 secs -1 0 years 0 mons 53 days 16 hours 0 mins 20.000000024 secs +3 15952 days 23 hours 22 mins 12.667123455 secs +2 8406 days 1 hours 1 mins 54.877123455 secs +1 53 days 16 hours 0 mins 20.000000024 secs +4 -15250 days -13 hours -28 mins -44.999876545 secs # Interval - Interval query ? @@ -1031,7 +1311,7 @@ SELECT ts1 + i FROM foo; 2003-07-12T01:31:15.000123463 # Timestamp + Timestamp => error -query error DataFusion error: type_coercion\ncaused by\nInternal error: Unsupported operation Plus between Timestamp\(Nanosecond, None\) and Timestamp\(Nanosecond, None\)\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +query error DataFusion error: Error during planning: Cannot get result type for temporal operation Timestamp\(Nanosecond, None\) \+ Timestamp\(Nanosecond, None\): Invalid argument error: Invalid timestamp arithmetic operation: Timestamp\(Nanosecond, None\) \+ Timestamp\(Nanosecond, None\) SELECT ts1 + ts2 FROM foo; @@ -1039,17 +1319,19 @@ FROM foo; query ? SELECT '2000-01-01T00:00:00'::timestamp - '2000-01-01T00:00:00'::timestamp; ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000000000 secs +0 days 0 hours 0 mins 0.000000000 secs # large timestamp - small timestamp query ? SELECT '2000-01-01T00:00:00'::timestamp - '2010-01-01T00:00:00'::timestamp; ---- -0 years 0 mons -3653 days 0 hours 0 mins 0.000000000 secs +-3653 days 0 hours 0 mins 0.000000000 secs # Interval - Timestamp => error -statement error DataFusion error: type_coercion\ncaused by\nError during planning: Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to -SELECT i - ts1 from FOO; +# statement error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types +# TODO: This query should raise error +# query P +# SELECT i - ts1 from FOO; statement ok drop table foo; @@ -1078,3 +1360,507 @@ SELECT ; ---- true false true true + + + +########## +## Common timestamp data +########## + +statement ok +drop table ts_data + +statement ok +drop table ts_data_nanos + +statement ok +drop table ts_data_micros + +statement ok +drop table ts_data_millis + +statement ok +drop table ts_data_secs + + + +########## +## Timezone impact on builtin scalar functions +# +# server time = +07 +########## + +statement ok +set timezone to '+07'; + +# postgresql: 2000-01-01 01:00:00+07 +query P +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as ts +---- +2000-01-01T01:00:00+07:00 + +# postgresql: 2000-01-01 00:00:00+07 +query P +SELECT date_trunc('day', TIMESTAMPTZ '2000-01-01T01:01:01') as ts +---- +2000-01-01T00:00:00+07:00 + +# postgresql: 2000-01-01 08:00:00+07 +query P +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T01:01:01Z') as ts +---- +2000-01-01T08:00:00+07:00 + +# postgresql: 2000-01-01 00:00:00+07 +query P +SELECT date_trunc('day', TIMESTAMPTZ '2000-01-01T01:01:01Z') as ts +---- +2000-01-01T00:00:00+07:00 + +# postgresql: 2022-01-01 00:00:00+07 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00', TIMESTAMPTZ '2020-01-01') +---- +2022-01-01T00:00:00+07:00 + +# postgresql: 2022-01-02 00:00:00+07 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00Z', TIMESTAMPTZ '2020-01-01') +---- +2022-01-02T00:00:00+07:00 + +# coerce TIMESTAMP to TIMESTAMPTZ +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00Z', TIMESTAMP '2020-01-01') +---- +2022-01-01T07:00:00+07:00 + +# postgresql: 1 +query R +SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as part +---- +1 + +# postgresql: 8 +query R +SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01Z') as part +---- +8 + + + +########## +## Timezone impact on builtin scalar functions +# +# server time = UTC +########## + +statement ok +set timezone to '+00'; + +# postgresql: 2000-01-01T01:00:00+00 +query P +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as ts +---- +2000-01-01T01:00:00Z + +# postgresql: 2000-01-01T00:00:00+00 +query P +SELECT date_trunc('day', TIMESTAMPTZ '2000-01-01T01:01:01') as ts +---- +2000-01-01T00:00:00Z + +# postgresql: 1999-12-31T18:00:00+00 +query P +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T01:01:01+07') as ts +---- +1999-12-31T18:00:00Z + +# postgresql: 1999-12-31T00:00:00+00 +query P +SELECT date_trunc('day', TIMESTAMPTZ '2000-01-01T01:01:01+07') as ts +---- +1999-12-31T00:00:00Z + +# postgresql: 2022-01-01 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 20:10:00', TIMESTAMPTZ '2020-01-01') +---- +2022-01-01T00:00:00Z + +# postgresql: 2021-12-31 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', TIMESTAMPTZ '2020-01-01') +---- +2021-12-31T00:00:00Z + +# postgresql: 2021-12-31 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01') +---- +2021-12-31T00:00:00Z + +# postgresql: 2021-12-31 00:00:00+00 +query P +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01T00:00:00Z') +---- +2021-12-31T00:00:00Z + +# postgresql: 2021-12-31 18:00:00+00 +query P +SELECT date_bin('2 hour', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01') +---- +2021-12-31T18:00:00Z + +# postgresql: 2021-12-31 18:00:00+00 +query P +SELECT date_bin('2 hour', TIMESTAMPTZ '2022-01-01 01:10:00+07', '2020-01-01T00:00:00Z') +---- +2021-12-31T18:00:00Z + +# postgresql: 1 +query R +SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01') as part +---- +1 + +# postgresql: 18 +query R +SELECT date_part('hour', TIMESTAMPTZ '2000-01-01T01:01:01+07') as part +---- +18 + + + +########## +## Timezone impact on builtin scalar functions +# +# irregular offsets +########## + +query P rowsort +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00+00:45') as ts_irregular_offset + UNION ALL +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00+00:30') as ts_irregular_offset + UNION ALL +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00+00:15') as ts_irregular_offset + UNION ALL +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00-00:15') as ts_irregular_offset + UNION ALL +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00-00:30') as ts_irregular_offset + UNION ALL +SELECT date_trunc('hour', TIMESTAMPTZ '2000-01-01T00:00:00-00:45') as ts_irregular_offset +---- +1999-12-31T23:00:00Z +1999-12-31T23:00:00Z +1999-12-31T23:00:00Z +2000-01-01T00:00:00Z +2000-01-01T00:00:00Z +2000-01-01T00:00:00Z + +query P rowsort +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 00:00:00+00:30', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 00:00:00+00:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 00:00:00-00:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 day', TIMESTAMPTZ '2022-01-01 00:00:00-00:30', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset +---- +2021-12-31T00:00:00Z +2021-12-31T00:00:00Z +2022-01-01T00:00:00Z +2022-01-01T00:00:00Z + +query P rowsort +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00+01:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00+00:45', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00+00:30', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00+00:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00-00:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00-00:30', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00-00:45', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset + UNION ALL +SELECT date_bin('1 hour', TIMESTAMPTZ '2022-01-01 00:00:00-01:15', TIMESTAMPTZ '2020-01-01') as ts_irregular_offset +---- +2021-12-31T22:00:00Z +2021-12-31T23:00:00Z +2021-12-31T23:00:00Z +2021-12-31T23:00:00Z +2022-01-01T00:00:00Z +2022-01-01T00:00:00Z +2022-01-01T00:00:00Z +2022-01-01T01:00:00Z + + + +########## +## Timezone acceptance bounds +# +# standard formats +########## + +query P +SELECT TIMESTAMPTZ '2022-01-01 01:10:00' as rfc3339_no_tz +---- +2022-01-01T01:10:00Z + +# +00, +00:00, +0000 +# +01, +01:00, +0100 +# -01, -01:00, -0100 +query P rowsort +SELECT TIMESTAMPTZ '2022-01-01 01:10:00+00' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00+00:00' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00+0000' as rfc3339_offset_tz + UNION ALL + SELECT TIMESTAMPTZ '2022-01-01 01:10:00+01' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00+01:00' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00+0100' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00-01' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00-01:00' as rfc3339_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00-0100' as rfc3339_offset_tz +---- +2022-01-01T00:10:00Z +2022-01-01T00:10:00Z +2022-01-01T00:10:00Z +2022-01-01T01:10:00Z +2022-01-01T01:10:00Z +2022-01-01T01:10:00Z +2022-01-01T02:10:00Z +2022-01-01T02:10:00Z +2022-01-01T02:10:00Z + +query P +SELECT TIMESTAMPTZ '2022-01-01T01:10:00' as iso8601_no_tz +---- +2022-01-01T01:10:00Z + +# +00, +00:00, +0000 +# +01, +01:00, +0100 +# -01, -01:00, -0100 +query P rowsort +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+00' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+00:00' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+0000' as iso8601_offset_tz + UNION ALL + SELECT TIMESTAMPTZ '2022-01-01T01:10:00+01' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+01:00' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+0100' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00-01' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00-01:00' as iso8601_offset_tz + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01T01:10:00-0100' as iso8601_offset_tz +---- +2022-01-01T00:10:00Z +2022-01-01T00:10:00Z +2022-01-01T00:10:00Z +2022-01-01T01:10:00Z +2022-01-01T01:10:00Z +2022-01-01T01:10:00Z +2022-01-01T02:10:00Z +2022-01-01T02:10:00Z +2022-01-01T02:10:00Z + +statement error +SELECT TIMESTAMPTZ '2023‐W38‐5' as iso8601_week_designation + +statement error +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+Foo' as bad_tz + +statement error +SELECT TIMESTAMPTZ '2022-01-01T01:10:00+42:00' as bad_tz + +query P rowsort +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 GMT' as ts_gmt +---- +2022-01-01T01:10:00Z + +statement error +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 GMT-1' as ts_gmt_offset + +# will not accept non-GMT geo abbr +# postgresql: accepts +statement error +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 AEST' + +# ok to use geo longform +query P rowsort +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Australia/Sydney' as ts_geo + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Antarctica/Vostok' as ts_geo + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 Africa/Johannesburg' as ts_geo + UNION ALL +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los_Angeles' as ts_geo +---- +2021-12-31T14:10:00Z +2021-12-31T19:10:00Z +2021-12-31T23:10:00Z +2022-01-01T09:10:00Z + +# geo longform timezones need whitespace converted to underscore +statement error +SELECT TIMESTAMPTZ '2022-01-01 01:10:00 America/Los Angeles' as ts_geo + +statement error +SELECT TIMESTAMPTZ 'Sat, 1 Jan 2022 01:10:00 GMT' as rfc1123 + + + +########## +## Timezone acceptance bounds +# +# daylight savings +########## + +# will not accept daylight savings designations as geo abbr (because not accepting geo abbr) +# postgresql: accepts +statement error +SELECT TIMESTAMPTZ '2023-03-12 02:00:00 EDT' + +# ok to use geo longform +query P +SELECT TIMESTAMPTZ '2023-03-11 02:00:00 America/Los_Angeles' as ts_geo +---- +2023-03-11T10:00:00Z + +# will error if provide geo longform with time not possible due to daylight savings +# Arrow error: Parser error: Error parsing timestamp from '2023-03-12 02:00:00 America/Los_Angeles': error computing timezone offset +# postgresql: accepts +statement error +SELECT TIMESTAMPTZ '2023-03-12 02:00:00 America/Los_Angeles' as ts_geo + + + +########## +## Timezone column tests +########## + +# create a table with a non-UTC time zone. +statement ok +SET TIME ZONE = '+05:00' + +statement ok +CREATE TABLE foo (time TIMESTAMPTZ) AS VALUES + ('2020-01-01T00:00:00+05:00'), + ('2020-01-01T01:00:00+05:00'), + ('2020-01-01T02:00:00+05:00'), + ('2020-01-01T03:00:00+05:00') + +statement ok +SET TIME ZONE = '+00' + +# verify column type +query T +SELECT arrow_typeof(time) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + +# check date_trunc +query P +SELECT date_trunc('day', time) FROM foo +---- +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 + +# verify date_trunc column type +query T +SELECT arrow_typeof(date_trunc('day', time)) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + +# check date_bin +query P +SELECT date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00') FROM foo +---- +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 +2020-01-01T00:00:00+05:00 + +# verify date_trunc column type +query T +SELECT arrow_typeof(date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00')) FROM foo LIMIT 1 +---- +Timestamp(Nanosecond, Some("+05:00")) + + +# timestamp comparison with and without timezone +query B +SELECT TIMESTAMPTZ '2022-01-01 20:10:00Z' = TIMESTAMP '2020-01-01' +---- +false + +query B +SELECT TIMESTAMPTZ '2020-01-01 00:00:00Z' = TIMESTAMP '2020-01-01' +---- +true + +# verify timestamp cast with integer input +query PPPPPP +SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timestamp(1), to_timestamp(-1), to_timestamp(0-1) +---- +NULL 1970-01-01T00:00:00 2031-01-19T23:33:25 1970-01-01T00:00:01 1969-12-31T23:59:59 1969-12-31T23:59:59 + +# verify timestamp syntax stlyes are consistent +query BBBBBBBBBBBBB +SELECT to_timestamp(null) is null as c1, + null::timestamp is null as c2, + cast(null as timestamp) is null as c3, + to_timestamp(0) = 0::timestamp as c4, + to_timestamp(1926632005) = 1926632005::timestamp as c5, + to_timestamp(1) = 1::timestamp as c6, + to_timestamp(-1) = -1::timestamp as c7, + to_timestamp(0-1) = (0-1)::timestamp as c8, + to_timestamp(0) = cast(0 as timestamp) as c9, + to_timestamp(1926632005) = cast(1926632005 as timestamp) as c10, + to_timestamp(1) = cast(1 as timestamp) as c11, + to_timestamp(-1) = cast(-1 as timestamp) as c12, + to_timestamp(0-1) = cast(0-1 as timestamp) as c13 +---- +true true true true true true true true true true true true true + +# verify timestamp output types +query TTT +SELECT arrow_typeof(to_timestamp(1)), arrow_typeof(to_timestamp(null)), arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) +---- +Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) Timestamp(Nanosecond, None) + +# verify timestamp output types using timestamp literal syntax +query BBBBBB +SELECT arrow_typeof(to_timestamp(1)) = arrow_typeof(1::timestamp) as c1, + arrow_typeof(to_timestamp(null)) = arrow_typeof(null::timestamp) as c2, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof('2023-01-10 12:34:56.000'::timestamp) as c3, + arrow_typeof(to_timestamp(1)) = arrow_typeof(cast(1 as timestamp)) as c4, + arrow_typeof(to_timestamp(null)) = arrow_typeof(cast(null as timestamp)) as c5, + arrow_typeof(to_timestamp('2023-01-10 12:34:56.000')) = arrow_typeof(cast('2023-01-10 12:34:56.000' as timestamp)) as c6 +---- +true true true true true true + +# known issues. currently overflows (expects default precision to be microsecond instead of nanoseconds. Work pending) +#verify extreme values +#query PPPPPPPP +#SELECT to_timestamp(-62125747200), to_timestamp(1926632005177), -62125747200::timestamp, 1926632005177::timestamp, cast(-62125747200 as timestamp), cast(1926632005177 as timestamp) +#---- +#0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 0001-04-25T00:00:00 +63022-07-16T12:59:37 diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt new file mode 100644 index 0000000000000..5eba20fdc655f --- /dev/null +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -0,0 +1,232 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for development + +statement ok +create table topk(x int) as values (10), (2), (3), (0), (5), (4), (3), (2), (1), (3), (8); + +query I +select * from topk order by x; +---- +0 +1 +2 +2 +3 +3 +3 +4 +5 +8 +10 + +query I +select * from topk order by x limit 3; +---- +0 +1 +2 + +query I +select * from topk order by x desc limit 3; +---- +10 +8 +5 + + + + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TT +explain select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: aggregate_test_100.c13 DESC NULLS FIRST, fetch=5 +----TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[c13@12 DESC] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true + + + + +query T +select c13 from aggregate_test_100 ORDER BY c13; +---- +0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm +0keZ5G8BffGwgF2RwQD59TFzMStxCB +0og6hSkhbX8AC1ktFS4kounvTzy8Vo +1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO +2T3wSlHdEmASmO0xcXHnndkKEt6bz8 +3BEOHQsMEFZ58VcNTOJYShTBpAPzbt +4HX6feIvmNXBN7XGqgO4YVBkhu8GDI +4JznSdBajNWhu4hRQwjV1FjTTxY68i +52mKlRE3aHCBZtjECq6sY9OqVf8Dze +56MZa5O1hVtX4c5sbnCfxuX5kDChqI +6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ +6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW +6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE +6x93sxYioWuq5c9Kkk8oTAAORM7cH0 +802bgTGl6Bk5TlkPYYTxp5JkKyaYUA +8LIh0b6jmDGm87BmIyjdxNIpX4ugjD +90gAtmGEeIqUTbo1ZrxCvWtsseukXC +9UbObCsVkmYpJGcGrgfK90qOnwb2Lj +AFGCj7OWlEB5QfniEFgonMq90Tq5uH +ALuRhobVWbnQTTWZdSOk0iVe8oYFhW +Amn2K87Db5Es3dFQO9cw9cvpAM6h35 +AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz +BJqx5WokrmrrezZA0dUbleMYkG5U2O +BPtQMxnuSPpxMExYV9YkDa6cAN7GP3 +BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE +C2GT5KVyOPZpgKVl110TyZO0NcJ434 +DuJNG8tufSqW0ZstHqWj3aGvFLMg4A +EcCuckwsF3gV1Ecgmh5v4KM8g1ozif +ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU +F7NSTjWvQJyBburN7CXRUlbgp2dIrA +Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u +H5j5ZHy1FGesOAHjkQEDYCucbpKWRu +HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g +IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr +IZTkHMLvIKuiLjhDjYMmIHxh166we4 +Ig1QcuKsjHXkproePdERo2w0mYzIqd +JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ +JN0VclewmjwYlSl8386MlWv5rEhWCz +JafwVLSVk5AVoXFuzclesQ000EE2k1 +KJFcmTVjdkCMv94wYCtfHMFhzyRsmH +Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn +Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV +LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW +MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ +MeSTAXq8gVxVjbEjgkvU9YLte0X9uE +NEhyk8uIx4kEULJGa8qIyFjjBcP2G6 +O66j6PaYuZhEUtqV6fuU7TyjM2WxC5 +OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh +OPwBqCEK5PWTjWaiOyL45u2NLTaDWv +Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0 +Ow5PGpfTm4dXCfTDsXAOTatXRoAydR +QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv +QJYm7YRA3YetcBHI5wkMZeLXVmfuNy +QYlaIAnJA6r8rlAb6f59wcxvcPcWFf +RilTlL1tKkPOUFuzmLydHAVZwv1OGl +Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH +TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX +TtDKUZxzVxsq758G6AWPSYuZgVgbcl +VDhtJkYjAYPykCgOU9x3v7v3t4SO1a +VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4 +Vp3gmWunM5A7wOC9YW2JroFqTWjvTi +WHmjWk2AY4c6m7DA4GitUx6nmb1yYS +XemNcT1xp61xcM1Qz3wZ1VECCnq06O +Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK +aDxBtor7Icd9C5hnTvvw5NrIre740e +akiiY5N0I44CMwEnBL6RTBk7BRkxEj +b3b9esRhTzFEawbs6XhpKnD9ojutHB +bgK1r6v3BCTh0aejJUhkA1Hn6idXGp +cBGc0kSm32ylBDnxogG727C0uhZEYZ +cq4WSAIFwx3wwTUS5bp1wCe71R6U5I +dVdvo6nUD5FgCgsbOZLds28RyGTpnx +e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG +f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX +fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG +gTpyQnEODMcpsPnJMZC66gh33i3m0b +gpo8K5qtYePve6jyPt6xgJx4YOVjms +gxfHWUF8XgY2KdFxigxvNEXe2V2XMl +i6RQVXKUh7MzuGMDaNclUYnFUAireU +ioEncce3mPOXD2hWhpZpCPWGATG6GU +jQimhdepw3GKmioWUlVSWeBVRKFkY3 +l7uwDoTepWwnAP0ufqtHJS3CRi7RfP +lqhzgLsXZ8JhtpeeUWWNbMz8PHI705 +m6jD0LBIQWaMfenwRCTANI9eOdyyto +mhjME0zBHbrK6NMkytMTQzOssOa1gF +mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS +nYVJnVicpGRqKZibHyBAmtmzBXAFfT +oHJMNvWuunsIMIWFnYG31RCfkOo2V7 +oLZ21P2JEDooxV1pU31cIxQHEeeoLu +okOkcWflkNXIy4R8LzmySyY1EC3sYd +pLk3i59bZwd5KBZrI1FiweYTd5hteG +pTeu0WMjBRTaNRT15rLCuEh3tBJVc5 +qnPOOmslCJaT45buUisMRnM0rc77EK +t6fQUjJejPcjc04wHvHTPe55S65B4V +ukOiFGGFnQJDHFgZxHMpvhD3zybF0M +ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8 +waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs +wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +xipQ93429ksjNcXPX5326VSg1xJZcW +y7C453hRWd4E7ImjNDWlpexB8nUqjh +ydkwycaISlYSlEq3TlkS2m15I2pcp8 + + +query TIIIIIIIITRRT +select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +a 4 -38 20744 762932956 308913475857409919 7 45465 1787652631 878137512938218976 0.7459874 0.021825780392 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 13630 -1991133944 1184110014998006843 220 2986 225513085 9634106610243643486 0.89651865 0.164088254508 y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 -12056 -1090239422 9011500141803970147 238 4168 2013662838 12565360638488684051 0.6694766 0.391444365692 xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 25590 1188089983 3090286296481837049 241 832 3542840110 5885937420286765261 0.41980565 0.215354023438 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 12636 794623392 2909750622865366631 15 24022 2669374863 4776679784701509574 0.29877836 0.253725340799 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs + + + +## -- make tiny batches to trigger batch compaction +statement ok +set datafusion.execution.batch_size = 2 + +query TIIIIIIIITRRT +select * from aggregate_test_100 ORDER BY c13 desc limit 5; +---- +a 4 -38 20744 762932956 308913475857409919 7 45465 1787652631 878137512938218976 0.7459874 0.021825780392 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 13630 -1991133944 1184110014998006843 220 2986 225513085 9634106610243643486 0.89651865 0.164088254508 y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 -12056 -1090239422 9011500141803970147 238 4168 2013662838 12565360638488684051 0.6694766 0.391444365692 xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 25590 1188089983 3090286296481837049 241 832 3542840110 5885937420286765261 0.41980565 0.215354023438 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 12636 794623392 2909750622865366631 15 24022 2669374863 4776679784701509574 0.29877836 0.253725340799 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs + + +## make an example for dictionary encoding + +statement ok +create table dict as select c1, c2, c3, c13, arrow_cast(c13, 'Dictionary(Int32, Utf8)') as c13_dict from aggregate_test_100; + +query TIIT? +select * from dict order by c13 desc limit 5; +---- +a 4 -38 ydkwycaISlYSlEq3TlkS2m15I2pcp8 ydkwycaISlYSlEq3TlkS2m15I2pcp8 +d 1 -98 y7C453hRWd4E7ImjNDWlpexB8nUqjh y7C453hRWd4E7ImjNDWlpexB8nUqjh +e 2 52 xipQ93429ksjNcXPX5326VSg1xJZcW xipQ93429ksjNcXPX5326VSg1xJZcW +d 1 -72 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS wwXqSGKLyBQyPkonlzBNYUJTCo4LRS +a 1 -5 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/create_tables.slt.part b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part similarity index 81% rename from datafusion/core/tests/sqllogictests/test_files/tpch/create_tables.slt.part rename to datafusion/sqllogictest/test_files/tpch/create_tables.slt.part index 007cfd7062b75..2f5e2d5a76163 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/create_tables.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part @@ -31,7 +31,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS supplier ( s_acctbal DECIMAL(15, 2), s_comment VARCHAR, s_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/supplier.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/supplier.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS part ( @@ -45,7 +45,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS part ( p_retailprice DECIMAL(15, 2), p_comment VARCHAR, p_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/part.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/part.tbl'; statement ok @@ -56,7 +56,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS partsupp ( ps_supplycost DECIMAL(15, 2), ps_comment VARCHAR, ps_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/partsupp.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/partsupp.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( @@ -69,7 +69,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_mktsegment VARCHAR, c_comment VARCHAR, c_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/customer.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/customer.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS orders ( @@ -83,7 +83,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS orders ( o_shippriority INTEGER, o_comment VARCHAR, o_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/orders.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/orders.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( @@ -104,7 +104,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( l_shipmode VARCHAR, l_comment VARCHAR, l_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/lineitem.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/lineitem.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS nation ( @@ -113,7 +113,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS nation ( n_regionkey BIGINT, n_comment VARCHAR, n_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/nation.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/nation.tbl'; statement ok CREATE EXTERNAL TABLE IF NOT EXISTS region ( @@ -121,4 +121,4 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'tests/sqllogictests/test_files/tpch/data/region.tbl'; +) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/region.tbl'; diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/drop_tables.slt.part b/datafusion/sqllogictest/test_files/tpch/drop_tables.slt.part similarity index 99% rename from datafusion/core/tests/sqllogictests/test_files/tpch/drop_tables.slt.part rename to datafusion/sqllogictest/test_files/tpch/drop_tables.slt.part index 6e6acee5b5769..35faf3719d9ff 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/drop_tables.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/drop_tables.slt.part @@ -1,5 +1,3 @@ - - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q1.slt.part b/datafusion/sqllogictest/test_files/tpch/q1.slt.part similarity index 88% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q1.slt.part rename to datafusion/sqllogictest/test_files/tpch/q1.slt.part index 8a0e5fd3228b2..3086ab487aaa0 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q1.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part @@ -41,25 +41,23 @@ explain select ---- logical_plan Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST ---Projection: lineitem.l_returnflag, lineitem.l_linestatus, SUM(lineitem.l_quantity) AS sum_qty, SUM(lineitem.l_extendedprice) AS sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(UInt8(1)) AS count_order -----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(UInt8(1))]] +--Projection: lineitem.l_returnflag, lineitem.l_linestatus, SUM(lineitem.l_quantity) AS sum_qty, SUM(lineitem.l_extendedprice) AS sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order +----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(UInt8(1)) AS COUNT(*)]] ------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus --------Filter: lineitem.l_shipdate <= Date32("10471") ----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("10471")] physical_plan SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST] --SortExec: expr=[l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST] -----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, SUM(lineitem.l_quantity)@2 as sum_qty, SUM(lineitem.l_extendedprice)@3 as sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, AVG(lineitem.l_quantity)@6 as avg_qty, AVG(lineitem.l_extendedprice)@7 as avg_price, AVG(lineitem.l_discount)@8 as avg_disc, COUNT(UInt8(1))@9 as count_order] -------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(UInt8(1))] +----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, SUM(lineitem.l_quantity)@2 as sum_qty, SUM(lineitem.l_extendedprice)@3 as sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, AVG(lineitem.l_quantity)@6 as avg_qty, AVG(lineitem.l_extendedprice)@7 as avg_price, AVG(lineitem.l_discount)@8 as avg_disc, COUNT(*)@9 as count_order] +------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(UInt8(1))] +----------RepartitionExec: partitioning=Hash([l_returnflag@0, l_linestatus@1], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] --------------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, l_quantity@0 as l_quantity, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_tax@3 as l_tax, l_returnflag@4 as l_returnflag, l_linestatus@5 as l_linestatus] ----------------CoalesceBatchesExec: target_batch_size=8192 ------------------FilterExec: l_shipdate@6 <= 10471 ---------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], has_header=false - +--------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], has_header=false query TTRRRRRRRI select diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/q10.slt.part similarity index 80% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part rename to datafusion/sqllogictest/test_files/tpch/q10.slt.part index 82220d7c9377a..eb0b66f024de9 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q10.slt.part @@ -71,48 +71,46 @@ Limit: skip=0, fetch=10 ------------TableScan: nation projection=[n_nationkey, n_name] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [revenue@2 DESC] -----SortExec: fetch=10, expr=[revenue@2 DESC] +--SortPreservingMergeExec: [revenue@2 DESC], fetch=10 +----SortExec: TopK(fetch=10), expr=[revenue@2 DESC] ------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment] --------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: "c_acctbal", index: 2 }, Column { name: "c_phone", index: 3 }, Column { name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: "c_comment", index: 6 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([c_custkey@0, c_name@1, c_acctbal@2, c_phone@3, n_name@4, c_address@5, c_comment@6], 4), input_partitions=4 --------------AggregateExec: mode=Partial, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@4 as c_acctbal, c_phone@3 as c_phone, n_name@8 as n_name, c_address@2 as c_address, c_comment@5 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_address@2 as c_address, c_phone@4 as c_phone, c_acctbal@5 as c_acctbal, c_comment@6 as c_comment, l_extendedprice@7 as l_extendedprice, l_discount@8 as l_discount, n_name@10 as n_name] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })] +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@3, n_nationkey@0)] ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "c_nationkey", index: 3 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([c_nationkey@3], 4), input_partitions=4 --------------------------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_address@2 as c_address, c_nationkey@3 as c_nationkey, c_phone@4 as c_phone, c_acctbal@5 as c_acctbal, c_comment@6 as c_comment, l_extendedprice@9 as l_extendedprice, l_discount@10 as l_discount] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_orderkey", index: 7 }, Column { name: "l_orderkey", index: 0 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@7, l_orderkey@0)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 7 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([o_orderkey@7], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_address@2 as c_address, c_nationkey@3 as c_nationkey, c_phone@4 as c_phone, c_acctbal@5 as c_acctbal, c_comment@6 as c_comment, o_orderkey@7 as o_orderkey] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 1 })] +----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)] ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 ----------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment], has_header=false +------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment], has_header=false ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 1 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 ----------------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] ------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------------------------------FilterExec: o_orderdate@2 >= 8674 AND o_orderdate@2 < 8766 -----------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +----------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------FilterExec: l_returnflag@3 = R -------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], has_header=false +------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], has_header=false ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 --------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/q11.slt.part new file mode 100644 index 0000000000000..4efa29e2c0ac7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/tpch/q11.slt.part @@ -0,0 +1,176 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query TT +explain select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc +limit 10; +---- +logical_plan +Limit: skip=0, fetch=10 +--Sort: value DESC NULLS FIRST, fetch=10 +----Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value +------Inner Join: Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) +--------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +----------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +------------Inner Join: supplier.s_nationkey = nation.n_nationkey +--------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +----------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +--------------Projection: nation.n_nationkey +----------------Filter: nation.n_name = Utf8("GERMANY") +------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +--------SubqueryAlias: __scalar_sq_1 +----------Projection: CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) +------------Aggregate: groupBy=[[]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +--------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost +----------------Inner Join: supplier.s_nationkey = nation.n_nationkey +------------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +--------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +----------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] +----------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +------------------Projection: nation.n_nationkey +--------------------Filter: nation.n_name = Utf8("GERMANY") +----------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--SortPreservingMergeExec: [value@1 DESC], fetch=10 +----SortExec: TopK(fetch=10), expr=[value@1 DESC] +------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] +--------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1 +----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +----------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] +------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_availqty@1 as ps_availqty, ps_supplycost@2 as ps_supplycost] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)] +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +----------------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_availqty@2 as ps_availqty, ps_supplycost@3 as ps_supplycost, s_nationkey@5 as s_nationkey] +------------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)] +----------------------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +--------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], has_header=false +----------------------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +----------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] +------------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------------FilterExec: n_name@1 = GERMANY +----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +----------ProjectionExec: expr=[CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] +------------AggregateExec: mode=Final, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] +--------------CoalescePartitionsExec +----------------AggregateExec: mode=Partial, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] +------------------ProjectionExec: expr=[ps_availqty@0 as ps_availqty, ps_supplycost@1 as ps_supplycost] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)] +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +----------------------------ProjectionExec: expr=[ps_availqty@1 as ps_availqty, ps_supplycost@2 as ps_supplycost, s_nationkey@4 as s_nationkey] +------------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)] +----------------------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +--------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], has_header=false +----------------------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +----------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] +------------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------------FilterExec: n_name@1 = GERMANY +----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false + + + +query IR +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc +limit 10; +---- +12098 16227681.21 +5134 15709338.52 +13334 15023662.41 +17052 14351644.2 +3452 14070870.14 +12552 13332469.18 +1084 13170428.29 +5797 13038622.72 +12633 12892561.61 +403 12856217.34 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part b/datafusion/sqllogictest/test_files/tpch/q12.slt.part similarity index 81% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part rename to datafusion/sqllogictest/test_files/tpch/q12.slt.part index fdada35952f83..09939359ce122 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q12.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q12.slt.part @@ -64,22 +64,20 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] ----ProjectionExec: expr=[l_shipmode@0 as l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@1 as high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@2 as low_line_count] ------AggregateExec: mode=FinalPartitioned, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "l_shipmode", index: 0 }], 4), input_partitions=4 +----------RepartitionExec: partitioning=Hash([l_shipmode@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] --------------ProjectionExec: expr=[l_shipmode@1 as l_shipmode, o_orderpriority@3 as o_orderpriority] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderkey", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_shipmode@4 as l_shipmode] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131 -------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false +------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_orderpriority], has_header=false +----------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderpriority], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/q13.slt.part new file mode 100644 index 0000000000000..5cf6ace8b27b5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/tpch/q13.slt.part @@ -0,0 +1,115 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query TT +explain select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc +limit 10; +---- +logical_plan +Limit: skip=0, fetch=10 +--Sort: custdist DESC NULLS FIRST, c_orders.c_count DESC NULLS FIRST, fetch=10 +----Projection: c_orders.c_count, COUNT(*) AS custdist +------Aggregate: groupBy=[[c_orders.c_count]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] +--------SubqueryAlias: c_orders +----------Projection: COUNT(orders.o_orderkey) AS c_count +------------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[COUNT(orders.o_orderkey)]] +--------------Projection: customer.c_custkey, orders.o_orderkey +----------------Left Join: customer.c_custkey = orders.o_custkey +------------------TableScan: customer projection=[c_custkey] +------------------Projection: orders.o_orderkey, orders.o_custkey +--------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") +----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] +physical_plan +GlobalLimitExec: skip=0, fetch=10 +--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10 +----SortExec: TopK(fetch=10), expr=[custdist@1 DESC,c_count@0 DESC] +------ProjectionExec: expr=[c_count@0 as c_count, COUNT(*)@1 as custdist] +--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(*)] +----------CoalesceBatchesExec: target_batch_size=8192 +------------RepartitionExec: partitioning=Hash([c_count@0], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[COUNT(*)] +----------------ProjectionExec: expr=[COUNT(orders.o_orderkey)@1 as c_count] +------------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[COUNT(orders.o_orderkey)] +--------------------ProjectionExec: expr=[c_custkey@0 as c_custkey, o_orderkey@1 as o_orderkey] +----------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(c_custkey@0, o_custkey@1)] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], has_header=false +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] +--------------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------------FilterExec: o_comment@2 NOT LIKE %special%requests% +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_comment], has_header=false + + + +query II +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc +limit 10; +---- +0 5000 +10 665 +9 657 +11 621 +12 567 +8 564 +13 492 +18 482 +7 480 +20 456 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/q14.slt.part similarity index 86% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q14.slt.part rename to datafusion/sqllogictest/test_files/tpch/q14.slt.part index 08f4c4eb430b6..b584972c25bc8 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q14.slt.part @@ -47,18 +47,17 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") ------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] --------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type] ----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_partkey", index: 0 }, Column { name: "p_partkey", index: 0 })] +------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] --------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4 +----------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 ------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] --------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------FilterExec: l_shipdate@3 >= 9374 AND l_shipdate@3 < 9404 -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false --------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +----------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 ------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/q15.slt.part similarity index 50% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q15.slt.part rename to datafusion/sqllogictest/test_files/tpch/q15.slt.part index f7e428dcfb9d6..a872e96acf04e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q15.slt.part @@ -51,71 +51,63 @@ order by ---- logical_plan Sort: supplier.s_suppkey ASC NULLS LAST ---Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, total_revenue -----Inner Join: total_revenue = __scalar_sq_3.__value -------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, total_revenue ---------Inner Join: supplier.s_suppkey = supplier_no +--Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue +----Inner Join: revenue0.total_revenue = __scalar_sq_1.MAX(revenue0.total_revenue) +------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue +--------Inner Join: supplier.s_suppkey = revenue0.supplier_no ----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone] -----------Projection: revenue0.l_suppkey AS supplier_no, revenue0.SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -------------SubqueryAlias: revenue0 +----------SubqueryAlias: revenue0 +------------Projection: lineitem.l_suppkey AS supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +------------------Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") +--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9496"), lineitem.l_shipdate < Date32("9587")] +------SubqueryAlias: __scalar_sq_1 +--------Aggregate: groupBy=[[]], aggr=[[MAX(revenue0.total_revenue)]] +----------SubqueryAlias: revenue0 +------------Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue --------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] ----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount ------------------Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") --------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9496"), lineitem.l_shipdate < Date32("9587")] -------SubqueryAlias: __scalar_sq_3 ---------Projection: MAX(total_revenue) AS __value -----------Aggregate: groupBy=[[]], aggr=[[MAX(total_revenue)]] -------------Projection: revenue0.SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue ---------------SubqueryAlias: revenue0 -----------------Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) -------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] ---------------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -----------------------Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") -------------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9496"), lineitem.l_shipdate < Date32("9587")] physical_plan SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] --SortExec: expr=[s_suppkey@0 ASC NULLS LAST] ----ProjectionExec: expr=[s_suppkey@0 as s_suppkey, s_name@1 as s_name, s_address@2 as s_address, s_phone@3 as s_phone, total_revenue@4 as total_revenue] ------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "total_revenue", index: 4 }, Column { name: "__value", index: 0 })] +--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(total_revenue@4, MAX(revenue0.total_revenue)@0)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "total_revenue", index: 4 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([total_revenue@4], 4), input_partitions=4 --------------ProjectionExec: expr=[s_suppkey@0 as s_suppkey, s_name@1 as s_name, s_address@2 as s_address, s_phone@3 as s_phone, total_revenue@5 as total_revenue] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_suppkey", index: 0 }, Column { name: "supplier_no", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 ------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], has_header=false +--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], has_header=false --------------------ProjectionExec: expr=[l_suppkey@0 as supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] ----------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 0 }], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 ----------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ------------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] --------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "__value", index: 0 }], 4), input_partitions=1 ---------------ProjectionExec: expr=[MAX(total_revenue)@0 as __value] -----------------AggregateExec: mode=Final, gby=[], aggr=[MAX(total_revenue)] -------------------CoalescePartitionsExec ---------------------AggregateExec: mode=Partial, gby=[], aggr=[MAX(total_revenue)] -----------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as total_revenue] -------------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ---------------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 0 }], 4), input_partitions=4 ---------------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -----------------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 -----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false - - +------------RepartitionExec: partitioning=Hash([MAX(revenue0.total_revenue)@0], 4), input_partitions=1 +--------------AggregateExec: mode=Final, gby=[], aggr=[MAX(revenue0.total_revenue)] +----------------CoalescePartitionsExec +------------------AggregateExec: mode=Partial, gby=[], aggr=[MAX(revenue0.total_revenue)] +--------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +----------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +----------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +------------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] +--------------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false query ITTTR with revenue0 (supplier_no, total_revenue) as ( diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part similarity index 66% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part rename to datafusion/sqllogictest/test_files/tpch/q16.slt.part index 4f4316b084789..b93872929fe55 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,9 +52,9 @@ limit 10; logical_plan Limit: skip=0, fetch=10 --Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -----Projection: group_alias_0 AS part.p_brand, group_alias_1 AS part.p_type, group_alias_2 AS part.p_size, COUNT(alias1) AS supplier_cnt -------Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]] ---------Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]] +----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt +------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] +--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] ----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey ------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size --------------Inner Join: partsupp.ps_partkey = part.p_partkey @@ -67,41 +67,40 @@ Limit: skip=0, fetch=10 ------------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1 as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt] ---------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] +--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 +----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] +------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] +--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "group_alias_0", index: 0 }, Column { name: "group_alias_1", index: 1 }, Column { name: "group_alias_2", index: 2 }], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2], aggr=[COUNT(alias1)] -----------------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as group_alias_2, alias1@3 as alias1], aggr=[] +------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 +--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([Column { name: "group_alias_0", index: 0 }, Column { name: "group_alias_1", index: 1 }, Column { name: "group_alias_2", index: 2 }, Column { name: "alias1", index: 3 }], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as group_alias_0, p_type@2 as group_alias_1, p_size@3 as group_alias_2, ps_suppkey@0 as alias1], aggr=[] +--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "ps_suppkey", index: 0 }, Column { name: "s_suppkey", index: 0 })] +--------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 0 }], 4), input_partitions=4 +------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 --------------------------------ProjectionExec: expr=[ps_suppkey@1 as ps_suppkey, p_brand@3 as p_brand, p_type@4 as p_type, p_size@5 as p_size] ----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_partkey", index: 0 }, Column { name: "p_partkey", index: 0 })] +------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 -------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey], has_header=false +----------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], has_header=false --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +----------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) ----------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], has_header=false +------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], has_header=false ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 --------------------------------ProjectionExec: expr=[s_suppkey@0 as s_suppkey] ----------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints% --------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], has_header=false +----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/q17.slt.part similarity index 54% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q17.slt.part rename to datafusion/sqllogictest/test_files/tpch/q17.slt.part index 522d67811aac9..4d4aa4b1395fd 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q17.slt.part @@ -39,15 +39,15 @@ logical_plan Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly --Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice)]] ----Projection: lineitem.l_extendedprice -------Inner Join: part.p_partkey = __scalar_sq_5.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_5.__value +------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * AVG(lineitem.l_quantity) --------Projection: lineitem.l_quantity, lineitem.l_extendedprice, part.p_partkey ----------Inner Join: lineitem.l_partkey = part.p_partkey ------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] ------------Projection: part.p_partkey --------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") ----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] ---------SubqueryAlias: __scalar_sq_5 -----------Projection: lineitem.l_partkey, CAST(Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)) AS __value +--------SubqueryAlias: __scalar_sq_1 +----------Projection: CAST(Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey ------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[AVG(lineitem.l_quantity)]] --------------TableScan: lineitem projection=[l_partkey, l_quantity] physical_plan @@ -57,30 +57,26 @@ ProjectionExec: expr=[CAST(SUM(lineitem.l_extendedprice)@0 AS Float64) / 7 as av ------AggregateExec: mode=Partial, gby=[], aggr=[SUM(lineitem.l_extendedprice)] --------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice] ----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 2 }, Column { name: "l_partkey", index: 0 })], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < __value@1 ---------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 2 }], 4), input_partitions=4 -------------------ProjectionExec: expr=[l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey] +------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * AVG(lineitem.l_quantity)@1 +--------------ProjectionExec: expr=[l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_partkey", index: 0 }, Column { name: "p_partkey", index: 0 })] -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4 -----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity, l_extendedprice], has_header=false -------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 -----------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX -----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false ---------------ProjectionExec: expr=[l_partkey@0 as l_partkey, CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as __value] +----------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice], has_header=false +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false +--------------ProjectionExec: expr=[CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * AVG(lineitem.l_quantity), l_partkey@0 as l_partkey] ----------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4 +--------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 ----------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity], has_header=false +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q18.slt.part b/datafusion/sqllogictest/test_files/tpch/q18.slt.part similarity index 71% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q18.slt.part rename to datafusion/sqllogictest/test_files/tpch/q18.slt.part index f7a96b8b64ae7..53191a5d44e15 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q18.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q18.slt.part @@ -53,7 +53,7 @@ order by logical_plan Sort: orders.o_totalprice DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST --Aggregate: groupBy=[[customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice]], aggr=[[SUM(lineitem.l_quantity)]] -----LeftSemi Join: orders.o_orderkey = __correlated_sq_3.l_orderkey +----LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey ------Projection: customer.c_custkey, customer.c_name, orders.o_orderkey, orders.o_totalprice, orders.o_orderdate, lineitem.l_quantity --------Inner Join: orders.o_orderkey = lineitem.l_orderkey ----------Projection: customer.c_custkey, customer.c_name, orders.o_orderkey, orders.o_totalprice, orders.o_orderdate @@ -61,7 +61,7 @@ Sort: orders.o_totalprice DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST --------------TableScan: customer projection=[c_custkey, c_name] --------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice, o_orderdate] ----------TableScan: lineitem projection=[l_orderkey, l_quantity] -------SubqueryAlias: __correlated_sq_3 +------SubqueryAlias: __correlated_sq_1 --------Projection: lineitem.l_orderkey ----------Filter: SUM(lineitem.l_quantity) > Decimal128(Some(30000),25,2) ------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_quantity)]] @@ -71,39 +71,36 @@ SortPreservingMergeExec: [o_totalprice@4 DESC,o_orderdate@3 ASC NULLS LAST] --SortExec: expr=[o_totalprice@4 DESC,o_orderdate@3 ASC NULLS LAST] ----AggregateExec: mode=FinalPartitioned, gby=[c_name@0 as c_name, c_custkey@1 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@3 as o_orderdate, o_totalprice@4 as o_totalprice], aggr=[SUM(lineitem.l_quantity)] ------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Column { name: "c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name: "o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column { name: "o_totalprice", index: 4 }], 4), input_partitions=4 +--------RepartitionExec: partitioning=Hash([c_name@0, c_custkey@1, o_orderkey@2, o_orderdate@3, o_totalprice@4], 4), input_partitions=4 ----------AggregateExec: mode=Partial, gby=[c_name@1 as c_name, c_custkey@0 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@4 as o_orderdate, o_totalprice@3 as o_totalprice], aggr=[SUM(lineitem.l_quantity)] ------------CoalesceBatchesExec: target_batch_size=8192 ---------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: "o_orderkey", index: 2 }, Column { name: "l_orderkey", index: 0 })] +--------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(o_orderkey@2, l_orderkey@0)] ----------------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, o_orderkey@2 as o_orderkey, o_totalprice@3 as o_totalprice, o_orderdate@4 as o_orderdate, l_quantity@6 as l_quantity] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_orderkey", index: 2 }, Column { name: "l_orderkey", index: 0 })] +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@2, l_orderkey@0)] ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 2 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([o_orderkey@2], 4), input_partitions=4 --------------------------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, o_orderkey@2 as o_orderkey, o_totalprice@4 as o_totalprice, o_orderdate@5 as o_orderdate] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 1 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 ------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name], has_header=false +--------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 1 }], 4), input_partitions=4 -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_totalprice, o_orderdate], has_header=false +----------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_totalprice, o_orderdate], has_header=false ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 ---------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_quantity], has_header=false +------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_quantity], has_header=false ----------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey] ------------------CoalesceBatchesExec: target_batch_size=8192 --------------------FilterExec: SUM(lineitem.l_quantity)@1 > Some(30000),25,2 ----------------------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey], aggr=[SUM(lineitem.l_quantity)] ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ----------------------------AggregateExec: mode=Partial, gby=[l_orderkey@0 as l_orderkey], aggr=[SUM(lineitem.l_quantity)] -------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_quantity], has_header=false +------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_quantity], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/q19.slt.part similarity index 83% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part rename to datafusion/sqllogictest/test_files/tpch/q19.slt.part index 1a91fed124c00..2df27bd41082d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q19.slt.part @@ -70,20 +70,19 @@ ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_disco ------AggregateExec: mode=Partial, gby=[], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] --------ProjectionExec: expr=[l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount] ----------CoalesceBatchesExec: target_batch_size=8192 -------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_partkey", index: 0 }, Column { name: "p_partkey", index: 0 })], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15 +------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15 --------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }], 4), input_partitions=4 +----------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 ------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount] --------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON -------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false +------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false --------------CoalesceBatchesExec: target_batch_size=8192 -----------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +----------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 ------------------CoalesceBatchesExec: target_batch_size=8192 --------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 ----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], has_header=false +------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/q2.slt.part similarity index 74% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part rename to datafusion/sqllogictest/test_files/tpch/q2.slt.part index fe125c2b3b0cc..ed439348d22de 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q2.slt.part @@ -66,7 +66,7 @@ logical_plan Limit: skip=0, fetch=10 --Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST, fetch=10 ----Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment -------Inner Join: part.p_partkey = __scalar_sq_7.ps_partkey, partsupp.ps_supplycost = __scalar_sq_7.__value +------Inner Join: part.p_partkey = __scalar_sq_1.ps_partkey, partsupp.ps_supplycost = __scalar_sq_1.MIN(partsupp.ps_supplycost) --------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name ----------Inner Join: nation.n_regionkey = region.r_regionkey ------------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name, nation.n_regionkey @@ -84,8 +84,8 @@ Limit: skip=0, fetch=10 ------------Projection: region.r_regionkey --------------Filter: region.r_name = Utf8("EUROPE") ----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] ---------SubqueryAlias: __scalar_sq_7 -----------Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value +--------SubqueryAlias: __scalar_sq_1 +----------Projection: MIN(partsupp.ps_supplycost), partsupp.ps_partkey ------------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] --------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost ----------------Inner Join: nation.n_regionkey = region.r_regionkey @@ -101,96 +101,94 @@ Limit: skip=0, fetch=10 ----------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] -----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] +--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10 +----SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] ------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] --------CoalesceBatchesExec: target_batch_size=8192 -----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 }), (Column { name: "ps_supplycost", index: 7 }, Column { name: "__value", index: 1 })] +----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@1), (ps_supplycost@7, MIN(partsupp.ps_supplycost)@0)] ------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }, Column { name: "ps_supplycost", index: 7 }], 4), input_partitions=4 +--------------RepartitionExec: partitioning=Hash([p_partkey@0, ps_supplycost@7], 4), input_partitions=4 ----------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@2 as s_name, s_address@3 as s_address, s_phone@4 as s_phone, s_acctbal@5 as s_acctbal, s_comment@6 as s_comment, ps_supplycost@7 as ps_supplycost, n_name@8 as n_name] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 9 }, Column { name: "r_regionkey", index: 0 })] +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@9, r_regionkey@0)] ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "n_regionkey", index: 9 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([n_regionkey@9], 4), input_partitions=4 --------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@2 as s_name, s_address@3 as s_address, s_phone@5 as s_phone, s_acctbal@6 as s_acctbal, s_comment@7 as s_comment, ps_supplycost@8 as ps_supplycost, n_name@10 as n_name, n_regionkey@11 as n_regionkey] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 4 }, Column { name: "n_nationkey", index: 0 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@4, n_nationkey@0)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 4 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([s_nationkey@4], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@5 as s_name, s_address@6 as s_address, s_nationkey@7 as s_nationkey, s_phone@8 as s_phone, s_acctbal@9 as s_acctbal, s_comment@10 as s_comment, ps_supplycost@3 as ps_supplycost] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_suppkey", index: 2 }, Column { name: "s_suppkey", index: 0 })] +----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@2, s_suppkey@0)] ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 2 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@2], 4), input_partitions=4 ----------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, ps_suppkey@3 as ps_suppkey, ps_supplycost@4 as ps_supplycost] ------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 })] +--------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@0)] ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 --------------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr] ----------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------------------------------FilterExec: p_size@3 = 15 AND p_type@2 LIKE %BRASS --------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_mfgr, p_type, p_size], has_header=false +----------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_mfgr, p_type, p_size], has_header=false ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 ---------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +------------------------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +--------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 ----------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment], has_header=false +------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false +--------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "r_regionkey", index: 0 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 --------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] ----------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------FilterExec: r_name@1 = EUROPE --------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false ------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }, Column { name: "__value", index: 1 }], 4), input_partitions=4 -----------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, MIN(partsupp.ps_supplycost)@1 as __value] +--------------RepartitionExec: partitioning=Hash([ps_partkey@1, MIN(partsupp.ps_supplycost)@0], 4), input_partitions=4 +----------------ProjectionExec: expr=[MIN(partsupp.ps_supplycost)@1 as MIN(partsupp.ps_supplycost), ps_partkey@0 as ps_partkey] ------------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 ------------------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] --------------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_supplycost@1 as ps_supplycost] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 2 }, Column { name: "r_regionkey", index: 0 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@2, r_regionkey@0)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_regionkey", index: 2 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([n_regionkey@2], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_supplycost@1 as ps_supplycost, n_regionkey@4 as n_regionkey] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 2 }, Column { name: "n_nationkey", index: 0 })] +----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)] ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 2 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 ----------------------------------------------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, ps_supplycost@2 as ps_supplycost, s_nationkey@4 as s_nationkey] ------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_suppkey", index: 1 }, Column { name: "s_suppkey", index: 0 })] +--------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)] ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 1 }], 4), input_partitions=4 ---------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +------------------------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +--------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 --------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ----------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false +------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "r_regionkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------FilterExec: r_name@1 = EUROPE ------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false @@ -240,7 +238,7 @@ order by p_partkey limit 10; ---- -9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily +9828.21 Supplier#000000647 UNITED KINGDOM 13120 Manufacturer#5 x5U7MBZmwfG9 33-258-202-4782 s the slyly even ideas poach fluffily 9508.37 Supplier#000000070 FRANCE 3563 Manufacturer#1 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9508.37 Supplier#000000070 FRANCE 17268 Manufacturer#4 INWNH2w,OOWgNDq0BRCcBwOMQc6PdFDc4 16-821-608-1166 ests sleep quickly express ideas. ironic ideas haggle about the final T 9453.01 Supplier#000000802 ROMANIA 10021 Manufacturer#5 ,6HYXb4uaHITmtMBj4Ak57Pd 29-342-882-6463 gular frets. permanently special multipliers believe blithely alongs diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/q20.slt.part similarity index 69% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q20.slt.part rename to datafusion/sqllogictest/test_files/tpch/q20.slt.part index f6d343d4db30d..e014c6cafd989 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q20.slt.part @@ -58,24 +58,24 @@ order by logical_plan Sort: supplier.s_name ASC NULLS LAST --Projection: supplier.s_name, supplier.s_address -----LeftSemi Join: supplier.s_suppkey = __correlated_sq_5.ps_suppkey +----LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey ------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address --------Inner Join: supplier.s_nationkey = nation.n_nationkey ----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] ----------Projection: nation.n_nationkey ------------Filter: nation.n_name = Utf8("CANADA") --------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] -------SubqueryAlias: __correlated_sq_5 +------SubqueryAlias: __correlated_sq_1 --------Projection: partsupp.ps_suppkey -----------Inner Join: partsupp.ps_partkey = __scalar_sq_9.l_partkey, partsupp.ps_suppkey = __scalar_sq_9.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_9.__value -------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_6.p_partkey +----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * SUM(lineitem.l_quantity) +------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey --------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] ---------------SubqueryAlias: __correlated_sq_6 +--------------SubqueryAlias: __correlated_sq_2 ----------------Projection: part.p_partkey ------------------Filter: part.p_name LIKE Utf8("forest%") --------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] -------------SubqueryAlias: __scalar_sq_9 ---------------Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value +------------SubqueryAlias: __scalar_sq_3 +--------------Projection: Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey ----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] ------------------Projection: lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity --------------------Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") @@ -85,53 +85,51 @@ SortPreservingMergeExec: [s_name@0 ASC NULLS LAST] --SortExec: expr=[s_name@0 ASC NULLS LAST] ----ProjectionExec: expr=[s_name@1 as s_name, s_address@2 as s_address] ------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: "s_suppkey", index: 0 }, Column { name: "ps_suppkey", index: 0 })] +--------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_suppkey@0, ps_suppkey@0)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 --------------ProjectionExec: expr=[s_suppkey@0 as s_suppkey, s_name@1 as s_name, s_address@2 as s_address] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 3 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 ------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey], has_header=false +--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: n_name@1 = CANADA ------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 0 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 --------------ProjectionExec: expr=[ps_suppkey@1 as ps_suppkey] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_partkey", index: 0 }, Column { name: "l_partkey", index: 0 }), (Column { name: "ps_suppkey", index: 1 }, Column { name: "l_suppkey", index: 1 })], filter=CAST(ps_availqty@0 AS Float64) > __value@1 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, l_partkey@1), (ps_suppkey@1, l_suppkey@2)], filter=CAST(ps_availqty@0 AS Float64) > Float64(0.5) * SUM(lineitem.l_quantity)@1 --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }, Column { name: "ps_suppkey", index: 1 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([ps_partkey@0, ps_suppkey@1], 4), input_partitions=4 ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: "ps_partkey", index: 0 }, Column { name: "p_partkey", index: 0 })] +--------------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(ps_partkey@0, p_partkey@0)] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 ---------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey, ps_availqty], has_header=false +------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty], has_header=false ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 --------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] ----------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------FilterExec: p_name@1 LIKE forest% --------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false ---------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey, 0.5 * CAST(SUM(lineitem.l_quantity)@2 AS Float64) as __value] +----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false +--------------------ProjectionExec: expr=[0.5 * CAST(SUM(lineitem.l_quantity)@2 AS Float64) as Float64(0.5) * SUM(lineitem.l_quantity), l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey] ----------------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[SUM(lineitem.l_quantity)] ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 0 }, Column { name: "l_suppkey", index: 1 }], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=Hash([l_partkey@0, l_suppkey@1], 4), input_partitions=4 ----------------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[SUM(lineitem.l_quantity)] ------------------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey, l_quantity@2 as l_quantity] --------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------FilterExec: l_shipdate@3 >= 8766 AND l_shipdate@3 < 9131 -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], has_header=false +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/q21.slt.part similarity index 71% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q21.slt.part rename to datafusion/sqllogictest/test_files/tpch/q21.slt.part index bdc8e2076b6af..147afc603c2c1 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q21.slt.part @@ -59,11 +59,11 @@ order by ---- logical_plan Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST ---Projection: supplier.s_name, COUNT(UInt8(1)) AS numwait -----Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(UInt8(1))]] +--Projection: supplier.s_name, COUNT(*) AS numwait +----Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ------Projection: supplier.s_name ---------LeftAnti Join: l1.l_orderkey = __correlated_sq_10.l_orderkey Filter: __correlated_sq_10.l_suppkey != l1.l_suppkey -----------LeftSemi Join: l1.l_orderkey = __correlated_sq_9.l_orderkey Filter: __correlated_sq_9.l_suppkey != l1.l_suppkey +--------LeftAnti Join: l1.l_orderkey = __correlated_sq_2.l_orderkey Filter: __correlated_sq_2.l_suppkey != l1.l_suppkey +----------LeftSemi Join: l1.l_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_suppkey != l1.l_suppkey ------------Projection: supplier.s_name, l1.l_orderkey, l1.l_suppkey --------------Inner Join: supplier.s_nationkey = nation.n_nationkey ----------------Projection: supplier.s_name, supplier.s_nationkey, l1.l_orderkey, l1.l_suppkey @@ -81,10 +81,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST ----------------Projection: nation.n_nationkey ------------------Filter: nation.n_name = Utf8("SAUDI ARABIA") --------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("SAUDI ARABIA")] -------------SubqueryAlias: __correlated_sq_9 +------------SubqueryAlias: __correlated_sq_1 --------------SubqueryAlias: l2 ----------------TableScan: lineitem projection=[l_orderkey, l_suppkey] -----------SubqueryAlias: __correlated_sq_10 +----------SubqueryAlias: __correlated_sq_2 ------------SubqueryAlias: l3 --------------Projection: lineitem.l_orderkey, lineitem.l_suppkey ----------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate @@ -92,67 +92,63 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST physical_plan SortPreservingMergeExec: [numwait@1 DESC,s_name@0 ASC NULLS LAST] --SortExec: expr=[numwait@1 DESC,s_name@0 ASC NULLS LAST] -----ProjectionExec: expr=[s_name@0 as s_name, COUNT(UInt8(1))@1 as numwait] -------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[COUNT(UInt8(1))] +----ProjectionExec: expr=[s_name@0 as s_name, COUNT(*)@1 as numwait] +------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[COUNT(*)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "s_name", index: 0 }], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[COUNT(UInt8(1))] +----------RepartitionExec: partitioning=Hash([s_name@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[COUNT(*)] --------------ProjectionExec: expr=[s_name@0 as s_name] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "l_orderkey", index: 1 }, Column { name: "l_orderkey", index: 0 })], filter=l_suppkey@1 != l_suppkey@0 +------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(l_orderkey@1, l_orderkey@0)], filter=l_suppkey@1 != l_suppkey@0 --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: "l_orderkey", index: 1 }, Column { name: "l_orderkey", index: 0 })], filter=l_suppkey@1 != l_suppkey@0 +----------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(l_orderkey@1, l_orderkey@0)], filter=l_suppkey@1 != l_suppkey@0 ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 1 }], 4), input_partitions=4 +--------------------------RepartitionExec: partitioning=Hash([l_orderkey@1], 4), input_partitions=4 ----------------------------ProjectionExec: expr=[s_name@0 as s_name, l_orderkey@2 as l_orderkey, l_suppkey@3 as l_suppkey] ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 1 }, Column { name: "n_nationkey", index: 0 })] +--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@1, n_nationkey@0)] ----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 1 }], 4), input_partitions=4 +------------------------------------RepartitionExec: partitioning=Hash([s_nationkey@1], 4), input_partitions=4 --------------------------------------ProjectionExec: expr=[s_name@0 as s_name, s_nationkey@1 as s_nationkey, l_orderkey@2 as l_orderkey, l_suppkey@3 as l_suppkey] ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 2 }, Column { name: "o_orderkey", index: 0 })] +------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@2, o_orderkey@0)] --------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 2 }], 4), input_partitions=4 +----------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@2], 4), input_partitions=4 ------------------------------------------------ProjectionExec: expr=[s_name@1 as s_name, s_nationkey@2 as s_nationkey, l_orderkey@3 as l_orderkey, l_suppkey@4 as l_suppkey] --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_suppkey", index: 0 }, Column { name: "l_suppkey", index: 1 })] +----------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, l_suppkey@1)] ------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 ----------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_nationkey], has_header=false +------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_nationkey], has_header=false ------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 1 }], 4), input_partitions=4 +--------------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1], 4), input_partitions=4 ----------------------------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey] ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------------------------------------------FilterExec: l_receiptdate@3 > l_commitdate@2 -----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false +----------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false --------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 +----------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 ------------------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey] --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------------------FilterExec: o_orderstatus@1 = F -------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_orderstatus], has_header=false +------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderstatus], has_header=false ----------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 --------------------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------------FilterExec: n_name@1 = SAUDI ARABIA --------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false ------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 -----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_suppkey], has_header=false +--------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: l_receiptdate@3 > l_commitdate@2 -------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false +------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part new file mode 100644 index 0000000000000..2713d5bf6e18e --- /dev/null +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -0,0 +1,153 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query TT +explain select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; +---- +logical_plan +Sort: custsale.cntrycode ASC NULLS LAST +--Projection: custsale.cntrycode, COUNT(*) AS numcust, SUM(custsale.c_acctbal) AS totacctbal +----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(UInt8(1)) AS COUNT(*), SUM(custsale.c_acctbal)]] +------SubqueryAlias: custsale +--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal +----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.AVG(customer.c_acctbal) +------------Projection: customer.c_phone, customer.c_acctbal +--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey +----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) +------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] +----------------SubqueryAlias: __correlated_sq_1 +------------------TableScan: orders projection=[o_custkey] +------------SubqueryAlias: __scalar_sq_2 +--------------Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] +----------------Projection: customer.c_acctbal +------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) +--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] +physical_plan +SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] +--SortExec: expr=[cntrycode@0 ASC NULLS LAST] +----ProjectionExec: expr=[cntrycode@0 as cntrycode, COUNT(*)@1 as numcust, SUM(custsale.c_acctbal)@2 as totacctbal] +------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), SUM(custsale.c_acctbal)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), SUM(custsale.c_acctbal)] +--------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] +----------------NestedLoopJoinExec: join_type=Inner, filter=CAST(c_acctbal@0 AS Decimal128(19, 6)) > AVG(customer.c_acctbal)@1 +------------------ProjectionExec: expr=[c_phone@1 as c_phone, c_acctbal@2 as c_acctbal] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)] +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +----------------------------CoalesceBatchesExec: target_batch_size=8192 +------------------------------FilterExec: Use substr(c_phone@1, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) +--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], has_header=false +------------------------CoalesceBatchesExec: target_batch_size=8192 +--------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 +----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], has_header=false +------------------AggregateExec: mode=Final, gby=[], aggr=[AVG(customer.c_acctbal)] +--------------------CoalescePartitionsExec +----------------------AggregateExec: mode=Partial, gby=[], aggr=[AVG(customer.c_acctbal)] +------------------------ProjectionExec: expr=[c_acctbal@1 as c_acctbal] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], has_header=false + + +query TIR +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; +---- +13 94 714035.05 +17 96 722560.15 +18 99 738012.52 +23 93 708285.25 +29 85 632693.46 +30 87 646748.02 +31 87 647372.5 diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/q3.slt.part similarity index 78% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part rename to datafusion/sqllogictest/test_files/tpch/q3.slt.part index 381ea531d602b..85f2d9986c277 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q3.slt.part @@ -60,41 +60,39 @@ Limit: skip=0, fetch=10 ----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] -----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] +--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10 +----SortExec: TopK(fetch=10), expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] ------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority] --------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 1 }, Column { name: "o_shippriority", index: 2 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([l_orderkey@0, o_orderdate@1, o_shippriority@2], 4), input_partitions=4 --------------AggregateExec: mode=Partial, gby=[l_orderkey@2 as l_orderkey, o_orderdate@0 as o_orderdate, o_shippriority@1 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----------------ProjectionExec: expr=[o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority, l_orderkey@3 as l_orderkey, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_orderkey", index: 0 }, Column { name: "l_orderkey", index: 0 })] +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@0, l_orderkey@0)] ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 --------------------------ProjectionExec: expr=[o_orderkey@1 as o_orderkey, o_orderdate@3 as o_orderdate, o_shippriority@4 as o_shippriority] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 1 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[c_custkey@0 as c_custkey] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------FilterExec: c_mktsegment@1 = BUILDING ------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_mktsegment], has_header=false +--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_mktsegment], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 1 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 ------------------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------------------FilterExec: o_orderdate@2 < 9204 -----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], has_header=false +----------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], has_header=false ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 --------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] ----------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------FilterExec: l_shipdate@3 > 9204 ---------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], has_header=false +--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q4.slt.part b/datafusion/sqllogictest/test_files/tpch/q4.slt.part similarity index 69% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q4.slt.part rename to datafusion/sqllogictest/test_files/tpch/q4.slt.part index 109bcd6d5cf5a..690ef64bc28d3 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q4.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q4.slt.part @@ -41,42 +41,40 @@ order by ---- logical_plan Sort: orders.o_orderpriority ASC NULLS LAST ---Projection: orders.o_orderpriority, COUNT(UInt8(1)) AS order_count -----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] +--Projection: orders.o_orderpriority, COUNT(*) AS order_count +----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ------Projection: orders.o_orderpriority ---------LeftSemi Join: orders.o_orderkey = __correlated_sq_15.l_orderkey +--------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey ----------Projection: orders.o_orderkey, orders.o_orderpriority ------------Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674") --------------TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority], partial_filters=[orders.o_orderdate >= Date32("8582"), orders.o_orderdate < Date32("8674")] -----------SubqueryAlias: __correlated_sq_15 +----------SubqueryAlias: __correlated_sq_1 ------------Projection: lineitem.l_orderkey --------------Filter: lineitem.l_commitdate < lineitem.l_receiptdate ----------------TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_commitdate < lineitem.l_receiptdate] physical_plan SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST] --SortExec: expr=[o_orderpriority@0 ASC NULLS LAST] -----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, COUNT(UInt8(1))@1 as order_count] -------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(UInt8(1))] +----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, COUNT(*)@1 as order_count] +------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "o_orderpriority", index: 0 }], 4), input_partitions=4 -------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(UInt8(1))] +----------RepartitionExec: partitioning=Hash([o_orderpriority@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] --------------ProjectionExec: expr=[o_orderpriority@1 as o_orderpriority] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(Column { name: "o_orderkey", index: 0 }, Column { name: "l_orderkey", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(o_orderkey@0, l_orderkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_orderpriority@2 as o_orderpriority] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: o_orderdate@1 >= 8582 AND o_orderdate@1 < 8674 -------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_orderdate, o_orderpriority], has_header=false +------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate, o_orderpriority], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: l_commitdate@1 < l_receiptdate@2 -------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_commitdate, l_receiptdate], has_header=false +------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_commitdate, l_receiptdate], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q5.slt.part b/datafusion/sqllogictest/test_files/tpch/q5.slt.part similarity index 73% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q5.slt.part rename to datafusion/sqllogictest/test_files/tpch/q5.slt.part index feac9d2a04a5c..af3a33497026d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q5.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q5.slt.part @@ -72,61 +72,59 @@ SortPreservingMergeExec: [revenue@1 DESC] ----ProjectionExec: expr=[n_name@0 as n_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as revenue] ------AggregateExec: mode=FinalPartitioned, gby=[n_name@0 as n_name], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "n_name", index: 0 }], 4), input_partitions=4 +----------RepartitionExec: partitioning=Hash([n_name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[n_name@2 as n_name], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] --------------ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, n_name@2 as n_name] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 3 }, Column { name: "r_regionkey", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@3, r_regionkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "n_regionkey", index: 3 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([n_regionkey@3], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, n_name@4 as n_name, n_regionkey@5 as n_regionkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 2 }, Column { name: "n_nationkey", index: 0 })] +----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)] ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 2 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 ----------------------------------ProjectionExec: expr=[l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount, s_nationkey@5 as s_nationkey] ------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_suppkey", index: 1 }, Column { name: "s_suppkey", index: 0 }), (Column { name: "c_nationkey", index: 0 }, Column { name: "s_nationkey", index: 1 })] +--------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@1, s_suppkey@0), (c_nationkey@0, s_nationkey@1)] ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 1 }, Column { name: "c_nationkey", index: 0 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1, c_nationkey@0], 4), input_partitions=4 --------------------------------------------ProjectionExec: expr=[c_nationkey@0 as c_nationkey, l_suppkey@3 as l_suppkey, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount] ----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_orderkey", index: 1 }, Column { name: "l_orderkey", index: 0 })] +------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@1, l_orderkey@0)] --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 1 }], 4), input_partitions=4 +----------------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@1], 4), input_partitions=4 ------------------------------------------------------ProjectionExec: expr=[c_nationkey@1 as c_nationkey, o_orderkey@2 as o_orderkey] --------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_custkey", index: 0 }, Column { name: "o_custkey", index: 1 })] +----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)] ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 ----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false +------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 1 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 ----------------------------------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] ------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------------------------------------------------FilterExec: o_orderdate@2 >= 8766 AND o_orderdate@2 < 9131 -----------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +----------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 -------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount], has_header=false +----------------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount], has_header=false ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }, Column { name: "s_nationkey", index: 1 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0, s_nationkey@1], 4), input_partitions=4 --------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false +------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "r_regionkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: r_name@1 = ASIA ------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q6.slt.part b/datafusion/sqllogictest/test_files/tpch/q6.slt.part similarity index 84% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q6.slt.part rename to datafusion/sqllogictest/test_files/tpch/q6.slt.part index e388f800725bb..8e53be297db04 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q6.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q6.slt.part @@ -41,8 +41,7 @@ ProjectionExec: expr=[SUM(lineitem.l_extendedprice * lineitem.l_discount)@0 as r --------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] ----------CoalesceBatchesExec: target_batch_size=8192 ------------FilterExec: l_shipdate@3 >= 8766 AND l_shipdate@3 < 9131 AND l_discount@2 >= Some(5),15,2 AND l_discount@2 <= Some(7),15,2 AND l_quantity@0 < Some(2400),15,2 ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], has_header=false +--------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q7.slt.part b/datafusion/sqllogictest/test_files/tpch/q7.slt.part similarity index 77% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q7.slt.part rename to datafusion/sqllogictest/test_files/tpch/q7.slt.part index dd538ebfe8538..5186c46a896f0 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q7.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q7.slt.part @@ -89,61 +89,59 @@ SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST,cust_nation@1 ASC NULLS L ----ProjectionExec: expr=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year, SUM(shipping.volume)@3 as revenue] ------AggregateExec: mode=FinalPartitioned, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[SUM(shipping.volume)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column { name: "l_year", index: 2 }], 4), input_partitions=4 +----------RepartitionExec: partitioning=Hash([supp_nation@0, cust_nation@1, l_year@2], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[SUM(shipping.volume)] --------------ProjectionExec: expr=[n_name@4 as supp_nation, n_name@6 as cust_nation, date_part(YEAR, l_shipdate@2) as l_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })], filter=n_name@0 = FRANCE AND n_name@1 = GERMANY OR n_name@0 = GERMANY AND n_name@1 = FRANCE +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@3, n_nationkey@0)], filter=n_name@0 = FRANCE AND n_name@1 = GERMANY OR n_name@0 = GERMANY AND n_name@1 = FRANCE --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "c_nationkey", index: 3 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([c_nationkey@3], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_shipdate@3 as l_shipdate, c_nationkey@4 as c_nationkey, n_name@6 as n_name] --------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 0 }, Column { name: "n_nationkey", index: 0 })] +----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@0, n_nationkey@0)] ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 0 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([s_nationkey@0], 4), input_partitions=4 ----------------------------------ProjectionExec: expr=[s_nationkey@0 as s_nationkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_shipdate@3 as l_shipdate, c_nationkey@6 as c_nationkey] ------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_custkey", index: 4 }, Column { name: "c_custkey", index: 0 })] +--------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_custkey@4, c_custkey@0)] ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 4 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([o_custkey@4], 4), input_partitions=4 --------------------------------------------ProjectionExec: expr=[s_nationkey@0 as s_nationkey, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount, l_shipdate@4 as l_shipdate, o_custkey@6 as o_custkey] ----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 1 }, Column { name: "o_orderkey", index: 0 })] +------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@1, o_orderkey@0)] --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 1 }], 4), input_partitions=4 +----------------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@1], 4), input_partitions=4 ------------------------------------------------------ProjectionExec: expr=[s_nationkey@1 as s_nationkey, l_orderkey@2 as l_orderkey, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount, l_shipdate@6 as l_shipdate] --------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_suppkey", index: 0 }, Column { name: "l_suppkey", index: 1 })] +----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, l_suppkey@1)] ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 ----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 1 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1], 4), input_partitions=4 ----------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------------------------------------FilterExec: l_shipdate@4 >= 9131 AND l_shipdate@4 <= 9861 ---------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false +--------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 -------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey], has_header=false +----------------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey], has_header=false ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 --------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false +----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ----------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------FilterExec: n_name@1 = FRANCE OR n_name@1 = GERMANY --------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ------------------------CoalesceBatchesExec: target_batch_size=8192 --------------------------FilterExec: n_name@1 = GERMANY OR n_name@1 = FRANCE ----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q8.slt.part b/datafusion/sqllogictest/test_files/tpch/q8.slt.part similarity index 78% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q8.slt.part rename to datafusion/sqllogictest/test_files/tpch/q8.slt.part index 38ee2119df6f8..760b40ad1ae85 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q8.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q8.slt.part @@ -95,81 +95,79 @@ SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] ----ProjectionExec: expr=[o_year@0 as o_year, CAST(CAST(SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END)@1 AS Decimal128(12, 2)) / CAST(SUM(all_nations.volume)@2 AS Decimal128(12, 2)) AS Decimal128(15, 2)) as mkt_share] ------AggregateExec: mode=FinalPartitioned, gby=[o_year@0 as o_year], aggr=[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "o_year", index: 0 }], 4), input_partitions=4 +----------RepartitionExec: partitioning=Hash([o_year@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[o_year@0 as o_year], aggr=[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)] --------------ProjectionExec: expr=[date_part(YEAR, o_orderdate@2) as o_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume, n_name@4 as nation] ----------------CoalesceBatchesExec: target_batch_size=8192 -------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 3 }, Column { name: "r_regionkey", index: 0 })] +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@3, r_regionkey@0)] --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "n_regionkey", index: 3 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([n_regionkey@3], 4), input_partitions=4 ------------------------ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, o_orderdate@3 as o_orderdate, n_regionkey@4 as n_regionkey, n_name@6 as n_name] --------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 2 }, Column { name: "n_nationkey", index: 0 })] +----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)] ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 2 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 ----------------------------------ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, s_nationkey@2 as s_nationkey, o_orderdate@3 as o_orderdate, n_regionkey@6 as n_regionkey] ------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_nationkey", index: 4 }, Column { name: "n_nationkey", index: 0 })] +--------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@4, n_nationkey@0)] ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_nationkey", index: 4 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([c_nationkey@4], 4), input_partitions=4 --------------------------------------------ProjectionExec: expr=[l_extendedprice@0 as l_extendedprice, l_discount@1 as l_discount, s_nationkey@2 as s_nationkey, o_orderdate@4 as o_orderdate, c_nationkey@6 as c_nationkey] ----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "o_custkey", index: 3 }, Column { name: "c_custkey", index: 0 })] +------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_custkey@3, c_custkey@0)] --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_custkey", index: 3 }], 4), input_partitions=4 +----------------------------------------------------RepartitionExec: partitioning=Hash([o_custkey@3], 4), input_partitions=4 ------------------------------------------------------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, s_nationkey@3 as s_nationkey, o_custkey@5 as o_custkey, o_orderdate@6 as o_orderdate] --------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderkey", index: 0 })] +----------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)] ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ----------------------------------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount, s_nationkey@5 as s_nationkey] ------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_suppkey", index: 1 }, Column { name: "s_suppkey", index: 0 })] +--------------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@1, s_suppkey@0)] ----------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 1 }], 4), input_partitions=4 +------------------------------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1], 4), input_partitions=4 --------------------------------------------------------------------------ProjectionExec: expr=[l_orderkey@1 as l_orderkey, l_suppkey@3 as l_suppkey, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount] ----------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "l_partkey", index: 1 })] +------------------------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, l_partkey@1)] --------------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +----------------------------------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 ------------------------------------------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] --------------------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------------------------------------------------------FilterExec: p_type@1 = ECONOMY ANODIZED STEEL ------------------------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +--------------------------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false --------------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 1 }], 4), input_partitions=4 -------------------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount], has_header=false +----------------------------------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 +------------------------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount], has_header=false ----------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +------------------------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 --------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +----------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 +--------------------------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 ----------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ------------------------------------------------------------------FilterExec: o_orderdate@2 >= 9131 AND o_orderdate@2 <= 9861 ---------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +--------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false --------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "c_custkey", index: 0 }], 4), input_partitions=4 +----------------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 ------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false +--------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false ----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +------------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 --------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false +----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false ------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +--------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 ----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false --------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------RepartitionExec: partitioning=Hash([Column { name: "r_regionkey", index: 0 }], 4), input_partitions=4 +----------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 ------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] --------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------FilterExec: r_name@1 = AMERICA ------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/q9.slt.part similarity index 74% rename from datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part rename to datafusion/sqllogictest/test_files/tpch/q9.slt.part index b2c49141c56fa..5db97f79bdb1d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q9.slt.part @@ -77,63 +77,60 @@ Limit: skip=0, fetch=10 --------------TableScan: nation projection=[n_nationkey, n_name] physical_plan GlobalLimitExec: skip=0, fetch=10 ---SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC] -----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] +--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10 +----SortExec: TopK(fetch=10), expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] ------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit] --------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] ----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 4), input_partitions=4 +------------RepartitionExec: partitioning=Hash([nation@0, o_year@1], 4), input_partitions=4 --------------AggregateExec: mode=Partial, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] ----------------ProjectionExec: expr=[n_name@7 as nation, date_part(YEAR, o_orderdate@5) as o_year, l_extendedprice@1 * (Some(1),20,0 - l_discount@2) - ps_supplycost@4 * l_quantity@0 as amount] ------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })] +--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)] ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 3 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 --------------------------ProjectionExec: expr=[l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount, s_nationkey@4 as s_nationkey, ps_supplycost@5 as ps_supplycost, o_orderdate@7 as o_orderdate] ----------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderkey", index: 0 })] +------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)] --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_orderkey", index: 0 }], 4), input_partitions=4 +----------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 ------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_quantity@3 as l_quantity, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount, s_nationkey@6 as s_nationkey, ps_supplycost@9 as ps_supplycost] --------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_suppkey", index: 2 }, Column { name: "ps_suppkey", index: 1 }), (Column { name: "l_partkey", index: 1 }, Column { name: "ps_partkey", index: 0 })] +----------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, ps_suppkey@1), (l_partkey@1, ps_partkey@0)] ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 2 }, Column { name: "l_partkey", index: 1 }], 4), input_partitions=4 +--------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2, l_partkey@1], 4), input_partitions=4 ----------------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_partkey@1 as l_partkey, l_suppkey@2 as l_suppkey, l_quantity@3 as l_quantity, l_extendedprice@4 as l_extendedprice, l_discount@5 as l_discount, s_nationkey@7 as s_nationkey] ------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_suppkey", index: 2 }, Column { name: "s_suppkey", index: 0 })] +--------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, s_suppkey@0)] ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_suppkey", index: 2 }], 4), input_partitions=4 +------------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2], 4), input_partitions=4 --------------------------------------------------------ProjectionExec: expr=[l_orderkey@1 as l_orderkey, l_partkey@2 as l_partkey, l_suppkey@3 as l_suppkey, l_quantity@4 as l_quantity, l_extendedprice@5 as l_extendedprice, l_discount@6 as l_discount] ----------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "l_partkey", index: 1 })] +------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, l_partkey@1)] --------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }], 4), input_partitions=4 +----------------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 ------------------------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] --------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ----------------------------------------------------------------------FilterExec: p_name@1 LIKE %green% ------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false +--------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false --------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "l_partkey", index: 1 }], 4), input_partitions=4 -------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/lineitem.tbl]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount], has_header=false +----------------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 +------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount], has_header=false ----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "s_suppkey", index: 0 }], 4), input_partitions=4 +------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 --------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false ------------------------------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------------------------------RepartitionExec: partitioning=Hash([Column { name: "ps_suppkey", index: 1 }, Column { name: "ps_partkey", index: 0 }], 4), input_partitions=4 -----------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/partsupp.tbl]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +--------------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1, ps_partkey@0], 4), input_partitions=4 +----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false --------------------------------CoalesceBatchesExec: target_batch_size=8192 -----------------------------------RepartitionExec: partitioning=Hash([Column { name: "o_orderkey", index: 0 }], 4), input_partitions=4 -------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/orders.tbl]]}, projection=[o_orderkey, o_orderdate], has_header=false +----------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate], has_header=false ----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 --------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/sqllogictests/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/tpch.slt b/datafusion/sqllogictest/test_files/tpch/tpch.slt similarity index 100% rename from datafusion/core/tests/sqllogictests/test_files/tpch/tpch.slt rename to datafusion/sqllogictest/test_files/tpch/tpch.slt diff --git a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt similarity index 82% rename from datafusion/core/tests/sqllogictests/test_files/type_coercion.slt rename to datafusion/sqllogictest/test_files/type_coercion.slt index 9aced0a3fd4e8..aa1e6826eca55 100644 --- a/datafusion/core/tests/sqllogictests/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -43,9 +43,9 @@ SELECT '2023-05-01 12:30:00'::timestamp - interval '1 month'; 2023-04-01T12:30:00 # interval - date -query error DataFusion error: type_coercion +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select interval '1 month' - '2023-05-01'::date; # interval - timestamp -query error DataFusion error: type_coercion +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types SELECT interval '1 month' - '2023-05-01 12:30:00'::timestamp; diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt similarity index 60% rename from datafusion/core/tests/sqllogictests/test_files/union.slt rename to datafusion/sqllogictest/test_files/union.slt index 2f33437ca1ad2..b4e338875e247 100644 --- a/datafusion/core/tests/sqllogictests/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -82,6 +82,11 @@ SELECT 2 as x 1 2 +query I +select count(*) from (select id from t1 union all select id from t2) +---- +6 + # csv_union_all statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( @@ -174,6 +179,76 @@ UNION ALL Alice John +# nested_union +query T rowsort +SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) +---- +Alex +Alex_new +Alice +Bob +Bob_new +John +John_new + +# should be un-nested, with a single (logical) aggregate +query TT +EXPLAIN SELECT name FROM t1 UNION (SELECT name from t2 UNION SELECT name || '_new' from t2) +---- +logical_plan +Aggregate: groupBy=[[t1.name]], aggr=[[]] +--Union +----TableScan: t1 projection=[name] +----TableScan: t2 projection=[name] +----Projection: t2.name || Utf8("_new") AS name +------TableScan: t2 projection=[name] +physical_plan +AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] +--CoalesceBatchesExec: target_batch_size=8192 +----RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +--------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] +----------UnionExec +------------MemoryExec: partitions=1, partition_sizes=[1] +------------MemoryExec: partitions=1, partition_sizes=[1] +------------ProjectionExec: expr=[name@0 || _new as name] +--------------MemoryExec: partitions=1, partition_sizes=[1] + +# nested_union_all +query T rowsort +SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2) +---- +Alex +Alex +Alex_new +Alice +Bob +Bob +Bob_new +John +John_new + +# Plan is unnested +query TT +EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2) +---- +logical_plan +Union +--TableScan: t1 projection=[name] +--TableScan: t2 projection=[name] +--Projection: t2.name || Utf8("_new") AS name +----TableScan: t2 projection=[name] +physical_plan +UnionExec +--MemoryExec: partitions=1, partition_sizes=[1] +--MemoryExec: partitions=1, partition_sizes=[1] +--ProjectionExec: expr=[name@0 || _new as name] +----MemoryExec: partitions=1, partition_sizes=[1] + +# Make sure to choose a small batch size to introduce parallelism to the plan. +statement ok +set datafusion.execution.batch_size = 2; + # union_with_type_coercion query TT explain @@ -202,33 +277,36 @@ Union ------TableScan: t1 projection=[id, name] physical_plan UnionExec ---ProjectionExec: expr=[id@0 as id, name@1 as name] -----CoalesceBatchesExec: target_batch_size=8192 -------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "id", index: 0 }, Column { name: "CAST(t2.id AS Int32)", index: 2 }), (Column { name: "name", index: 1 }, Column { name: "name", index: 1 })] ---------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "id", index: 0 }, Column { name: "name", index: 1 }], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "CAST(t2.id AS Int32)", index: 2 }, Column { name: "name", index: 1 }], 4), input_partitions=4 -------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] +------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +----------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] --ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] -----ProjectionExec: expr=[id@0 as id, name@1 as name] -------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "CAST(t2.id AS Int32)", index: 2 }, Column { name: "id", index: 0 }), (Column { name: "name", index: 1 }, Column { name: "name", index: 1 })] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "CAST(t2.id AS Int32)", index: 2 }, Column { name: "name", index: 1 }], 4), input_partitions=4 ---------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] -----------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] -------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------RepartitionExec: partitioning=Hash([Column { name: "id", index: 0 }, Column { name: "name", index: 1 }], 4), input_partitions=4 -----------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] -------------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([Column { name: "id", index: 0 }, Column { name: "name", index: 1 }], 4), input_partitions=4 ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +----CoalesceBatchesExec: target_batch_size=2 +------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([CAST(t2.id AS Int32)@2, name@1], 4), input_partitions=4 +------------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] +--------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] +----------------CoalesceBatchesExec: target_batch_size=2 +------------------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +--------------------AggregateExec: mode=Partial, gby=[id@0 as id, name@1 as name], aggr=[] +----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([id@0, name@1], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------MemoryExec: partitions=1, partition_sizes=[1] + query IT rowsort ( @@ -273,26 +351,30 @@ Union ----TableScan: t1 projection=[name] physical_plan InterleaveExec ---CoalesceBatchesExec: target_batch_size=8192 -----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "name", index: 0 }, Column { name: "name", index: 0 })] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(name@0, name@0)] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 -----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] ---CoalesceBatchesExec: target_batch_size=8192 -----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(Column { name: "name", index: 0 }, Column { name: "name", index: 0 })] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] +--CoalesceBatchesExec: target_batch_size=2 +----HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(name@0, name@0)] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 -----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------MemoryExec: partitions=1, partition_sizes=[1] # union_upcast_types query TT @@ -308,7 +390,7 @@ Limit: skip=0, fetch=5 --------TableScan: aggregate_test_100 projection=[c1, c3] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortPreservingMergeExec: [c9@1 DESC] +--SortPreservingMergeExec: [c9@1 DESC], fetch=5 ----UnionExec ------SortExec: expr=[c9@1 DESC] --------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9] @@ -338,27 +420,29 @@ SELECT count(*) FROM ( ) GROUP BY name ---- logical_plan -Projection: COUNT(UInt8(1)) ---Aggregate: groupBy=[[t1.name]], aggr=[[COUNT(UInt8(1))]] +Projection: COUNT(*) +--Aggregate: groupBy=[[t1.name]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----Union ------Aggregate: groupBy=[[t1.name]], aggr=[[]] --------TableScan: t1 projection=[name] ------Aggregate: groupBy=[[t2.name]], aggr=[[]] --------TableScan: t2 projection=[name] physical_plan -ProjectionExec: expr=[COUNT(UInt8(1))@1 as COUNT(UInt8(1))] ---AggregateExec: mode=Single, gby=[name@0 as name], aggr=[COUNT(UInt8(1))] +ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] +--AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[COUNT(*)] ----InterleaveExec ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] ------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "name", index: 0 }], 4), input_partitions=4 +--------CoalesceBatchesExec: target_batch_size=2 +----------RepartitionExec: partitioning=Hash([name@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[name@0 as name], aggr=[] ---------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] ######## @@ -464,15 +548,14 @@ physical_plan UnionExec --ProjectionExec: expr=[Int64(1)@0 as a] ----AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1)], aggr=[] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Column { name: "Int64(1)", index: 0 }], 4), input_partitions=4 -----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] ---------------EmptyExec: produce_one_row=true +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 +----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[] +------------PlaceholderRowExec --ProjectionExec: expr=[2 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec --ProjectionExec: expr=[3 as a] -----EmptyExec: produce_one_row=true +----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -482,8 +565,8 @@ select x, y from (select 1 as x , max(10) as y) b ---- logical_plan Union ---Projection: COUNT(UInt8(1)) AS count, a.n -----Aggregate: groupBy=[[a.n]], aggr=[[COUNT(UInt8(1))]] +--Projection: COUNT(*) AS count, a.n +----Aggregate: groupBy=[[a.n]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ------SubqueryAlias: a --------Projection: Int64(5) AS n ----------EmptyRelation @@ -494,15 +577,13 @@ Union ----------EmptyRelation physical_plan UnionExec ---ProjectionExec: expr=[COUNT(UInt8(1))@1 as count, n@0 as n] -----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[COUNT(UInt8(1))] -------CoalesceBatchesExec: target_batch_size=8192 ---------RepartitionExec: partitioning=Hash([Column { name: "n", index: 0 }], 4), input_partitions=4 -----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(UInt8(1))] ---------------ProjectionExec: expr=[5 as n] -----------------EmptyExec: produce_one_row=true ---ProjectionExec: expr=[x@0 as count, y@1 as n] -----ProjectionExec: expr=[1 as x, MAX(Int64(10))@0 as y] -------AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] ---------EmptyExec: produce_one_row=true +--ProjectionExec: expr=[COUNT(*)@1 as count, n@0 as n] +----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[COUNT(*)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 +----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)] +------------ProjectionExec: expr=[5 as n] +--------------PlaceholderRowExec +--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] +----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] +------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt new file mode 100644 index 0000000000000..6412c3ca859e4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/update.slt @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Update Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + +query TT +explain update t1 set a=1, b=2, c=3.0, d=NULL; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +----TableScan: t1 + +query TT +explain update t1 set a=c+1, b=a, c=c+1.0, d=b; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +----TableScan: t1 + +statement ok +create table t2(a int, b varchar, c double, d int); + +## set from subquery +query TT +explain update t1 set b = (select max(b) from t2 where t1.a = t2.a) +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, () AS b, t1.c AS c, t1.d AS d +----Subquery: +------Projection: MAX(t2.b) +--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]] +----------Filter: outer_ref(t1.a) = t2.a +------------TableScan: t2 +----TableScan: t1 + +# set from other table +query TT +explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------TableScan: t1 +--------TableScan: t2 + +statement ok +create table t3(a int, b varchar, c double, d int); + +# set from mutiple tables, sqlparser only supports from one table +query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) +explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; + +# test table alias +query TT +explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------SubqueryAlias: t +----------TableScan: t1 +--------TableScan: t2 diff --git a/datafusion/core/tests/sqllogictests/test_files/wildcard.slt b/datafusion/sqllogictest/test_files/wildcard.slt similarity index 98% rename from datafusion/core/tests/sqllogictests/test_files/wildcard.slt rename to datafusion/sqllogictest/test_files/wildcard.slt index cc43ff4376e68..f83e84804a377 100644 --- a/datafusion/core/tests/sqllogictests/test_files/wildcard.slt +++ b/datafusion/sqllogictest/test_files/wildcard.slt @@ -41,7 +41,7 @@ CREATE EXTERNAL TABLE aggregate_simple ( ) STORED AS CSV WITH HEADER ROW -LOCATION 'tests/data/aggregate_simple.csv' +LOCATION '../core/tests/data/aggregate_simple.csv' ########## diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt similarity index 51% rename from datafusion/core/tests/sqllogictests/test_files/window.slt rename to datafusion/sqllogictest/test_files/window.slt index c0b861fd8a10b..f3de5b54fc8b3 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -43,13 +43,13 @@ CREATE EXTERNAL TABLE null_cases( ) STORED AS CSV WITH HEADER ROW -LOCATION 'tests/data/null_cases.csv'; +LOCATION '../core/tests/data/null_cases.csv'; ### This is the same table as ### execute_with_partition with 4 partitions statement ok CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION 'tests/data/partitioned_csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; # for window functions without order by the first, last, and nth function call does not make sense @@ -275,17 +275,17 @@ SortPreservingMergeExec: [b@0 ASC NULLS LAST] ----ProjectionExec: expr=[b@0 as b, MAX(d.a)@1 as max_a] ------AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[MAX(d.a)] --------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([Column { name: "b", index: 0 }], 4), input_partitions=4 +----------RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 ------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] --------------UnionExec ----------------ProjectionExec: expr=[1 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[3 as a, aa as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[5 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec ----------------ProjectionExec: expr=[7 as a, bb as b] -------------------EmptyExec: produce_one_row=true +------------------PlaceholderRowExec # Check actual result: query TI @@ -357,21 +357,21 @@ Sort: d.b ASC NULLS LAST physical_plan SortPreservingMergeExec: [b@0 ASC NULLS LAST] --ProjectionExec: expr=[b@0 as b, MAX(d.a)@1 as max_a, MAX(d.seq)@2 as MAX(d.seq)] -----AggregateExec: mode=Single, gby=[b@2 as b], aggr=[MAX(d.a), MAX(d.seq)], ordering_mode=FullyOrdered +----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[MAX(d.a), MAX(d.seq)], ordering_mode=Sorted ------ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as seq, a@0 as a, b@1 as b] ---------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[b@1 ASC NULLS LAST,a@0 ASC NULLS LAST] ------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([Column { name: "b", index: 1 }], 4), input_partitions=4 +--------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 ----------------UnionExec ------------------ProjectionExec: expr=[1 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[3 as a, aa as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[5 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec ------------------ProjectionExec: expr=[7 as a, bb as b] ---------------------EmptyExec: produce_one_row=true +--------------------PlaceholderRowExec # check actual result @@ -417,6 +417,7 @@ sum(amount) - lag(sum(amount), 1) over (order by idx) as difference from ( select * from (values ('a', 1, 100), ('a', 2, 150)) as t (col1, idx, amount) ) a group by col1, idx +ORDER BY idx ---- a 1 1 100 NULL NULL a 2 1 150 100 50 @@ -448,7 +449,7 @@ ORDER BY c9 LIMIT 5 ---- -48302 -16100.666666666666 3 -11243 3747.666666666667 3 +11243 3747.666666666666 3 -51311 -17103.666666666668 3 -2391 -797 3 46756 15585.333333333334 3 @@ -468,7 +469,7 @@ LIMIT 5 46721.33333333174 31147.555555554496 216.151181660734 176.486700789477 2639429.333333332 1759619.5555555548 1624.632060908971 1326.50652299774 746202.3333333324 497468.2222222216 863.830037295146 705.314271954156 -768422.9999999981 512281.9999999988 876.597399037893 715.738779164577 +768422.9999999981 512281.9999999988 876.597399037892 715.738779164577 66526.3333333288 44350.88888888587 257.926992254259 210.596507304575 # window_frame_rows_preceding_with_partition_unique_order_by @@ -714,8 +715,13 @@ LIMIT 5 26861 3 -#fn window_frame_ranges_preceding_following -statement error DataFusion error: Internal error: Operator \- is not implemented for types +# fn window_frame_ranges_preceding_following +# when value is outside type range (10000 is outside range of tiny int (type of c2)), +# we should treat values as infinite, hence +# "SUM(c3) OVER(ORDER BY c2 RANGE BETWEEN 10000 PRECEDING AND 10000 FOLLOWING)," +# is functionally equivalent to +# "SUM(c3) OVER(ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)," +query III SELECT SUM(c4) OVER(ORDER BY c2 RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), SUM(c3) OVER(ORDER BY c2 RANGE BETWEEN 10000 PRECEDING AND 10000 FOLLOWING), @@ -723,6 +729,13 @@ COUNT(*) OVER(ORDER BY c2 RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +52276 781 56 +260620 781 63 +-28623 781 37 +260620 781 63 +260620 781 63 + #fn window_frame_ranges_ntile @@ -882,14 +895,14 @@ SELECT statement ok create table temp as values -(1664264591000000000), -(1664264592000000000), -(1664264592000000000), -(1664264593000000000), -(1664264594000000000), -(1664364594000000000), -(1664464594000000000), -(1664564594000000000); +(1664264591), +(1664264592), +(1664264592), +(1664264593), +(1664264594), +(1664364594), +(1664464594), +(1664564594); statement ok create table t as select cast(column1 as timestamp) as ts from temp; @@ -928,23 +941,6 @@ FROM aggregate_test_100 ORDER BY c9 LIMIT 5 - -#fn window_frame_groups_preceding_following_desc -query III -SELECT -SUM(c4) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING), -SUM(c3) OVER(ORDER BY c2 DESC GROUPS BETWEEN 10000 PRECEDING AND 10000 FOLLOWING), -COUNT(*) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) -FROM aggregate_test_100 -ORDER BY c9 -LIMIT 5 ----- -52276 781 56 -260620 781 63 --28623 781 37 -260620 781 63 -260620 781 63 - #fn window_frame_groups_order_by_null_desc query I SELECT @@ -1213,9 +1209,9 @@ Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregat --------TableScan: aggregate_test_100 projection=[c8, c9] physical_plan ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum2] ---BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as SUM(aggregate_test_100.c9)] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@1 ASC NULLS LAST,c8@0 ASC NULLS LAST] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c8, c9], has_header=true @@ -1233,10 +1229,10 @@ Projection: aggregate_test_100.c2, MAX(aggregate_test_100.c9) ORDER BY [aggregat ------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] --------TableScan: aggregate_test_100 projection=[c2, c9] physical_plan -ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)] ---WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] -----BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: "MAX(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] -------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: "MIN(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] +ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +----BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], has_header=true @@ -1257,11 +1253,11 @@ Sort: aggregate_test_100.c2 ASC NULLS LAST ----------TableScan: aggregate_test_100 projection=[c2, c9] physical_plan SortExec: expr=[c2@0 ASC NULLS LAST] ---ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)] -----WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] -------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: "MAX(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +----WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@1 ASC NULLS LAST,c2@0 ASC NULLS LAST] -----------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: "MIN(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------SortExec: expr=[c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] --------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], has_header=true @@ -1276,22 +1272,22 @@ EXPLAIN SELECT FROM aggregate_test_100 ---- logical_plan -Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING ---WindowAggr: windowExpr=[[COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING +--WindowAggr: windowExpr=[[COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ----Projection: aggregate_test_100.c1, aggregate_test_100.c2, SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING -------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] --------TableScan: aggregate_test_100 projection=[c1, c2, c4] physical_plan -ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as COUNT(UInt8(1))] ---BoundedWindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +--BoundedWindowAggExec: wdw=[COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ----SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] ------CoalesceBatchesExec: target_batch_size=4096 ---------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 2), input_partitions=2 -----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4)] -------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: "SUM(aggregate_test_100.c4)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] ----------------CoalesceBatchesExec: target_batch_size=4096 -------------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }, Column { name: "c2", index: 1 }], 2), input_partitions=2 +------------------RepartitionExec: partitioning=Hash([c1@0, c2@1], 2), input_partitions=2 --------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c4], has_header=true @@ -1315,8 +1311,8 @@ Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregat physical_plan ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -1356,8 +1352,8 @@ Projection: aggregate_test_100.c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [ physical_plan ProjectionExec: expr=[c9@0 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101))", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)) }], mode=[Sorted] -------BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101))", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -1399,9 +1395,9 @@ Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 physical_plan ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ------SortExec: expr=[c9@0 ASC NULLS LAST] ---------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ----------SortExec: expr=[c9@0 DESC] ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -1441,10 +1437,10 @@ Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregat physical_plan ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as rn2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ------SortExec: expr=[c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] ---------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ------------SortExec: expr=[c9@2 DESC,c1@0 DESC] --------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9], has_header=true @@ -1523,19 +1519,19 @@ Projection: SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BET physical_plan ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@15 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as o11] --GlobalLimitExec: skip=0, fetch=5 -----WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] -------ProjectionExec: expr=[c1@0 as c1, c3@2 as c3, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as SUM(null_cases.c1), SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as SUM(null_cases.c1)] ---------BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] +------ProjectionExec: expr=[c1@0 as c1, c3@2 as c3, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[c3@2 ASC NULLS LAST,c2@1 ASC NULLS LAST] -------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------------SortExec: expr=[c3@2 ASC NULLS LAST,c1@0 ASC] -----------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------------SortExec: expr=[c3@2 ASC NULLS LAST,c1@0 DESC] ---------------------WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }] -----------------------WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] +--------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }] +----------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] ------------------------SortExec: expr=[c3@2 DESC NULLS LAST] ---------------------------WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] -----------------------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: "SUM(null_cases.c1)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }] +----------------------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------------------------------SortExec: expr=[c3@2 DESC,c1@0 ASC NULLS LAST] --------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/null_cases.csv]]}, projection=[c1, c2, c3], has_header=true @@ -1609,8 +1605,8 @@ Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregat physical_plan ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] + ----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] --------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true @@ -1653,8 +1649,8 @@ Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) PARTITION BY [aggr physical_plan ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] --------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true @@ -1698,9 +1694,9 @@ Projection: aggregate_test_100.c3, SUM(aggregate_test_100.c9) ORDER BY [aggregat physical_plan ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)) }] -------ProjectionExec: expr=[c3@1 as c3, c4@2 as c4, c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(aggregate_test_100.c9)] ---------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)) }] +------ProjectionExec: expr=[c3@1 as c3, c4@2 as c4, c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[c3@1 + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST] ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c4, c9], has_header=true @@ -1732,26 +1728,26 @@ EXPLAIN SELECT count(*) as global_count FROM ORDER BY c1 ) AS a ---- logical_plan -Projection: COUNT(UInt8(1)) AS global_count ---Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] +Projection: COUNT(*) AS global_count +--Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] ----SubqueryAlias: a -------Sort: aggregate_test_100.c1 ASC NULLS LAST ---------Projection: aggregate_test_100.c1 -----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(UInt8(1))]] +------Projection: +--------Sort: aggregate_test_100.c1 ASC NULLS LAST +----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] ------------Projection: aggregate_test_100.c1 --------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") ----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan -ProjectionExec: expr=[COUNT(UInt8(1))@0 as global_count] ---AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))] +ProjectionExec: expr=[COUNT(*)@0 as global_count] +--AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))] +------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 -----------ProjectionExec: expr=[c1@0 as c1] -------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))] +----------ProjectionExec: expr=[] +------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] --------------CoalesceBatchesExec: target_batch_size=4096 -----------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 2), input_partitions=2 -------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))] +----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] --------------------ProjectionExec: expr=[c1@0 as c1] ----------------------CoalesceBatchesExec: target_batch_size=4096 ------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 @@ -1792,15 +1788,15 @@ Limit: skip=0, fetch=5 ------------TableScan: aggregate_test_100 projection=[c2, c3, c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortPreservingMergeExec: [c3@0 ASC NULLS LAST] +--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5 ----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC] ----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([Column { name: "c3", index: 0 }], 2), input_partitions=2 +------------RepartitionExec: partitioning=Hash([c3@0], 2), input_partitions=2 --------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------ProjectionExec: expr=[c3@1 as c3, c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(aggregate_test_100.c9)] -------------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----------------ProjectionExec: expr=[c3@1 as c3, c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------------------SortExec: expr=[c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST] ----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c9], has_header=true @@ -1836,10 +1832,10 @@ Sort: aggregate_test_100.c1 ASC NULLS LAST physical_plan SortPreservingMergeExec: [c1@0 ASC NULLS LAST] --ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -----BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ------SortExec: expr=[c1@0 ASC NULLS LAST] --------CoalesceBatchesExec: target_batch_size=4096 -----------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 2), input_partitions=2 +----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 ------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true @@ -1948,6 +1944,9 @@ e 20 e 21 # test_window_agg_global_sort_parallelize_sort_disabled +# even if, parallelize sort is disabled, we should use SortPreservingMergeExec +# instead of CoalescePartitionsExec + SortExec stack. Because at the end +# we already have the desired ordering. statement ok set datafusion.optimizer.repartition_sorts = false; @@ -1960,15 +1959,14 @@ Sort: aggregate_test_100.c1 ASC NULLS LAST ----WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1] physical_plan -SortExec: expr=[c1@0 ASC NULLS LAST] ---CoalescePartitionsExec -----ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] ---------SortExec: expr=[c1@0 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=4096 -------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 2), input_partitions=2 ---------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true +SortPreservingMergeExec: [c1@0 ASC NULLS LAST,rn1@1 ASC NULLS LAST] +--ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }], mode=[Sorted] +------SortExec: expr=[c1@0 ASC NULLS LAST] +--------CoalesceBatchesExec: target_batch_size=4096 +----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true statement ok set datafusion.optimizer.repartition_sorts = true; @@ -1989,13 +1987,13 @@ Sort: aggregate_test_100.c1 ASC NULLS LAST physical_plan SortExec: expr=[c1@0 ASC NULLS LAST] --ProjectionExec: expr=[c1@0 as c1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] -----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ------SortPreservingMergeExec: [c9@1 ASC NULLS LAST] --------SortExec: expr=[c9@1 ASC NULLS LAST] -----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(3)) }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(3)) }], mode=[Sorted] ------------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] --------------CoalesceBatchesExec: target_batch_size=4096 -----------------RepartitionExec: partitioning=Hash([Column { name: "c1", index: 0 }], 2), input_partitions=2 +----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 ------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 --------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true @@ -2019,7 +2017,7 @@ ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1] ------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------GlobalLimitExec: skip=0, fetch=1 -------------SortExec: fetch=1, expr=[c13@0 ASC NULLS LAST] +------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST] --------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true @@ -2079,13 +2077,13 @@ Limit: skip=0, fetch=5 ----------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[c9@0 ASC NULLS LAST] +--SortExec: TopK(fetch=5), expr=[c9@0 ASC NULLS LAST] ----ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ---------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9), SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(aggregate_test_100.c9), SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(aggregate_test_100.c9)] -----------WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] -------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ---------------WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: "SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] +----------WindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] +------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------------WindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] ----------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] ------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true @@ -2137,18 +2135,15 @@ Projection: t1.c9, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NU physical_plan ProjectionExec: expr=[c9@1 as c9, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(t1.c9): Ok(Field { name: "SUM(t1.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(t1.c9), SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(t1.c9), SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(t1.c9)] ---------WindowAggExec: wdw=[SUM(t1.c9): Ok(Field { name: "SUM(t1.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] -----------SortExec: expr=[c2@0 ASC NULLS LAST,c1_alias@3 ASC NULLS LAST,c9@2 ASC NULLS LAST,c8@1 ASC NULLS LAST] -------------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as SUM(t1.c9), SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(t1.c9)] ---------------BoundedWindowAggExec: wdw=[SUM(t1.c9): Ok(Field { name: "SUM(t1.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------------WindowAggExec: wdw=[SUM(t1.c9): Ok(Field { name: "SUM(t1.c9)", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] -------------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] ---------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] -----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true - - +----BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] +--------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] +----------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] +------------BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)) }] +----------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] +------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true query IIIII SELECT c9, @@ -2186,11 +2181,11 @@ Projection: sum1, sum2 physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c9@2 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] ---------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)) }], mode=[Sorted] -----------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as SUM(aggregate_test_100.c12)] -------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)) }], mode=[Sorted] +----------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] --------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] ----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9, c12], has_header=true @@ -2226,7 +2221,7 @@ Limit: skip=0, fetch=5 physical_plan GlobalLimitExec: skip=0, fetch=5 --ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -----BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------SortExec: expr=[c9@0 ASC NULLS LAST] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2265,7 +2260,7 @@ Limit: skip=0, fetch=5 physical_plan GlobalLimitExec: skip=0, fetch=5 --ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -----BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------SortExec: expr=[c9@0 DESC] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2303,9 +2298,9 @@ Limit: skip=0, fetch=5 ----------TableScan: aggregate_test_100 projection=[c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[rn1@1 DESC] +--SortExec: TopK(fetch=5), expr=[rn1@1 DESC] ----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2346,9 +2341,9 @@ Limit: skip=0, fetch=5 ----------TableScan: aggregate_test_100 projection=[c9] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[rn1@1 ASC NULLS LAST,c9@0 ASC NULLS LAST] +--SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST,c9@0 ASC NULLS LAST] ----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -------BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------SortExec: expr=[c9@0 DESC] ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2401,7 +2396,7 @@ Limit: skip=0, fetch=5 physical_plan GlobalLimitExec: skip=0, fetch=5 --ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -----BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ------SortExec: expr=[c9@0 DESC] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2429,6 +2424,29 @@ GlobalLimitExec: skip=0, fetch=5 ------SortExec: expr=[CAST(c9@1 AS Int32) + c5@0 DESC] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5, c9], has_header=true +# Ordering equivalence should be preserved during cast expression +query TT +EXPLAIN SELECT c9, rn1 FROM (SELECT c9, + CAST(ROW_NUMBER() OVER(ORDER BY c9 DESC) as BIGINT) as rn1 + FROM aggregate_test_100 + ORDER BY c9 DESC) + ORDER BY rn1 ASC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: rn1 ASC NULLS LAST, fetch=5 +----Sort: aggregate_test_100.c9 DESC NULLS FIRST +------Projection: aggregate_test_100.c9, CAST(ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS Int64) AS rn1 +--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------TableScan: aggregate_test_100 projection=[c9] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--ProjectionExec: expr=[c9@0 as c9, CAST(ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 AS Int64) as rn1] +----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------SortExec: expr=[c9@0 DESC] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true + # The following query has type error. We should test the error could be detected # from either the logical plan (when `skip_failed_rules` is set to `false`) or # the physical plan (when `skip_failed_rules` is set to `true`). @@ -2468,7 +2486,7 @@ CREATE EXTERNAL TABLE annotated_data_finite ( STORED AS CSV WITH HEADER ROW WITH ORDER (ts ASC) -LOCATION 'tests/data/window_1.csv' +LOCATION '../core/tests/data/window_1.csv' ; # 100 rows. Columns in the table are ts, inc_col, desc_col. @@ -2483,7 +2501,7 @@ CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite ( STORED AS CSV WITH HEADER ROW WITH ORDER (ts ASC) -LOCATION 'tests/data/window_1.csv'; +LOCATION '../core/tests/data/window_1.csv'; # test_source_sorted_aggregate @@ -2521,22 +2539,26 @@ logical_plan Projection: sum1, sum2, sum3, min1, min2, min3, max1, max2, max3, cnt1, cnt2, sumr1, sumr2, sumr3, minr1, minr2, minr3, maxr1, maxr2, maxr3, cntr1, cntr2, sum4, cnt3 --Limit: skip=0, fetch=5 ----Sort: annotated_data_finite.inc_col DESC NULLS FIRST, fetch=5 -------Projection: SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col ---------WindowAggr: windowExpr=[[SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -----------Projection: annotated_data_finite.inc_col, annotated_data_finite.desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING -------------WindowAggr: windowExpr=[[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] ---------------WindowAggr: windowExpr=[[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -----------------TableScan: annotated_data_finite projection=[ts, inc_col, desc_col] +------Projection: SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.desc_col AS Int64)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +----------Projection: annotated_data_finite.inc_col, annotated_data_finite.desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING +------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +--------------Projection: CAST(annotated_data_finite.inc_col AS Int64) AS CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, annotated_data_finite.ts, annotated_data_finite.inc_col, annotated_data_finite.desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING +----------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +------------------Projection: CAST(annotated_data_finite.desc_col AS Int64) AS CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, annotated_data_finite.ts, annotated_data_finite.inc_col, annotated_data_finite.desc_col +--------------------TableScan: annotated_data_finite projection=[ts, inc_col, desc_col] physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@24 DESC] -------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@0 as inc_col] ---------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.desc_col): Ok(Field { name: "SUM(annotated_data_finite.desc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------ProjectionExec: expr=[inc_col@1 as inc_col, desc_col@2 as desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@3 as SUM(annotated_data_finite.inc_col), SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@4 as SUM(annotated_data_finite.desc_col), SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(annotated_data_finite.desc_col), MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@6 as MIN(annotated_data_finite.inc_col), MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@7 as MIN(annotated_data_finite.desc_col), MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as MIN(annotated_data_finite.inc_col), MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as MAX(annotated_data_finite.inc_col), MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@10 as MAX(annotated_data_finite.desc_col), MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@11 as MAX(annotated_data_finite.inc_col), COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@12 as COUNT(UInt8(1)), COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@13 as COUNT(UInt8(1)), SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as SUM(annotated_data_finite.inc_col), SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@15 as SUM(annotated_data_finite.desc_col), SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as SUM(annotated_data_finite.inc_col), MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as MIN(annotated_data_finite.inc_col), MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@18 as MIN(annotated_data_finite.desc_col), MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@19 as MIN(annotated_data_finite.inc_col), MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@20 as MAX(annotated_data_finite.inc_col), MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@21 as MAX(annotated_data_finite.desc_col), MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as MAX(annotated_data_finite.inc_col), COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@23 as COUNT(UInt8(1)), COUNT(UInt8(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as COUNT(UInt8(1))] -------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col): Ok(Field { name: "SUM(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col): Ok(Field { name: "SUM(annotated_data_finite.desc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.inc_col): Ok(Field { name: "SUM(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MIN(annotated_data_finite.desc_col): Ok(Field { name: "MIN(annotated_data_finite.desc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MAX(annotated_data_finite.desc_col): Ok(Field { name: "MAX(annotated_data_finite.desc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ---------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col): Ok(Field { name: "SUM(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col): Ok(Field { name: "SUM(annotated_data_finite.desc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col): Ok(Field { name: "SUM(annotated_data_finite.desc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MIN(annotated_data_finite.desc_col): Ok(Field { name: "MIN(annotated_data_finite.desc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MAX(annotated_data_finite.desc_col): Ok(Field { name: "MAX(annotated_data_finite.desc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: "COUNT(UInt8(1))", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +----SortExec: TopK(fetch=5), expr=[inc_col@24 DESC] +------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@0 as inc_col] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----------ProjectionExec: expr=[inc_col@2 as inc_col, desc_col@3 as desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@4 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@5 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@9 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@12 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@13 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@14 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@16 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@17 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@18 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@19 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@22 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@23 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@24 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] +------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------------ProjectionExec: expr=[CAST(inc_col@2 AS Int64) as CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, ts@1 as ts, inc_col@2 as inc_col, desc_col@3 as desc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@4 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@5 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@9 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@12 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@13 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@14 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] +----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }], mode=[Sorted] +------------------ProjectionExec: expr=[CAST(desc_col@2 AS Int64) as CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, ts@0 as ts, inc_col@1 as inc_col, desc_col@2 as desc_col] +--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIII SELECT @@ -2579,6 +2601,7 @@ SELECT # test_source_sorted_builtin query TT EXPLAIN SELECT + ts, FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, @@ -2608,24 +2631,23 @@ EXPLAIN SELECT LIMIT 5; ---- logical_plan -Projection: fv1, fv2, lv1, lv2, nv1, nv2, rn1, rn2, rank1, rank2, dense_rank1, dense_rank2, lag1, lag2, lead1, lead2, fvr1, fvr2, lvr1, lvr2, lagr1, lagr2, leadr1, leadr2 ---Limit: skip=0, fetch=5 -----Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -------Projection: FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2, annotated_data_finite.ts ---------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -----------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -------------TableScan: annotated_data_finite projection=[ts, inc_col] +Limit: skip=0, fetch=5 +--Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 +----Projection: annotated_data_finite.ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +----------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan -ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2, rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1, dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1, lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1, leadr2@23 as leadr2] ---GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[ts@24 DESC] -------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2, ts@0 as ts] ---------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)): Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)): Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER(): Ok(Field { name: "ROW_NUMBER()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK(): Ok(Field { name: "RANK()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK(): Ok(Field { name: "RANK()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK(): Ok(Field { name: "DENSE_RANK()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK(): Ok(Field { name: "DENSE_RANK()", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)): Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)): Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)): Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)): Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)): Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)): Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }], mode=[Sorted] -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[ts@0 DESC] +----ProjectionExec: expr=[ts@0 as ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }], mode=[Sorted] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true -query IIIIIIIIIIIIIIIIIIIIIIII +query IIIIIIIIIIIIIIIIIIIIIIIII SELECT + ts, FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, @@ -2651,14 +2673,14 @@ SELECT LEAD(inc_col, -1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS leadr1, LEAD(inc_col, 4, 1004) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as leadr2 FROM annotated_data_finite - ORDER BY ts DESC + ORDER BY ts DESC, fv2 LIMIT 5; ---- -289 269 305 305 305 283 100 100 99 99 86 86 301 296 301 1004 305 305 301 301 1001 1002 1001 289 -289 266 305 305 305 278 99 99 99 99 86 86 296 291 296 1004 305 305 301 296 305 1002 305 286 -289 261 296 301 NULL 275 98 98 98 98 85 85 291 289 291 1004 305 305 296 291 301 305 301 283 -286 259 291 296 NULL 272 97 97 97 97 84 84 289 286 289 1004 305 305 291 289 296 301 296 278 -275 254 289 291 289 269 96 96 96 96 83 83 286 283 286 305 305 305 289 286 291 296 291 275 +264 289 266 305 305 305 278 99 99 99 99 86 86 296 291 296 1004 305 305 301 296 305 1002 305 286 +264 289 269 305 305 305 283 100 100 99 99 86 86 301 296 301 1004 305 305 301 301 1001 1002 1001 289 +262 289 261 296 301 NULL 275 98 98 98 98 85 85 291 289 291 1004 305 305 296 291 301 305 301 283 +258 286 259 291 296 NULL 272 97 97 97 97 84 84 289 286 289 1004 305 305 291 289 296 301 296 278 +254 275 254 289 291 289 269 96 96 96 96 83 83 286 283 286 305 305 305 289 286 291 296 291 275 # test_source_sorted_unbounded_preceding @@ -2684,16 +2706,16 @@ Projection: sum1, sum2, min1, min2, max1, max2, count1, count2, avg1, avg2 --Limit: skip=0, fetch=5 ----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 ------Projection: SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS sum1, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS avg2, annotated_data_finite.inc_col ---------WindowAggr: windowExpr=[[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] -----------WindowAggr: windowExpr=[[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, AVG(CAST(annotated_data_finite.inc_col AS Float64)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, AVG(CAST(annotated_data_finite.inc_col AS Float64)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] ------------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@10 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST] ------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@7 as sum1, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@8 as min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as avg2, inc_col@1 as inc_col] ---------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col): Ok(Field { name: "SUM(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, COUNT(annotated_data_finite.inc_col): Ok(Field { name: "COUNT(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, AVG(annotated_data_finite.inc_col): Ok(Field { name: "AVG(annotated_data_finite.inc_col)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col): Ok(Field { name: "SUM(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MIN(annotated_data_finite.inc_col): Ok(Field { name: "MIN(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MAX(annotated_data_finite.inc_col): Ok(Field { name: "MAX(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, COUNT(annotated_data_finite.inc_col): Ok(Field { name: "COUNT(annotated_data_finite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, AVG(annotated_data_finite.inc_col): Ok(Field { name: "AVG(annotated_data_finite.inc_col)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }], mode=[Sorted] ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIRR @@ -2743,10 +2765,10 @@ Projection: first_value1, first_value2, last_value1, last_value2, nth_value1 physical_plan ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[inc_col@5 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST] ------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] ---------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)): Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2))", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, LAST_VALUE(annotated_data_finite.inc_col): Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIII @@ -2783,16 +2805,16 @@ Projection: sum1, sum2, count1, count2 --Limit: skip=0, fetch=5 ----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 ------Projection: SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts ---------WindowAggr: windowExpr=[[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -----------WindowAggr: windowExpr=[[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] ------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] --GlobalLimitExec: skip=0, fetch=5 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] -------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col): Ok(Field { name: "SUM(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col): Ok(Field { name: "COUNT(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ---------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col): Ok(Field { name: "SUM(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col): Ok(Field { name: "COUNT(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2829,16 +2851,16 @@ Projection: sum1, sum2, count1, count2 --Limit: skip=0, fetch=5 ----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 ------Projection: SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts ---------WindowAggr: windowExpr=[[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -----------WindowAggr: windowExpr=[[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] ------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] --GlobalLimitExec: skip=0, fetch=5 ----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as count2, ts@0 as ts] -------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col): Ok(Field { name: "SUM(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col): Ok(Field { name: "COUNT(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ---------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col): Ok(Field { name: "SUM(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col): Ok(Field { name: "COUNT(annotated_data_infinite.inc_col)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }], mode=[Sorted] +----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2883,7 +2905,7 @@ CREATE EXTERNAL TABLE annotated_data_finite2 ( STORED AS CSV WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION 'tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv'; # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2899,7 +2921,7 @@ CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite2 ( STORED AS CSV WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION 'tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv'; # test_infinite_source_partition_by @@ -2924,23 +2946,25 @@ EXPLAIN SELECT a, b, c, logical_plan Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 --Limit: skip=0, fetch=5 -----WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] -------WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] ---------WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -----------WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] -------------WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] ---------------WindowAggr: windowExpr=[[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -----------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] +----WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] +------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] +--------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +----------------Projection: CAST(annotated_data_infinite2.c AS Int64) AS CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d +------------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan -ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@8 as sum1, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@9 as sum2, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@14 as sum3, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@15 as sum4, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@4 as sum5, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@5 as sum6, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@10 as sum7, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@11 as sum8, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@6 as sum9, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@7 as sum10, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@12 as sum11, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@13 as sum12] +ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] --GlobalLimitExec: skip=0, fetch=5 -----BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)) }], mode=[Linear] -------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)) }], mode=[PartiallySorted([1, 0])] ---------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[PartiallySorted([0])] -------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[PartiallySorted([0, 1])] ---------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c): Ok(Field { name: "SUM(annotated_data_infinite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)) }], mode=[Linear] +------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)) }], mode=[PartiallySorted([1, 0])] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[PartiallySorted([0])] +------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[PartiallySorted([0, 1])] +--------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query IIIIIIIIIIIIIII @@ -2992,29 +3016,31 @@ logical_plan Limit: skip=0, fetch=5 --Sort: annotated_data_finite2.c ASC NULLS LAST, fetch=5 ----Projection: annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 -------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] ---------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] -----------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -------------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] ---------------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] -----------------WindowAggr: windowExpr=[[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -------------------TableScan: annotated_data_finite2 projection=[a, b, c, d] +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] +--------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] +----------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +------------------Projection: CAST(annotated_data_finite2.c AS Int64) AS CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c, annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, annotated_data_finite2.d +--------------------TableScan: annotated_data_finite2 projection=[a, b, c, d] physical_plan GlobalLimitExec: skip=0, fetch=5 ---SortExec: fetch=5, expr=[c@2 ASC NULLS LAST] -----ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@8 as sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@9 as sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@14 as sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@15 as sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@4 as sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@5 as sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@10 as sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@11 as sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@6 as sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@7 as sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@12 as sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@13 as sum12] -------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)) }], mode=[Sorted] ---------SortExec: expr=[d@3 ASC NULLS LAST,a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST] -----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)) }], mode=[Sorted] -------------SortExec: expr=[b@1 ASC NULLS LAST,a@0 ASC NULLS LAST,d@3 ASC NULLS LAST,c@2 ASC NULLS LAST] ---------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------------SortExec: expr=[b@1 ASC NULLS LAST,a@0 ASC NULLS LAST,c@2 ASC NULLS LAST] -------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] ---------------------SortExec: expr=[a@0 ASC NULLS LAST,d@3 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST] -----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[Sorted] -------------------------SortExec: expr=[a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,d@3 ASC NULLS LAST,c@2 ASC NULLS LAST] ---------------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c): Ok(Field { name: "SUM(annotated_data_finite2.c)", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] -----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +--SortExec: TopK(fetch=5), expr=[c@2 ASC NULLS LAST] +----ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] +------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)) }], mode=[Sorted] +--------SortExec: expr=[d@4 ASC NULLS LAST,a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)) }], mode=[Sorted] +------------SortExec: expr=[b@2 ASC NULLS LAST,a@1 ASC NULLS LAST,d@4 ASC NULLS LAST,c@3 ASC NULLS LAST] +--------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----------------SortExec: expr=[b@2 ASC NULLS LAST,a@1 ASC NULLS LAST,c@3 ASC NULLS LAST] +------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +--------------------SortExec: expr=[a@1 ASC NULLS LAST,d@4 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow }], mode=[Sorted] +------------------------SortExec: expr=[a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,d@4 ASC NULLS LAST,c@3 ASC NULLS LAST] +--------------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)) }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)) }], mode=[Sorted] +----------------------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIII @@ -3055,6 +3081,188 @@ statement error DataFusion error: Error during planning: Aggregate ORDER BY is n EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a ASC) OVER (order by a ASC) as last_c FROM annotated_data_infinite2 +# ordering equivalence information +# should propagate through FilterExec, LimitExec, CoalesceBatchesExec, etc. +# Below query should work without breaking pipeline +query TT +EXPLAIN SELECT * FROM (SELECT *, ROW_NUMBER() OVER(ORDER BY a ASC) as rn1 + FROM annotated_data_infinite2 + ORDER BY rn1 ASC + LIMIT 5) + WHERE rn1 < 50 + ORDER BY rn1 ASC +---- +logical_plan +Sort: rn1 ASC NULLS LAST +--Filter: rn1 < UInt64(50) +----Limit: skip=0, fetch=5 +------Sort: rn1 ASC NULLS LAST, fetch=5 +--------Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +----------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +------------TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=4096 +--FilterExec: rn1@5 < 50 +----GlobalLimitExec: skip=0, fetch=5 +------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] +--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] + +# this is a negative test for asserting that window functions (other than ROW_NUMBER) +# are not added to ordering equivalence +# physical plan should contain SortExec. +query TT +EXPLAIN SELECT c9, sum1 FROM (SELECT c9, + SUM(c9) OVER(ORDER BY c9 DESC) as sum1 + FROM aggregate_test_100 + ORDER BY c9 DESC) + ORDER BY sum1, c9 DESC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: sum1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 +----Sort: aggregate_test_100.c9 DESC NULLS FIRST +------Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 +--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------TableScan: aggregate_test_100 projection=[c9] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[sum1@1 ASC NULLS LAST,c9@0 DESC] +----ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1] +------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------SortExec: expr=[c9@0 DESC] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true + +# Query below should work when its input is unbounded +# because ordering of ROW_NUMBER, RANK result is added to the ordering equivalence +# and final plan doesn't contain SortExec. +query IIII +SELECT a, d, rn1, rank1 FROM (SELECT a, d, + ROW_NUMBER() OVER(ORDER BY a ASC) as rn1, + RANK() OVER(ORDER BY a ASC) as rank1 + FROM annotated_data_infinite2 + ORDER BY a ASC) + ORDER BY rn1, rank1, a ASC + LIMIT 5 +---- +0 0 1 1 +0 2 2 1 +0 0 3 1 +0 0 4 1 +0 1 5 1 + +# this is a negative test for asserting that ROW_NUMBER is not +# added to the ordering equivalence when it contains partition by. +# physical plan should contain SortExec. Since source is unbounded +# pipeline checker should raise error, when plan contains SortExec. +statement error DataFusion error: PipelineChecker +SELECT a, d, rn1 FROM (SELECT a, d, + ROW_NUMBER() OVER(PARTITION BY d ORDER BY a ASC) as rn1 + FROM annotated_data_infinite2 + ORDER BY a ASC) + ORDER BY rn1, a ASC + LIMIT 5 + +# when partition by expressions match with existing ordering +# row number can be appended to existing ordering +# below query should work, without breaking pipeline. +query III +SELECT a, d, rn1 FROM (SELECT a, b, c, d, + ROW_NUMBER() OVER(PARTITION BY b, c, a) as rn1 + FROM annotated_data_infinite2 + ORDER BY a ASC) + ORDER BY a, b, c, rn1 + LIMIT 5 +---- +0 0 1 +0 2 1 +0 0 1 +0 0 1 +0 1 1 + +# projection should propagate ordering equivalence successfully +# when expression contains alias +query III +SELECT a_new, d, rn1 FROM (SELECT d, a as a_new, + ROW_NUMBER() OVER(ORDER BY a ASC) as rn1 + FROM annotated_data_infinite2 + ORDER BY a_new ASC) + ORDER BY a_new ASC, rn1 + LIMIT 5 +---- +0 0 1 +0 2 2 +0 0 3 +0 0 4 +0 1 5 + +query TT +EXPLAIN SELECT SUM(a) OVER(partition by a, b order by c) as sum1, +SUM(a) OVER(partition by b, a order by c) as sum2, + SUM(a) OVER(partition by a, d order by b) as sum3, + SUM(a) OVER(partition by d order by a) as sum4 +FROM annotated_data_infinite2; +---- +logical_plan +Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] +--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] +----ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +statement ok +set datafusion.execution.target_partitions = 2; + +# re-execute the same query in multi partitions. +# final plan should still be streamable +query TT +EXPLAIN SELECT SUM(a) OVER(partition by a, b order by c) as sum1, + SUM(a) OVER(partition by b, a order by c) as sum2, + SUM(a) OVER(partition by a, d order by b) as sum3, + SUM(a) OVER(partition by d order by a) as sum4 +FROM annotated_data_infinite2; +---- +logical_plan +Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum4] +--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Linear] +----CoalesceBatchesExec: target_batch_size=4096 +------SortPreservingRepartitionExec: partitioning=Hash([d@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST +--------ProjectionExec: expr=[a@0 as a, d@3 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------CoalesceBatchesExec: target_batch_size=4096 +--------------SortPreservingRepartitionExec: partitioning=Hash([b@1, a@0], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[PartiallySorted([0])] +------------------CoalesceBatchesExec: target_batch_size=4096 +--------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, d@3], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +------------------------CoalesceBatchesExec: target_batch_size=4096 +--------------------------SortPreservingRepartitionExec: partitioning=Hash([a@0, b@1], 2), input_partitions=2, sort_exprs=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,c@2 ASC NULLS LAST +----------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] + +# reset the partition number 1 again +statement ok +set datafusion.execution.target_partitions = 1; + statement ok drop table annotated_data_finite2 @@ -3063,25 +3271,27 @@ drop table annotated_data_infinite2 # window3 spec is not used in window functions. # The query should still work. -query RR +query IRR SELECT - MAX(c12) OVER window1, - MIN(c12) OVER window2 as max1 + C3, + MAX(c12) OVER window1 as max1, + MIN(c12) OVER window2 as max2 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12), window2 AS (PARTITION BY C11), window3 AS (ORDER BY C1) - ORDER BY C3 + ORDER BY C3, max2 LIMIT 5 ---- -0.970671228336 0.970671228336 -0.850672105305 0.850672105305 -0.152498292972 0.152498292972 -0.369363046006 0.369363046006 -0.56535284223 0.56535284223 +-117 0.850672105305 0.850672105305 +-117 0.970671228336 0.970671228336 +-111 0.152498292972 0.152498292972 +-107 0.369363046006 0.369363046006 +-106 0.56535284223 0.56535284223 query TT EXPLAIN SELECT + C3, MAX(c12) OVER window1 as min1, MIN(c12) OVER window2 as max1 FROM aggregate_test_100 @@ -3092,42 +3302,41 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -Projection: min1, max1 ---Limit: skip=0, fetch=5 -----Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 -------Projection: MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1, aggregate_test_100.c3 ---------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -----------Projection: aggregate_test_100.c3, aggregate_test_100.c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING -------------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] ---------------TableScan: aggregate_test_100 projection=[c3, c11, c12] +Limit: skip=0, fetch=5 +--Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 +----Projection: aggregate_test_100.c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1 +------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------Projection: aggregate_test_100.c3, aggregate_test_100.c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +----------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +------------TableScan: aggregate_test_100 projection=[c3, c11, c12] physical_plan -ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] ---GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c3@2 ASC NULLS LAST] -------ProjectionExec: expr=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1, c3@0 as c3] ---------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12): Ok(Field { name: "MAX(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] -----------SortExec: expr=[c12@1 ASC NULLS LAST] -------------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as MIN(aggregate_test_100.c12)] ---------------WindowAggExec: wdw=[MIN(aggregate_test_100.c12): Ok(Field { name: "MIN(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] -----------------SortExec: expr=[c11@1 ASC NULLS LAST] -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true +GlobalLimitExec: skip=0, fetch=5 +--SortExec: TopK(fetch=5), expr=[c3@0 ASC NULLS LAST] +----ProjectionExec: expr=[c3@0 as c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1] +------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------SortExec: expr=[c12@1 ASC NULLS LAST] +----------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +------------WindowAggExec: wdw=[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }] +--------------SortExec: expr=[c11@1 ASC NULLS LAST] +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true # window1 spec is used multiple times under different aggregations. # The query should still work. -query RR +query IRR SELECT + C3, MAX(c12) OVER window1 as min1, MIN(c12) OVER window1 as max1 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12) - ORDER BY C3 + ORDER BY C3, min1 LIMIT 5 ---- -0.970671228336 0.014793053078 -0.850672105305 0.014793053078 -0.152498292972 0.014793053078 -0.369363046006 0.014793053078 -0.56535284223 0.014793053078 +-117 0.850672105305 0.014793053078 +-117 0.970671228336 0.014793053078 +-111 0.152498292972 0.014793053078 +-107 0.369363046006 0.014793053078 +-106 0.56535284223 0.014793053078 query TT EXPLAIN SELECT @@ -3148,9 +3357,9 @@ Projection: min1, max1 physical_plan ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] --GlobalLimitExec: skip=0, fetch=5 -----SortExec: fetch=5, expr=[c3@2 ASC NULLS LAST] +----SortExec: TopK(fetch=5), expr=[c3@2 ASC NULLS LAST] ------ProjectionExec: expr=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min1, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max1, c3@0 as c3] ---------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12): Ok(Field { name: "MAX(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }, MIN(aggregate_test_100.c12): Ok(Field { name: "MIN(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow }], mode=[Sorted] ----------SortExec: expr=[c12@1 ASC NULLS LAST] ------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c12], has_header=true @@ -3173,3 +3382,414 @@ SELECT window1 AS (ORDER BY C3) ORDER BY C3 LIMIT 5 + +# Create a source where there is multiple orderings. +statement ok +CREATE EXTERNAL TABLE multiple_ordered_table ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# Create an unbounded source where there is multiple orderings. +statement ok +CREATE UNBOUNDED EXTERNAL TABLE multiple_ordered_table_inf ( + a0 INTEGER, + a INTEGER, + b INTEGER, + c INTEGER, + d INTEGER +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (a ASC, b ASC) +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv'; + +# All of the window execs in the physical plan should work in the +# sorted mode. +query TT +EXPLAIN SELECT MIN(d) OVER(ORDER BY c ASC) as min1, + MAX(d) OVER(PARTITION BY b, a ORDER BY c ASC) as max1 +FROM multiple_ordered_table +---- +logical_plan +Projection: MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1 +--WindowAggr: windowExpr=[[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Projection: multiple_ordered_table.c, multiple_ordered_table.d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +------WindowAggr: windowExpr=[[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max1] +--BoundedWindowAggExec: wdw=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----ProjectionExec: expr=[c@2 as c, d@3 as d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +------BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query TT +EXPLAIN SELECT MAX(c) OVER(PARTITION BY d ORDER BY c ASC) as max_c +FROM( + SELECT * + FROM multiple_ordered_table + WHERE d=0) +---- +logical_plan +Projection: MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max_c +--WindowAggr: windowExpr=[[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----Filter: multiple_ordered_table.d = Int32(0) +------TableScan: multiple_ordered_table projection=[c, d], partial_filters=[multiple_ordered_table.d = Int32(0)] +physical_plan +ProjectionExec: expr=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max_c] +--BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CoalesceBatchesExec: target_batch_size=4096 +------FilterExec: d@1 = 0 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query TT +explain SELECT SUM(d) OVER(PARTITION BY c ORDER BY a ASC) +FROM multiple_ordered_table; +---- +logical_plan +Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----TableScan: multiple_ordered_table projection=[a, c, d] +physical_plan +ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true + +query TT +explain SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) +FROM multiple_ordered_table; +---- +logical_plan +Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----TableScan: multiple_ordered_table projection=[a, b, c, d] +physical_plan +ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow }], mode=[Sorted] +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query I +SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) +FROM multiple_ordered_table +LIMIT 5; +---- +0 +2 +0 +0 +1 + +# simple window query +query II +select sum(1) over() x, sum(1) over () y +---- +1 1 + +# NTH_VALUE requirement is c DESC, However existing ordering is c ASC +# if we reverse window expression: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" +# as "NTH_VALUE(c, -2) OVER(order by c ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as nv1" +# Please note that: "NTH_VALUE(c, 2) OVER(order by c DESC ) as nv1" is same with +# "NTH_VALUE(c, 2) OVER(order by c DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as nv1" " +# we can produce same result without re-sorting the table. +# Unfortunately since window expression names are string, this change is not seen the plan (we do not do string manipulation). +# TODO: Reflect window expression reversal in the plans. +query TT +EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +logical_plan +Limit: skip=0, fetch=5 +--Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 +----Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 +------WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +--------TableScan: multiple_ordered_table projection=[c] +physical_plan +GlobalLimitExec: skip=0, fetch=5 +--ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +----WindowAggExec: wdw=[NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int32(NULL)) }] +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c ASC + LIMIT 5 +---- +0 98 +1 98 +2 98 +3 98 +4 98 + +query II +SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 + FROM multiple_ordered_table + ORDER BY c DESC + LIMIT 5 +---- +99 NULL +98 98 +97 98 +96 98 +95 98 + +statement ok +set datafusion.execution.target_partitions = 2; + +# source is ordered by [a ASC, b ASC], [c ASC] +# after sort preserving repartition and sort preserving merge +# we should still have the orderings [a ASC, b ASC], [c ASC]. +query TT +EXPLAIN SELECT *, + AVG(d) OVER sliding_window AS avg_d +FROM multiple_ordered_table_inf +WINDOW sliding_window AS ( + PARTITION BY d + ORDER BY a RANGE 10 PRECEDING +) +ORDER BY c +---- +logical_plan +Sort: multiple_ordered_table_inf.c ASC NULLS LAST +--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] +physical_plan +SortPreservingMergeExec: [c@3 ASC NULLS LAST] +--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow }], mode=[Linear] +------CoalesceBatchesExec: target_batch_size=4096 +--------SortPreservingRepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST +----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST] + +# CTAS with NTILE function +statement ok +CREATE TABLE new_table AS SELECT NTILE(2) OVER(ORDER BY c1) AS ntile_2 FROM aggregate_test_100; + +statement ok +DROP TABLE new_table; + +statement ok +CREATE TABLE t1 (a int) AS VALUES (1), (2), (3); + +query I +SELECT NTILE(9223377) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query I +SELECT NTILE(9223372036854775809) OVER(ORDER BY a) FROM t1; +---- +1 +2 +3 + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT NTILE(-922337203685477580) OVER(ORDER BY a) FROM t1; + +query error DataFusion error: Execution error: Table 't' doesn't exist\. +DROP TABLE t; + +# NTILE with PARTITION BY, those tests from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/window/test_ntile.test +statement ok +CREATE TABLE score_board (team_name VARCHAR, player VARCHAR, score INTEGER) as VALUES + ('Mongrels', 'Apu', 350), + ('Mongrels', 'Ned', 666), + ('Mongrels', 'Meg', 1030), + ('Mongrels', 'Burns', 1270), + ('Simpsons', 'Homer', 1), + ('Simpsons', 'Lisa', 710), + ('Simpsons', 'Marge', 990), + ('Simpsons', 'Bart', 2010) + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(2) OVER (ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY score; +---- +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1000) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 + +query TTII +SELECT + team_name, + player, + score, + NTILE(1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s +ORDER BY team_name, score; +---- +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 + +# incorrect number of parameters for ntile +query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL +SELECT + NTILE(NULL) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(-1) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +query error DataFusion error: Execution error: NTILE requires a positive integer +SELECT + NTILE(0) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE() OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement error +SELECT + NTILE(1,2,3,4) OVER (PARTITION BY team_name ORDER BY score ASC) AS NTILE +FROM score_board s + +statement ok +DROP TABLE score_board; + +# Regularize RANGE frame +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query error DataFusion error: Error during planning: RANGE requires exactly one ORDER BY column +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a + +query II +select a, + rank() over (order by a, a + 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (order by a RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 2 + +query II +select a, + rank() over (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query I +select rank() over (RANGE between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q; +---- +1 +1 + +query II +select a, + rank() over (order by 1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 + +query II +select a, + rank() over (order by null RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) rnk + from (select 1 a union select 2 a) q ORDER BY a +---- +1 1 +2 1 diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index bfa9b26407568..42ebe56c298b0 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -18,24 +18,24 @@ [package] name = "datafusion-substrait" description = "DataFusion Substrait Producer and Consumer" +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = { workspace = true } +rust-version = "1.70" [dependencies] async-recursion = "1.0" -chrono = { version = "0.4.23", default-features = false } -datafusion = { version = "26.0.0", path = "../core" } -itertools = "0.10.5" -object_store = "0.6.1" -prost = "0.11" -prost-types = "0.11" -substrait = "0.11.0" +chrono = { workspace = true } +datafusion = { workspace = true } +itertools = { workspace = true } +object_store = { workspace = true } +prost = "0.12" +prost-types = "0.12" +substrait = "0.20.0" tokio = "1.17" [features] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f914b62a1452d..ffc9d094ab910 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,22 +17,26 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; -use datafusion::common::{DFField, DFSchema, DFSchemaRef}; +use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; + +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; -use datafusion::logical_expr::{build_join_schema, Extension, LogicalPlanBuilder}; -use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits}; +use datafusion::logical_expr::{ + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, + WindowFrameUnits, +}; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ error::{DataFusionError, Result}, - optimizer::utils::split_conjunction, + logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::Literal; +use substrait::proto::expression::{Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -47,12 +51,14 @@ use substrait::proto::{ join_rel, plan_rel, r#type, read_rel::ReadType, rel::RelType, + set_rel, sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, Plan, Rel, Type, }; use substrait::proto::{FunctionArgument, SortField}; -use datafusion::logical_expr::expr::Sort; +use datafusion::common::plan_err; +use datafusion::logical_expr::expr::{InList, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -67,6 +73,16 @@ use crate::variation_const::{ enum ScalarFunctionType { Builtin(BuiltinScalarFunction), Op(Operator), + /// [Expr::Not] + Not, + /// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive + Like, + /// [Expr::Like] Case insensitive operator counterpart of `Like` + ILike, + /// [Expr::IsNull] + IsNull, + /// [Expr::IsNotNull] + IsNotNull, } pub fn name_to_op(name: &str) -> Result { @@ -93,16 +109,16 @@ pub fn name_to_op(name: &str) -> Result { "bitwise_and" => Ok(Operator::BitwiseAnd), "bitwise_or" => Ok(Operator::BitwiseOr), "str_concat" => Ok(Operator::StringConcat), + "at_arrow" => Ok(Operator::AtArrow), + "arrow_at" => Ok(Operator::ArrowAt), "bitwise_xor" => Ok(Operator::BitwiseXor), "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported function name: {name:?}" - ))), + _ => not_impl_err!("Unsupported function name: {name:?}"), } } -fn name_to_op_or_scalar_function(name: &str) -> Result { +fn scalar_function_type_from_str(name: &str) -> Result { if let Ok(op) = name_to_op(name) { return Ok(ScalarFunctionType::Op(op)); } @@ -111,14 +127,64 @@ fn name_to_op_or_scalar_function(name: &str) -> Result { return Ok(ScalarFunctionType::Builtin(fun)); } - Err(DataFusionError::NotImplemented(format!( - "Unsupported function name: {name:?}" - ))) + match name { + "not" => Ok(ScalarFunctionType::Not), + "like" => Ok(ScalarFunctionType::Like), + "ilike" => Ok(ScalarFunctionType::ILike), + "is_null" => Ok(ScalarFunctionType::IsNull), + "is_not_null" => Ok(ScalarFunctionType::IsNotNull), + others => not_impl_err!("Unsupported function name: {others:?}"), + } +} + +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) } /// Convert Substrait Plan to DataFusion DataFrame pub async fn from_substrait_plan( - ctx: &mut SessionContext, + ctx: &SessionContext, plan: &Plan, ) -> Result { // Register function extension @@ -130,13 +196,9 @@ pub async fn from_substrait_plan( MappingType::ExtensionFunction(ext_f) => { Ok((ext_f.function_anchor, &ext_f.name)) } - _ => Err(DataFusionError::NotImplemented(format!( - "Extension type not supported: {ext:?}" - ))), + _ => not_impl_err!("Extension type not supported: {ext:?}"), }, - None => Err(DataFusionError::NotImplemented( - "Cannot parse empty extension".to_string(), - )), + None => not_impl_err!("Cannot parse empty extension"), }) .collect::>>()?; // Parse relations @@ -151,20 +213,20 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) } }, - None => Err(DataFusionError::Internal("Cannot parse plan relation: None".to_string())) + None => plan_err!("Cannot parse plan relation: None") } }, - _ => Err(DataFusionError::NotImplemented(format!( + _ => not_impl_err!( "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", plan.relations.len() - ))) + ) } } /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( - ctx: &mut SessionContext, + ctx: &SessionContext, rel: &Rel, extensions: &HashMap, ) -> Result { @@ -193,9 +255,7 @@ pub async fn from_substrait_rel( } input.project(exprs)?.build() } else { - Err(DataFusionError::NotImplemented( - "Projection without an input is not supported".to_string(), - )) + not_impl_err!("Projection without an input is not supported") } } Some(RelType::Filter(filter)) => { @@ -208,14 +268,10 @@ pub async fn from_substrait_rel( from_substrait_rex(condition, input.schema(), extensions).await?; input.filter(expr.as_ref().clone())?.build() } else { - Err(DataFusionError::NotImplemented( - "Filter without an condition is not valid".to_string(), - )) + not_impl_err!("Filter without an condition is not valid") } } else { - Err(DataFusionError::NotImplemented( - "Filter without an input is not valid".to_string(), - )) + not_impl_err!("Filter without an input is not valid") } } Some(RelType::Fetch(fetch)) => { @@ -224,12 +280,15 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - let count = fetch.count as usize; - input.limit(offset, Some(count))?.build() + // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` + let count = if fetch.count as usize == usize::MAX { + None + } else { + Some(fetch.count as usize) + }; + input.limit(offset, count)?.build() } else { - Err(DataFusionError::NotImplemented( - "Fetch without an input is not valid".to_string(), - )) + not_impl_err!("Fetch without an input is not valid") } } Some(RelType::Sort(sort)) => { @@ -241,9 +300,7 @@ pub async fn from_substrait_rel( from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; input.sort(sorts)?.build() } else { - Err(DataFusionError::NotImplemented( - "Sort without an input is not valid".to_string(), - )) + not_impl_err!("Sort without an input is not valid") } } Some(RelType::Aggregate(agg)) => { @@ -254,19 +311,35 @@ pub async fn from_substrait_rel( let mut group_expr = vec![]; let mut aggr_expr = vec![]; - let groupings = match agg.groupings.len() { - 1 => Ok(&agg.groupings[0]), - _ => Err(DataFusionError::NotImplemented( - "Aggregate with multiple grouping sets is not supported" - .to_string(), - )), + match agg.groupings.len() { + 1 => { + for e in &agg.groupings[0].grouping_expressions { + let x = + from_substrait_rex(e, input.schema(), extensions).await?; + group_expr.push(x.as_ref().clone()); + } + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let mut grouping_set = vec![]; + for e in &grouping.grouping_expressions { + let x = from_substrait_rex(e, input.schema(), extensions) + .await?; + grouping_set.push(x.as_ref().clone()); + } + grouping_sets.push(grouping_set); + } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets( + grouping_sets, + ))); + } }; - for e in &groupings?.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions).await?; - group_expr.push(x.as_ref().clone()); - } - for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( @@ -293,6 +366,7 @@ pub async fn from_substrait_rel( _ => false, }; from_substrait_agg_func( + ctx, f, input.schema(), extensions, @@ -303,23 +377,26 @@ pub async fn from_substrait_rel( ) .await } - None => Err(DataFusionError::NotImplemented( + None => not_impl_err!( "Aggregate without aggregate function is not supported" - .to_string(), - )), + ), }; aggr_expr.push(agg_func?.as_ref().clone()); } input.aggregate(group_expr, aggr_expr)?.build() } else { - Err(DataFusionError::NotImplemented( - "Aggregate without an input is not valid".to_string(), - )) + not_impl_err!("Aggregate without an input is not valid") } } Some(RelType::Join(join)) => { - let left = LogicalPlanBuilder::from( + if join.post_join_filter.is_some() { + return not_impl_err!( + "JoinRel with post_join_filter is not yet supported" + ); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, ); let right = LogicalPlanBuilder::from( @@ -328,71 +405,49 @@ pub async fn from_substrait_rel( let join_type = from_substrait_jointype(join.r#type)?; // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins - // - if left and right schemas are different, we combine (join) the schema to include all fields - // - if left and right schemas are the same, we handle the duplicate fields by using `build_join_schema()`, which discard the unused schema - // TODO: Handle duplicate fields error for other join types (non-semi/anti). The current approach does not work due to Substrait's inability - // to encode aliases - let join_schema = match left.schema().join(right.schema()) { - Ok(schema) => Ok(schema), - Err(DataFusionError::SchemaError( - datafusion::common::SchemaError::DuplicateQualifiedField { - qualifier: _, - name: _, - }, - )) => build_join_schema(left.schema(), right.schema(), &join_type), - Err(e) => Err(e), - }; - let on = from_substrait_rex( - join.expression.as_ref().unwrap(), - &join_schema?, - extensions, - ) - .await?; - let predicates = split_conjunction(&on); - // TODO: collect only one null_eq_null - let join_exprs: Vec<(Column, Column, bool)> = predicates - .iter() - .map(|p| match p { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => match op { - Operator::Eq => Ok((l.clone(), r.clone(), false)), - Operator::IsNotDistinctFrom => { - Ok((l.clone(), r.clone(), true)) - } - _ => Err(DataFusionError::Internal( - "invalid join condition op".to_string(), - )), - }, - _ => Err(DataFusionError::Internal( - "invalid join condition expresssion".to_string(), - )), - } - } - _ => Err(DataFusionError::Internal( - "Non-binary expression is not supported in join condition" - .to_string(), - )), - }) - .collect::>>()?; - let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = - itertools::multiunzip(join_exprs); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - None, - null_eq_nulls[0], - )? - .build() + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = + from_substrait_rex(expr, &in_join_schema, extensions).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + nulls_equal_nulls, + )? + .build() + } + None => plan_err!("JoinRel without join condition is not allowed"), + } + } + Some(RelType::Cross(cross)) => { + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + ); + let right = + from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + .await?; + left.cross_join(right)?.build() } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { 0 => { - return Err(DataFusionError::Internal( - "No table name found in NamedTable".to_string(), - )); + return plan_err!("No table name found in NamedTable"); } 1 => TableReference::Bare { table: (&nt.names[0]).into(), @@ -432,9 +487,7 @@ pub async fn from_substrait_rel( )?); Ok(LogicalPlan::TableScan(scan)) } - _ => Err(DataFusionError::Internal( - "unexpected plan for table".to_string(), - )), + _ => plan_err!("unexpected plan for table"), } } _ => Ok(t), @@ -442,9 +495,27 @@ pub async fn from_substrait_rel( _ => Ok(t), } } - _ => Err(DataFusionError::NotImplemented( - "Only NamedTable reads are supported".to_string(), - )), + _ => not_impl_err!("Only NamedTable reads are supported"), + }, + Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { + Ok(set_op) => match set_op { + set_rel::SetOp::UnionAll => { + if !set.inputs.is_empty() { + let mut union_builder = Ok(LogicalPlanBuilder::from( + from_substrait_rel(ctx, &set.inputs[0], extensions).await?, + )); + for input in &set.inputs[1..] { + union_builder = union_builder? + .union(from_substrait_rel(ctx, input, extensions).await?); + } + union_builder?.build() + } else { + not_impl_err!("Union relation requires at least one input") + } + } + _ => not_impl_err!("Unsupported set operator: {set_op:?}"), + }, + Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { let Some(ext_detail) = &extension.detail else { @@ -495,15 +566,12 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported RelType: {:?}", - rel.rel_type - ))), + _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } fn from_substrait_jointype(join_type: i32) -> Result { - if let Some(substrait_join_type) = join_rel::JoinType::from_i32(join_type) { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { match substrait_join_type { join_rel::JoinType::Inner => Ok(JoinType::Inner), join_rel::JoinType::Left => Ok(JoinType::Left), @@ -511,14 +579,10 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Outer => Ok(JoinType::Full), join_rel::JoinType::Anti => Ok(JoinType::LeftAnti), join_rel::JoinType::Semi => Ok(JoinType::LeftSemi), - _ => Err(DataFusionError::Internal(format!( - "unsupported join type {substrait_join_type:?}" - ))), + _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { - Err(DataFusionError::Internal(format!( - "invalid join type variant {join_type:?}" - ))) + plan_err!("invalid join type variant {join_type:?}") } } @@ -535,10 +599,10 @@ pub async fn from_substrait_sorts( let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { - let Some(direction) = SortDirection::from_i32(*d) else { - return Err(DataFusionError::NotImplemented( - format!("Unsupported Substrait SortDirection value {d}"), - )) + let Ok(direction) = SortDirection::try_from(*d) else { + return not_impl_err!( + "Unsupported Substrait SortDirection value {d}" + ); }; match direction { @@ -546,25 +610,19 @@ pub async fn from_substrait_sorts( SortDirection::AscNullsLast => Ok((true, false)), SortDirection::DescNullsFirst => Ok((false, true)), SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => Err(DataFusionError::NotImplemented( + SortDirection::Clustered => not_impl_err!( "Sort with direction clustered is not yet supported" - .to_string(), - )), + ), SortDirection::Unspecified => { - Err(DataFusionError::NotImplemented( - "Unspecified sort direction is invalid".to_string(), - )) + not_impl_err!("Unspecified sort direction is invalid") } } } - ComparisonFunctionReference(_) => Err(DataFusionError::NotImplemented( + ComparisonFunctionReference(_) => not_impl_err!( "Sort using comparison function reference is not supported" - .to_string(), - )), + ), }, - None => Err(DataFusionError::NotImplemented( - "Sort without sort kind is invalid".to_string(), - )), + None => not_impl_err!("Sort without sort kind is invalid"), }; let (asc, nulls_first) = asc_nullfirst.unwrap(); sorts.push(Expr::Sort(Sort { @@ -602,9 +660,9 @@ pub async fn from_substriat_func_args( Some(ArgType::Value(e)) => { from_substrait_rex(e, input_schema, extensions).await } - _ => Err(DataFusionError::NotImplemented( - "Aggregated function argument non-Value type not supported".to_string(), - )), + _ => { + not_impl_err!("Aggregated function argument non-Value type not supported") + } }; args.push(arg_expr?.as_ref().clone()); } @@ -613,6 +671,7 @@ pub async fn from_substriat_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( + ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, extensions: &HashMap, @@ -626,30 +685,37 @@ pub async fn from_substrait_agg_func( Some(ArgType::Value(e)) => { from_substrait_rex(e, input_schema, extensions).await } - _ => Err(DataFusionError::NotImplemented( - "Aggregated function argument non-Value type not supported".to_string(), - )), + _ => { + not_impl_err!("Aggregated function argument non-Value type not supported") + } }; args.push(arg_expr?.as_ref().clone()); } - let fun = match extensions.get(&f.function_reference) { - Some(function_name) => { - aggregate_function::AggregateFunction::from_str(function_name) - } - None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function anchor = {:?}", + let Some(function_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", f.function_reference - ))), + ); }; - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun: fun.unwrap(), - args, - distinct, - filter, - order_by, - }))) + // try udaf first, then built-in aggr fn. + if let Ok(fun) = ctx.udaf(function_name) { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) + } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) + { + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) + } else { + not_impl_err!( + "Aggregated function {} is not supported: function anchor = {:?}", + function_name, + f.function_reference + ) + } } /// Convert Substrait Rex to DataFusion Expr @@ -660,13 +726,27 @@ pub async fn from_substrait_rex( extensions: &HashMap, ) -> Result> { match &e.rex_type { + Some(RexType::SingularOrList(s)) => { + let substrait_expr = s.value.as_ref().unwrap(); + let substrait_list = s.options.as_ref(); + Ok(Arc::new(Expr::InList(InList { + expr: Box::new( + from_substrait_rex(substrait_expr, input_schema, extensions) + .await? + .as_ref() + .clone(), + ), + list: from_substrait_rex_vec(substrait_list, input_schema, extensions) + .await?, + negated: false, + }))) + } Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => Err(DataFusionError::NotImplemented( + Some(_) => not_impl_err!( "Direct reference StructField with child is not supported" - .to_string(), - )), + ), None => { let column = input_schema.field(x.field as usize).qualified_column(); @@ -676,14 +756,11 @@ pub async fn from_substrait_rex( }))) } }, - _ => Err(DataFusionError::NotImplemented( + _ => not_impl_err!( "Direct reference with types other than StructField is not supported" - .to_string(), - )), + ), }, - _ => Err(DataFusionError::NotImplemented( - "unsupported field ref type".to_string(), - )), + _ => not_impl_err!("unsupported field ref type"), }, Some(RexType::IfThen(if_then)) => { // Parse `ifs` @@ -746,20 +823,44 @@ pub async fn from_substrait_rex( else_expr, }))) } - Some(RexType::ScalarFunction(f)) => match f.arguments.len() { - // BinaryExpr or ScalarFunction - 2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) { - (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { - let op_or_fun = match extensions.get(&f.function_reference) { - Some(fname) => name_to_op_or_scalar_function(fname), - None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", - f.function_reference - ))), - }; - match op_or_fun { - Ok(ScalarFunctionType::Op(op)) => { - return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { + Some(RexType::ScalarFunction(f)) => { + let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { + DataFusionError::NotImplemented(format!( + "Aggregated function not found: function reference = {:?}", + f.function_reference + )) + })?; + let fn_type = scalar_function_type_from_str(fn_name)?; + match fn_type { + ScalarFunctionType::Builtin(fun) => { + let mut args = Vec::with_capacity(f.arguments.len()); + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => { + from_substrait_rex(e, input_schema, extensions).await + } + _ => not_impl_err!( + "Aggregated function argument non-Value type not supported" + ), + }; + args.push(arg_expr?.as_ref().clone()); + } + Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new( + fun, args, + )))) + } + ScalarFunctionType::Op(op) => { + if f.arguments.len() != 2 { + return not_impl_err!( + "Expect two arguments for binary operator {op:?}" + ); + } + let lhs = &f.arguments[0].arg_type; + let rhs = &f.arguments[1].arg_type; + + match (lhs, rhs) { + (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { + Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( from_substrait_rex(l, input_schema, extensions) .await? @@ -775,63 +876,72 @@ pub async fn from_substrait_rex( ), }))) } - Ok(ScalarFunctionType::Builtin(fun)) => { - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun, - args: vec![ - from_substrait_rex(l, input_schema, extensions) - .await? - .as_ref() - .clone(), - from_substrait_rex(r, input_schema, extensions) - .await? - .as_ref() - .clone(), - ], - }))) + (l, r) => not_impl_err!( + "Invalid arguments for binary expression: {l:?} and {r:?}" + ), + } + } + ScalarFunctionType::Not => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `NOT` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::Not(Box::new(expr)))) } - Err(e) => Err(e), + _ => not_impl_err!("Invalid arguments for Not expression"), } } - (l, r) => Err(DataFusionError::NotImplemented(format!( - "Invalid arguments for binary expression: {l:?} and {r:?}" - ))), - }, - // ScalarFunction - _ => { - let fun = match extensions.get(&f.function_reference) { - Some(fname) => BuiltinScalarFunction::from_str(fname), - None => Err(DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", - f.function_reference - ))), - }; - - let mut args: Vec = vec![]; - for arg in f.arguments.iter() { + ScalarFunctionType::Like => { + make_datafusion_like(false, f, input_schema, extensions).await + } + ScalarFunctionType::ILike => { + make_datafusion_like(true, f, input_schema, extensions).await + } + ScalarFunctionType::IsNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NULL` expr".to_string(), + ) + })?; match &arg.arg_type { Some(ArgType::Value(e)) => { - args.push( - from_substrait_rex(e, input_schema, extensions) - .await? - .as_ref() - .clone(), - ); + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNull(Box::new(expr)))) } - e => { - return Err(DataFusionError::NotImplemented(format!( - "Invalid arguments for scalar function: {e:?}" - ))) + _ => not_impl_err!("Invalid arguments for IS NULL expression"), + } + } + ScalarFunctionType::IsNotNull => { + let arg = f.arguments.first().ok_or_else(|| { + DataFusionError::Substrait( + "expect one argument for `IS NOT NULL` expr".to_string(), + ) + })?; + match &arg.arg_type { + Some(ArgType::Value(e)) => { + let expr = from_substrait_rex(e, input_schema, extensions) + .await? + .as_ref() + .clone(); + Ok(Arc::new(Expr::IsNotNull(Box::new(expr)))) + } + _ => { + not_impl_err!("Invalid arguments for IS NOT NULL expression") } } } - - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun: fun?, - args, - }))) } - }, + } Some(RexType::Literal(lit)) => { let scalar_value = from_substrait_literal(lit)?; Ok(Arc::new(Expr::Literal(scalar_value))) @@ -857,10 +967,10 @@ pub async fn from_substrait_rex( Some(RexType::WindowFunction(window)) => { let fun = match extensions.get(&window.function_reference) { Some(function_name) => Ok(find_df_window_func(function_name)), - None => Err(DataFusionError::NotImplemented(format!( + None => not_impl_err!( "Window function not found: function anchor = {:?}", &window.function_reference - ))), + ), }; let order_by = from_substrait_sorts(&window.sorts, input_schema, extensions).await?; @@ -895,9 +1005,7 @@ pub async fn from_substrait_rex( }, }))) } - _ => Err(DataFusionError::NotImplemented( - "unsupported rex_type".to_string(), - )), + _ => not_impl_err!("unsupported rex_type"), } } @@ -908,30 +1016,30 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::I8(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(DataType::Int8), UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt8), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(DataType::Int16), UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt16), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(DataType::Int32), UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt32), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(DataType::Int64), UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt64), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), @@ -948,23 +1056,23 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { TIMESTAMP_NANO_TYPE_REF => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::Date(date) => match date.type_variation_reference { DATE_32_TYPE_REF => Ok(DataType::Date32), DATE_64_TYPE_REF => Ok(DataType::Date64), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Binary), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeBinary), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::FixedBinary(fixed) => { Ok(DataType::FixedSizeBinary(fixed.length)) @@ -972,9 +1080,9 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { r#type::Kind::String(string) => match string.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Utf8), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeUtf8), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, r#type::Kind::List(list) => { let inner_type = @@ -987,9 +1095,9 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - )))?, + )?, } } r#type::Kind::Decimal(d) => match d.type_variation_reference { @@ -999,17 +1107,13 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { DECIMAL_256_TYPE_REF => { Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) } - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" - ))), + ), }, - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported Substrait type: {s_kind:?}" - ))), + _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), }, - _ => Err(DataFusionError::NotImplemented( - "`None` Substrait kind is not supported".to_string(), - )), + _ => not_impl_err!("`None` Substrait kind is not supported"), } } @@ -1148,12 +1252,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { ) } Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Unsupported literal_type: {:?}", - lit.literal_type - ))) - } + _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; Ok(scalar_value) @@ -1166,30 +1265,30 @@ fn from_substrait_null(null_type: &Type) -> Result { r#type::Kind::I8(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)), UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)), UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)), UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)), UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), @@ -1204,44 +1303,88 @@ fn from_substrait_null(null_type: &Type) -> Result { TIMESTAMP_NANO_TYPE_REF => { Ok(ScalarValue::TimestampNanosecond(None, None)) } - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::Date(date) => match date.type_variation_reference { DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)), DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)), LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, // FixedBinary is not supported because `None` doesn't have length r#type::Kind::String(string) => match string.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)), LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)), - v => Err(DataFusionError::NotImplemented(format!( + v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" - ))), + ), }, r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128( None, d.precision as u8, d.scale as i8, )), - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported Substrait type: {kind:?}" - ))), + _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), } } else { - Err(DataFusionError::NotImplemented( - "Null type without kind is not supported".to_string(), - )) + not_impl_err!("Null type without kind is not supported") } } + +async fn make_datafusion_like( + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + extensions: &HashMap, +) -> Result> { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 3 { + return not_impl_err!("Expect three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { + return not_impl_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let escape_char_expr = + from_substrait_rex(escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { + return Err(DataFusionError::Substrait(format!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}", + ))); + }; + + Ok(Arc::new(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: escape_char.map(|c| c.chars().next().unwrap()), + case_insensitive, + }))) +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 785bfa4ea6a7e..c5f1278be6e01 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; +use std::ops::Deref; +use std::sync::Arc; +use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -26,14 +29,18 @@ use datafusion::{ }; use datafusion::common::DFSchemaRef; +use datafusion::common::{exec_err, internal_err, not_impl_err}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::{binary_expr, Expr}; +use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::expression::window_function::BoundsType; +use substrait::proto::CrossRel; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -48,7 +55,7 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -58,10 +65,11 @@ use substrait::{ join_rel, plan_rel, r#type, read_rel::{NamedTable, ReadType}, rel::RelType, + set_rel, sort_field::{SortDirection, SortKind}, AggregateFunction, AggregateRel, AggregationPhase, Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, + JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SetRel, SortField, SortRel, }, version, @@ -156,7 +164,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -172,6 +180,7 @@ pub fn to_substrait_rel( let filter_expr = to_substrait_rex( &filter.predicate, filter.input.schema(), + 0, extension_info, )?; Ok(Box::new(Rel { @@ -185,7 +194,8 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; - let limit_fetch = limit.fetch.unwrap_or(0); + // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` + let limit_fetch = limit.fetch.unwrap_or(usize::MAX); Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, @@ -214,12 +224,11 @@ pub fn to_substrait_rel( } LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; - // Translate aggregate expression to Substrait's groupings (repeated repeated Expression) - let grouping = agg - .group_expr - .iter() - .map(|e| to_substrait_rex(e, agg.input.schema(), extension_info)) - .collect::>>()?; + let groupings = to_substrait_groupings( + &agg.group_expr, + agg.input.schema(), + extension_info, + )?; let measures = agg .aggr_expr .iter() @@ -230,19 +239,17 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), - groupings: vec![Grouping { - grouping_expressions: grouping, - }], //groupings, + groupings, measures, advanced_extension: None, }))), })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; // Get grouping keys from the input relation's number of output fields - let grouping = (0..distinct.input.schema().fields().len()) + let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?; @@ -263,17 +270,24 @@ pub fn to_substrait_rel( let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported - if join.filter.is_some() { - return Err(DataFusionError::NotImplemented("join filter".to_string())); - } match join.join_constraint { JoinConstraint::On => {} - _ => { - return Err(DataFusionError::NotImplemented( - "join constraint".to_string(), - )) + JoinConstraint::Using => { + return not_impl_err!("join constraint: `using`") } } + // parse filter if exists + let in_join_schema = join.left.schema().join(join.right.schema())?; + let join_filter = match &join.filter { + Some(filter) => Some(to_substrait_rex( + filter, + &Arc::new(in_join_schema), + 0, + extension_info, + )?), + None => None, + }; + // map the left and right columns to binary expressions in the form `l = r` // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` let eq_op = if join.null_equals_null { @@ -281,51 +295,84 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - let join_expression = join - .on - .iter() - .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone())) - .reduce(|acc: Expr, expr: Expr| acc.and(expr)); - // join schema from left and right to maintain all nececesary columns from inputs - // note that we cannot simple use join.schema here since we discard some input columns - // when performing semi and anti joins - let join_schema = match join.left.schema().join(join.right.schema()) { - Ok(schema) => Ok(schema), - Err(DataFusionError::SchemaError( - datafusion::common::SchemaError::DuplicateQualifiedField { - qualifier: _, - name: _, - }, - )) => Ok(join.schema.as_ref().clone()), - Err(e) => Err(e), + let join_on = to_substrait_join_expr( + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + extension_info, + )?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + on_expr, + filter, + Operator::And, + extension_info, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, }; - if let Some(e) = join_expression { - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: Some(Box::new(to_substrait_rex( - &e, - &Arc::new(join_schema?), - extension_info, - )?)), - post_join_filter: None, - advanced_extension: None, - }))), - })) - } else { - Err(DataFusionError::NotImplemented( - "Empty join condition".to_string(), - )) - } + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) + } + LogicalPlan::CrossJoin(cross_join) => { + let CrossJoin { + left, + right, + schema: _, + } = cross_join; + let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; + let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Cross(Box::new(CrossRel { + common: None, + left: Some(left), + right: Some(right), + advanced_extension: None, + }))), + })) } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait to_substrait_rel(alias.input.as_ref(), ctx, extension_info) } + LogicalPlan::Union(union) => { + let input_rels = union + .inputs + .iter() + .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(substrait::proto::rel::RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) + } LogicalPlan::Window(window) => { let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?; // If the input is a Project relation, we can just append the WindowFunction expressions @@ -353,6 +400,7 @@ pub fn to_substrait_rel( window_exprs.push(to_substrait_rex( expr, window.input.schema(), + 0, extension_info, )?); } @@ -397,10 +445,40 @@ pub fn to_substrait_rel( rel_type: Some(rel_type), })) } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported operator: {plan:?}" - ))), + _ => not_impl_err!("Unsupported operator: {plan:?}"), + } +} + +fn to_substrait_join_expr( + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result> { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + // Parse left + let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + // Parse right + let r = to_substrait_rex( + right, + right_schema, + left_schema.fields().len(), // offset to return the correct index + extension_info, + )?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); } + let join_expr: Option = + exprs.into_iter().reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + }); + Ok(join_expr) } fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { @@ -424,7 +502,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Gt => "gt", Operator::GtEq => "gte", Operator::Plus => "add", - Operator::Minus => "substract", + Operator::Minus => "subtract", Operator::Multiply => "multiply", Operator::Divide => "divide", Operator::Modulo => "mod", @@ -439,12 +517,75 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::BitwiseAnd => "bitwise_and", Operator::BitwiseOr => "bitwise_or", Operator::StringConcat => "str_concat", + Operator::AtArrow => "at_arrow", + Operator::ArrowAt => "arrow_at", Operator::BitwiseXor => "bitwise_xor", Operator::BitwiseShiftRight => "bitwise_shift_right", Operator::BitwiseShiftLeft => "bitwise_shift_left", } } +pub fn parse_flat_grouping_exprs( + exprs: &[Expr], + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let grouping_expressions = exprs + .iter() + .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .collect::>>()?; + Ok(Grouping { + grouping_expressions, + }) +} + +pub fn to_substrait_groupings( + exprs: &Vec, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result> { + match exprs.len() { + 1 => match &exprs[0] { + Expr::GroupingSet(gs) => match gs { + GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( + "GroupingSet CUBE is not yet supported".to_string(), + )), + GroupingSet::GroupingSets(sets) => Ok(sets + .iter() + .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .collect::>>()?), + GroupingSet::Rollup(set) => { + let mut sets: Vec> = vec![vec![]]; + for i in 0..set.len() { + sets.push(set[..=i].to_vec()); + } + Ok(sets + .iter() + .rev() + .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .collect::>>()?) + } + }, + _ => Ok(vec![parse_flat_grouping_exprs( + exprs, + schema, + extension_info, + )?]), + }, + _ => Ok(vec![parse_flat_grouping_exprs( + exprs, + schema, + extension_info, + )?]), + } +} + #[allow(deprecated)] pub fn to_substrait_agg_measure( expr: &Expr, @@ -455,42 +596,112 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - // TODO: Once substrait supports order by, add handling for it. - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => { - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); - } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts: vec![], - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn (fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) + } + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + } - Expr::Alias(expr, _name) => { + Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) } - _ => Err(DataFusionError::Internal(format!( + _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", expr, expr.variant_name() - ))), + ), + } +} + +/// Converts sort expression to corresponding substrait `SortField` +fn to_substrait_sort_field( + expr: &Expr, + schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + match expr { + Expr::Sort(sort) => { + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(to_substrait_rex( + sort.expr.deref(), + schema, + 0, + extension_info, + )?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) + } + _ => exec_err!("expects to receive sort expression"), } } @@ -545,8 +756,8 @@ pub fn make_binary_op_scalar_func( HashMap, ), ) -> Expression { - let function_name = operator_to_name(op).to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = + _register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -566,29 +777,97 @@ pub fn make_binary_op_scalar_func( } /// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// +/// * `expr` - DataFusion expression to be parse into a Substrait expression +/// * `schema` - DataFusion input schema for looking up field qualifiers +/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// This should only be set by caller with more than one input relations i.e. Join. +/// Substrait expects one set of indices when joining two relations. +/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` +/// relation will have column indices from `0` to `n-1`, however, Substrait will expect +/// the `right` indices to be offset by the `left`. This means Substrait will expect to +/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: +/// ```SELECT * +/// FROM t1 +/// JOIN t2 +/// ON t1.c1 = t2.c0;``` +/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] +/// the join condition should become +/// `col_ref(1) = col_ref(3 + 0)` +/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index +/// of the join key column from `right` +/// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, + col_ref_offset: usize, extension_info: &mut ( Vec, HashMap, ), ) -> Result { match expr { - Expr::ScalarFunction(DFScalarFunction { fun, args }) => { + Expr::InList(InList { + expr, + list, + negated, + }) => { + let substrait_list = list + .iter() + .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .collect::>>()?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; + + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } + } + Expr::ScalarFunction(fun) => { let mut arguments: Vec = vec![]; - for arg in args { + for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + + // function should be resolved during `AnalyzerRule` + if let ScalarFunctionDefinition::Name(_) = fun.func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + let function_anchor = + _register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -607,9 +886,12 @@ pub fn to_substrait_rex( }) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -632,9 +914,12 @@ pub fn to_substrait_rex( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -659,11 +944,11 @@ pub fn to_substrait_rex( } Expr::Column(col) => { let index = schema.index_of_column(col)?; - substrait_field_ref(index) + substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, extension_info)?; - let r = to_substrait_rex(right, schema, extension_info)?; + let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -677,21 +962,41 @@ pub fn to_substrait_rex( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(to_substrait_rex(e, schema, extension_info)?), + r#if: Some(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(to_substrait_rex(r#if, schema, extension_info)?), - then: Some(to_substrait_rex(then, schema, extension_info)?), + r#if: Some(to_substrait_rex( + r#if, + schema, + col_ref_offset, + extension_info, + )?), + then: Some(to_substrait_rex( + then, + schema, + col_ref_offset, + extension_info, + )?), }); } // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), + Some(e) => Some(Box::new(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?)), None => None, }; @@ -707,6 +1012,7 @@ pub fn to_substrait_rex( input: Some(Box::new(to_substrait_rex( expr, schema, + col_ref_offset, extension_info, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED @@ -715,7 +1021,9 @@ pub fn to_substrait_rex( }) } Expr::Literal(value) => to_substrait_literal(value), - Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), + Expr::Alias(Alias { expr, .. }) => { + to_substrait_rex(expr, schema, col_ref_offset, extension_info) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -724,8 +1032,7 @@ pub fn to_substrait_rex( window_frame, }) => { // function reference - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -733,6 +1040,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -740,7 +1048,7 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, extension_info)) + .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by @@ -749,26 +1057,86 @@ pub fn to_substrait_rex( .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; Ok(make_substrait_window_function( function_anchor, arguments, partition_by, order_by, bounds, + bound_type, )) } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported expression: {expr:?}" - ))), + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => make_substrait_like_expr( + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + col_ref_offset, + extension_info, + ), + Expr::IsNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + Expr::IsNotNull(arg) => { + let arguments: Vec = vec![FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex( + arg, + schema, + col_ref_offset, + extension_info, + )?)), + }]; + + let function_name = "is_not_null".to_string(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }) + } + _ => { + not_impl_err!("Unsupported expression: {expr:?}") + } } } fn to_substrait_type(dt: &DataType) -> Result { let default_nullability = r#type::Nullability::Required as i32; match dt { - DataType::Null => Err(DataFusionError::Internal( - "Null cast is not valid".to_string(), - )), + DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { type_variation_reference: DEFAULT_TYPE_REF, @@ -943,9 +1311,7 @@ fn to_substrait_type(dt: &DataType) -> Result { precision: *p as i32, })), }), - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported cast type: {dt:?}" - ))), + _ => not_impl_err!("Unsupported cast type: {dt:?}"), } } @@ -956,6 +1322,7 @@ fn make_substrait_window_function( partitions: Vec, sorts: Vec, bounds: (Bound, Bound), + bounds_type: BoundsType, ) -> Expression { Expression { rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { @@ -970,7 +1337,73 @@ fn make_substrait_window_function( lower_bound: Some(bounds.0), upper_bound: Some(bounds.1), args: vec![], + bounds_type: bounds_type as i32, + })), + } +} + +#[allow(deprecated)] +#[allow(clippy::too_many_arguments)] +fn make_substrait_like_expr( + ignore_case: bool, + negated: bool, + expr: &Expr, + pattern: &Expr, + escape_char: Option, + schema: &DFSchemaRef, + col_ref_offset: usize, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + let function_anchor = if ignore_case { + _register_function("ilike".to_string(), extension_info) + } else { + _register_function("like".to_string(), extension_info) + }; + let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let escape_char = + to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; + let arguments = vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(pattern)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(escape_char)), + }, + ]; + + let substrait_like = Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], })), + }; + + if negated { + let function_anchor = _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_like)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_like) } } @@ -1072,6 +1505,15 @@ fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { } } +fn to_substrait_bound_type(window_frame: &WindowFrame) -> Result { + match window_frame.units { + WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS + WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE + // TODO: Support GROUPS + unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), + } +} + fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { Ok(( to_substrait_bound(&window_frame.start_bound), @@ -1305,9 +1747,7 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { })) } // TODO: Extend support for remaining data types - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported literal: {v:?}" - ))), + _ => not_impl_err!("Unsupported literal: {v:?}"), } } @@ -1325,7 +1765,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, extension_info)?; + let e = to_substrait_rex(expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -1337,9 +1777,7 @@ fn substrait_sort_field( sort_kind: Some(SortKind::Direction(d as i32)), }) } - _ => Err(DataFusionError::NotImplemented(format!( - "Expecting sort expression but got {expr:?}" - ))), + _ => not_impl_err!("Expecting sort expression but got {expr:?}"), } } @@ -1406,7 +1844,10 @@ mod test { println!("Checking round trip of {scalar:?}"); let substrait = to_substrait_literal(&scalar)?; - let Expression { rex_type: Some(RexType::Literal(substrait_literal)) } = substrait else { + let Expression { + rex_type: Some(RexType::Literal(substrait_literal)), + } = substrait + else { panic!("Expected Literal expression, got {substrait:?}"); }; diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 5d2f22b857e99..942798173e0e7 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use async_recursion::async_recursion; -use chrono::DateTime; +use std::collections::HashMap; +use std::sync::Arc; + use datafusion::arrow::datatypes::Schema; +use datafusion::common::not_impl_err; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::error::{DataFusionError, Result}; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::{ExecutionPlan, Statistics}; use datafusion::prelude::SessionContext; + +use async_recursion::async_recursion; +use chrono::DateTime; use object_store::ObjectMeta; -use std::collections::HashMap; -use std::sync::Arc; use substrait::proto::read_rel::local_files::file_or_files::PathType; use substrait::proto::{ expression::MaskExpression, read_rel::ReadType, rel::RelType, Rel, @@ -35,26 +38,20 @@ use substrait::proto::{ /// Convert Substrait Rel to DataFusion ExecutionPlan #[async_recursion] pub async fn from_substrait_rel( - _ctx: &mut SessionContext, + _ctx: &SessionContext, rel: &Rel, _extensions: &HashMap, ) -> Result> { match &rel.rel_type { Some(RelType::Read(read)) => { if read.filter.is_some() || read.best_effort_filter.is_some() { - return Err(DataFusionError::NotImplemented( - "Read with filter is not supported".to_string(), - )); + return not_impl_err!("Read with filter is not supported"); } if read.base_schema.is_some() { - return Err(DataFusionError::NotImplemented( - "Read with schema is not supported".to_string(), - )); + return not_impl_err!("Read with schema is not supported"); } if read.advanced_extension.is_some() { - return Err(DataFusionError::NotImplemented( - "Read with AdvancedExtension is not supported".to_string(), - )); + return not_impl_err!("Read with AdvancedExtension is not supported"); } match &read.as_ref().read_type { Some(ReadType::LocalFiles(files)) => { @@ -92,6 +89,7 @@ pub async fn from_substrait_rel( location: path.into(), size, e_tag: None, + version: None, }, partition_values: vec![], range: None, @@ -109,7 +107,7 @@ pub async fn from_substrait_rel( object_store_url: ObjectStoreUrl::local_filesystem(), file_schema: Arc::new(Schema::empty()), file_groups, - statistics: Default::default(), + statistics: Statistics::new_unknown(&Schema::empty()), projection: None, limit: None, table_partition_cols: vec![], @@ -131,15 +129,11 @@ pub async fn from_substrait_rel( Ok(Arc::new(ParquetExec::new(base_config, None, None)) as Arc) } - _ => Err(DataFusionError::NotImplemented( + _ => not_impl_err!( "Only LocalFile reads are supported when parsing physical" - .to_string(), - )), + ), } } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported RelType: {:?}", - rel.rel_type - ))), + _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs new file mode 100644 index 0000000000000..b17289205f3de --- /dev/null +++ b/datafusion/substrait/tests/cases/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod roundtrip_logical_plan; +mod roundtrip_physical_plan; +mod serialize; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs new file mode 100644 index 0000000000000..691fba8644497 --- /dev/null +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -0,0 +1,995 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::array::ArrayRef; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; +use datafusion_substrait::logical_plan::{ + consumer::from_substrait_plan, producer::to_substrait_plan, +}; + +use std::hash::Hash; +use std::sync::Arc; + +use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::context::SessionState; +use datafusion::execution::registry::SerializerRegistry; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, +}; +use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; +use datafusion::prelude::*; + +use substrait::proto::extensions::simple_extension_declaration::MappingType; +use substrait::proto::rel::RelType; +use substrait::proto::{plan_rel, Plan, Rel}; + +struct MockSerializerRegistry; + +impl SerializerRegistry for MockSerializerRegistry { + fn serialize_logical_plan( + &self, + node: &dyn UserDefinedLogicalNode, + ) -> Result> { + if node.name() == "MockUserDefinedLogicalPlan" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); + node.serialize() + } else { + unreachable!() + } + } + + fn deserialize_logical_plan( + &self, + name: &str, + bytes: &[u8], + ) -> Result> + { + if name == "MockUserDefinedLogicalPlan" { + MockUserDefinedLogicalPlan::deserialize(bytes) + } else { + unreachable!() + } + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MockUserDefinedLogicalPlan { + /// Replacement for serialize/deserialize data + validation_bytes: Vec, + inputs: Vec, + empty_schema: DFSchemaRef, +} + +impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MockUserDefinedLogicalPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.inputs.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.empty_schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "MockUserDefinedLogicalPlan [validation_bytes={:?}]", + self.validation_bytes + ) + } + + fn from_template( + &self, + _: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + Arc::new(Self { + validation_bytes: self.validation_bytes.clone(), + inputs: inputs.to_vec(), + empty_schema: Arc::new(DFSchema::empty()), + }) + } + + fn dyn_hash(&self, _: &mut dyn std::hash::Hasher) { + unimplemented!() + } + + fn dyn_eq(&self, _: &dyn UserDefinedLogicalNode) -> bool { + unimplemented!() + } +} + +impl MockUserDefinedLogicalPlan { + pub fn new(validation_bytes: Vec) -> Self { + Self { + validation_bytes, + inputs: vec![], + empty_schema: Arc::new(DFSchema::empty()), + } + } + + fn serialize(&self) -> Result> { + Ok(self.validation_bytes.clone()) + } + + fn deserialize(bytes: &[u8]) -> Result> + where + Self: Sized, + { + Ok(Arc::new(MockUserDefinedLogicalPlan::new(bytes.to_vec()))) + } +} + +#[tokio::test] +async fn simple_select() -> Result<()> { + roundtrip("SELECT a, b FROM data").await +} + +#[tokio::test] +async fn wildcard_select() -> Result<()> { + roundtrip("SELECT * FROM data").await +} + +#[tokio::test] +async fn select_with_filter() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1").await +} + +#[tokio::test] +async fn select_with_reused_functions() -> Result<()> { + let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; + roundtrip(sql).await?; + let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; + function_names.sort(); + function_anchors.sort(); + + assert_eq!(function_names, ["and", "gt", "lt"]); + assert_eq!(function_anchors, [0, 1, 2]); + + Ok(()) +} + +#[tokio::test] +async fn select_with_filter_date() -> Result<()> { + roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await +} + +#[tokio::test] +async fn select_with_filter_bool_expr() -> Result<()> { + roundtrip("SELECT * FROM data WHERE d AND a > 1").await +} + +#[tokio::test] +async fn select_with_limit() -> Result<()> { + roundtrip_fill_na("SELECT * FROM data LIMIT 100").await +} + +#[tokio::test] +async fn select_without_limit() -> Result<()> { + roundtrip_fill_na("SELECT * FROM data OFFSET 10").await +} + +#[tokio::test] +async fn select_with_limit_offset() -> Result<()> { + roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await +} + +#[tokio::test] +async fn simple_aggregate() -> Result<()> { + roundtrip("SELECT a, sum(b) FROM data GROUP BY a").await +} + +#[tokio::test] +async fn aggregate_distinct_with_having() -> Result<()> { + roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100") + .await +} + +#[tokio::test] +async fn aggregate_multiple_keys() -> Result<()> { + roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await +} + +#[tokio::test] +async fn aggregate_grouping_sets() -> Result<()> { + roundtrip( + "SELECT a, c, d, avg(b) FROM data GROUP BY GROUPING SETS ((a, c), (a), (d), ())", + ) + .await +} + +#[tokio::test] +async fn aggregate_grouping_rollup() -> Result<()> { + assert_expected_plan( + "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", + "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\ + \n TableScan: data projection=[a, b, c, e]" + ).await +} + +#[tokio::test] +async fn decimal_literal() -> Result<()> { + roundtrip("SELECT * FROM data WHERE b > 2.5").await +} + +#[tokio::test] +async fn null_decimal_literal() -> Result<()> { + roundtrip("SELECT * FROM data WHERE b = NULL").await +} + +#[tokio::test] +async fn u32_literal() -> Result<()> { + roundtrip("SELECT * FROM data WHERE e > 4294967295").await +} + +#[tokio::test] +async fn simple_distinct() -> Result<()> { + test_alias( + "SELECT distinct a FROM data", + "SELECT a FROM data GROUP BY a", + ) + .await +} + +#[tokio::test] +async fn select_distinct_two_fields() -> Result<()> { + test_alias( + "SELECT distinct a, b FROM data", + "SELECT a, b FROM data GROUP BY a, b", + ) + .await +} + +#[tokio::test] +async fn simple_alias() -> Result<()> { + test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await +} + +#[tokio::test] +async fn two_table_alias() -> Result<()> { + test_alias( + "SELECT d1.a FROM data d1 JOIN data2 d2 ON d1.a = d2.a", + "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", + ) + .await +} + +#[tokio::test] +async fn between_integers() -> Result<()> { + test_alias( + "SELECT * FROM data WHERE a BETWEEN 2 AND 6", + "SELECT * FROM data WHERE a >= 2 AND a <= 6", + ) + .await +} + +#[tokio::test] +async fn not_between_integers() -> Result<()> { + test_alias( + "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6", + "SELECT * FROM data WHERE a < 2 OR a > 6", + ) + .await +} + +#[tokio::test] +async fn simple_scalar_function_abs() -> Result<()> { + roundtrip("SELECT ABS(a) FROM data").await +} + +#[tokio::test] +async fn simple_scalar_function_pow() -> Result<()> { + roundtrip("SELECT POW(a, 2) FROM data").await +} + +#[tokio::test] +async fn simple_scalar_function_substr() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await +} + +#[tokio::test] +async fn simple_scalar_function_is_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NULL").await +} + +#[tokio::test] +async fn simple_scalar_function_is_not_null() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IS NOT NULL").await +} + +#[tokio::test] +async fn case_without_base_expression() -> Result<()> { + roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data") + .await +} + +#[tokio::test] +async fn case_with_base_expression() -> Result<()> { + roundtrip( + "SELECT (CASE a + WHEN 0 THEN 'zero' + WHEN 1 THEN 'one' + ELSE 'other' + END) FROM data", + ) + .await +} + +#[tokio::test] +async fn cast_decimal_to_int() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a = CAST(2.5 AS int)").await +} + +#[tokio::test] +async fn implicit_cast() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a = b").await +} + +#[tokio::test] +async fn aggregate_case() -> Result<()> { + assert_expected_plan( + "SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", + "Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ + \n TableScan: data projection=[a]", + ) + .await +} + +#[tokio::test] +async fn roundtrip_inlist_1() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a IN (1, 2, 3)").await +} + +#[tokio::test] +// Test with length <= datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST +async fn roundtrip_inlist_2() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f IN ('a', 'b', 'c')").await +} + +#[tokio::test] +// Test with length > datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST +async fn roundtrip_inlist_3() -> Result<()> { + let inlist = (0..THRESHOLD_INLINE_INLIST + 1) + .map(|i| format!("'{i}'")) + .collect::>() + .join(", "); + + roundtrip(&format!("SELECT * FROM data WHERE f IN ({inlist})")).await +} + +#[tokio::test] +async fn roundtrip_inlist_4() -> Result<()> { + roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await +} + +#[tokio::test] +async fn roundtrip_cross_join() -> Result<()> { + roundtrip("SELECT * FROM data CROSS JOIN data2").await +} + +#[tokio::test] +async fn roundtrip_inner_join() -> Result<()> { + roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await +} + +#[tokio::test] +async fn roundtrip_non_equi_inner_join() -> Result<()> { + roundtrip_verify_post_join_filter( + "SELECT data.a FROM data JOIN data2 ON data.a <> data2.a", + ) + .await +} + +#[tokio::test] +async fn roundtrip_non_equi_join() -> Result<()> { + roundtrip_verify_post_join_filter( + "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a", + ) + .await +} + +#[tokio::test] +async fn roundtrip_exists_filter() -> Result<()> { + assert_expected_plan( + "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a = d1.a AND d2.e != d1.e)", + "Projection: data.b\ + \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64)\ + \n TableScan: data projection=[a, b, e]\ + \n TableScan: data2 projection=[a, e]" + ).await +} + +#[tokio::test] +async fn inner_join() -> Result<()> { + assert_expected_plan( + "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", + "Projection: data.a\ + \n Inner Join: data.a = data2.a\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", + ) + .await +} + +#[tokio::test] +async fn roundtrip_left_join() -> Result<()> { + roundtrip("SELECT data.a FROM data LEFT JOIN data2 ON data.a = data2.a").await +} + +#[tokio::test] +async fn roundtrip_right_join() -> Result<()> { + roundtrip("SELECT data.a FROM data RIGHT JOIN data2 ON data.a = data2.a").await +} + +#[tokio::test] +async fn roundtrip_outer_join() -> Result<()> { + roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await +} + +#[tokio::test] +async fn roundtrip_arithmetic_ops() -> Result<()> { + roundtrip("SELECT a - a FROM data").await?; + roundtrip("SELECT a + a FROM data").await?; + roundtrip("SELECT a * a FROM data").await?; + roundtrip("SELECT a / a FROM data").await?; + roundtrip("SELECT a = a FROM data").await?; + roundtrip("SELECT a != a FROM data").await?; + roundtrip("SELECT a > a FROM data").await?; + roundtrip("SELECT a >= a FROM data").await?; + roundtrip("SELECT a < a FROM data").await?; + roundtrip("SELECT a <= a FROM data").await?; + Ok(()) +} + +#[tokio::test] +async fn roundtrip_like() -> Result<()> { + roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await +} + +#[tokio::test] +async fn roundtrip_ilike() -> Result<()> { + roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await +} + +#[tokio::test] +async fn roundtrip_union() -> Result<()> { + roundtrip("SELECT a, e FROM data UNION SELECT a, e FROM data").await +} + +#[tokio::test] +async fn roundtrip_union2() -> Result<()> { + roundtrip( + "SELECT a, b FROM data UNION SELECT a, b FROM data UNION SELECT a, b FROM data", + ) + .await +} + +#[tokio::test] +async fn roundtrip_union_all() -> Result<()> { + roundtrip("SELECT a, e FROM data UNION ALL SELECT a, e FROM data").await +} + +#[tokio::test] +async fn simple_intersect() -> Result<()> { + assert_expected_plan( + "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", + "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Projection: \ + \n LeftSemi Join: data.a = data2.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data2 projection=[a]", + ) + .await +} + +#[tokio::test] +async fn simple_intersect_table_reuse() -> Result<()> { + assert_expected_plan( + "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", + "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Projection: \ + \n LeftSemi Join: data.a = data.a\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n TableScan: data projection=[a]", + ) + .await +} + +#[tokio::test] +async fn simple_window_function() -> Result<()> { + roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) OVER (PARTITION BY a) FROM data;").await +} + +#[tokio::test] +async fn qualified_schema_table_reference() -> Result<()> { + roundtrip("SELECT * FROM public.data;").await +} + +#[tokio::test] +async fn qualified_catalog_schema_table_reference() -> Result<()> { + roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await +} + +#[tokio::test] +async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", + "Projection: data.b, data.c\ + \n Inner Join: data.a = data.a\ + \n TableScan: data projection=[a, b]\ + \n TableScan: data projection=[a, c]", + ) + .await +} + +#[tokio::test] +async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: data.b, data.c\ + \n Inner Join: data.b = data.b\ + \n TableScan: data projection=[b]\ + \n TableScan: data projection=[b, c]", + ) + .await +} + +/// Construct a plan that contains several literals of types that are currently supported. +/// This case ignores: +/// - Date64, for this literal is not supported +/// - FixedSizeBinary, for converting UTF-8 literal to FixedSizeBinary is not supported +/// - List, this nested type is not supported in arrow_cast +/// - Decimal128 and Decimal256, them will fallback to UTF8 cast expr rather than plain literal. +#[tokio::test] +async fn all_type_literal() -> Result<()> { + roundtrip_all_types( + "select * from data where + bool_col = TRUE AND + int8_col = arrow_cast('0', 'Int8') AND + uint8_col = arrow_cast('0', 'UInt8') AND + int16_col = arrow_cast('0', 'Int16') AND + uint16_col = arrow_cast('0', 'UInt16') AND + int32_col = arrow_cast('0', 'Int32') AND + uint32_col = arrow_cast('0', 'UInt32') AND + int64_col = arrow_cast('0', 'Int64') AND + uint64_col = arrow_cast('0', 'UInt64') AND + float32_col = arrow_cast('0', 'Float32') AND + float64_col = arrow_cast('0', 'Float64') AND + sec_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Second, None)') AND + ms_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Millisecond, None)') AND + us_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Microsecond, None)') AND + ns_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Nanosecond, None)') AND + date32_col = arrow_cast('2020-01-01', 'Date32') AND + binary_col = arrow_cast('binary', 'Binary') AND + large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND + utf8_col = arrow_cast('utf8', 'Utf8') AND + large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');", + ) + .await +} + +/// Construct a plan that cast columns. Only those SQL types are supported for now. +#[tokio::test] +async fn new_test_grammar() -> Result<()> { + roundtrip_all_types( + "select + bool_col::boolean, + int8_col::tinyint, + uint8_col::tinyint unsigned, + int16_col::smallint, + uint16_col::smallint unsigned, + int32_col::integer, + uint32_col::integer unsigned, + int64_col::bigint, + uint64_col::bigint unsigned, + float32_col::float, + float64_col::double, + decimal_128_col::decimal(10, 2), + date32_col::date, + binary_col::bytea + from data", + ) + .await +} + +#[tokio::test] +async fn extension_logical_plan() -> Result<()> { + let ctx = create_context().await?; + let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec(); + let ext_plan = LogicalPlan::Extension(Extension { + node: Arc::new(MockUserDefinedLogicalPlan { + validation_bytes, + inputs: vec![], + empty_schema: Arc::new(DFSchema::empty()), + }), + }); + + let proto = to_substrait_plan(&ext_plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + let plan1str = format!("{ext_plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_aggregate_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Int64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Int64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let ctx = create_context().await?; + ctx.register_udaf(dummy_agg); + + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await +} + +fn check_post_join_filters(rel: &Rel) -> Result<()> { + // search for target_rel and field value in proto + match &rel.rel_type { + Some(RelType::Join(join)) => { + // check if join filter is None + if join.post_join_filter.is_some() { + plan_err!( + "DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel" + ) + } else { + // recursively check JoinRels + match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) { + Err(e) => Err(e), + Ok(_) => { + check_post_join_filters(join.right.as_ref().unwrap().as_ref()) + } + } + } + } + Some(RelType::Project(p)) => { + check_post_join_filters(p.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Filter(filter)) => { + check_post_join_filters(filter.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Fetch(fetch)) => { + check_post_join_filters(fetch.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Sort(sort)) => { + check_post_join_filters(sort.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Aggregate(agg)) => { + check_post_join_filters(agg.input.as_ref().unwrap().as_ref()) + } + Some(RelType::Set(set)) => { + for input in &set.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionSingle(ext)) => { + check_post_join_filters(ext.input.as_ref().unwrap().as_ref()) + } + Some(RelType::ExtensionMulti(ext)) => { + for input in &ext.inputs { + match check_post_join_filters(input) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + Ok(()) + } + Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), + _ => not_impl_err!( + "Unsupported RelType: {:?} in post join filter check", + rel.rel_type + ), + } +} + +async fn verify_post_join_filter_value(proto: Box) -> Result<()> { + for relation in &proto.relations { + match relation.rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) { + Err(e) => return Err(e), + Ok(_) => continue, + }, + plan_rel::RelType::Root(root) => { + match check_post_join_filters(root.input.as_ref().unwrap()) { + Err(e) => return Err(e), + Ok(_) => continue, + } + } + }, + None => return plan_err!("Cannot parse plan relation: None"), + } + } + + Ok(()) +} + +async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + let plan2str = format!("{plan2:?}"); + assert_eq!(expected_plan_str, &plan2str); + Ok(()) +} + +async fn roundtrip_fill_na(sql: &str) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan1 = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan1, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + // Format plan string and replace all None's with 0 + let plan1str = format!("{plan1:?}").replace("None", "0"); + let plan2str = format!("{plan2:?}").replace("None", "0"); + + assert_eq!(plan1str, plan2str); + Ok(()) +} + +async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { + // Since we ignore the SubqueryAlias in the producer, the result should be + // the same as producing a Substrait plan from the same query without aliases + // sql_with_alias -> substrait -> logical plan = sql_no_alias -> substrait -> logical plan + let ctx = create_context().await?; + + let df_a = ctx.sql(sql_with_alias).await?; + let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; + let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?; + + let df = ctx.sql(sql_no_alias).await?; + let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; + let plan = from_substrait_plan(&ctx, &proto).await?; + + println!("{plan_with_alias:#?}"); + println!("{plan:#?}"); + + let plan1str = format!("{plan_with_alias:?}"); + let plan2str = format!("{plan:?}"); + assert_eq!(plan1str, plan2str); + Ok(()) +} + +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + Ok(()) +} + +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await +} + +async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + + // verify that the join filters are None + verify_post_join_filter_value(proto).await +} + +async fn roundtrip_all_types(sql: &str) -> Result<()> { + let ctx = create_all_type_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + println!("{plan:#?}"); + println!("{plan2:#?}"); + + let plan1str = format!("{plan:?}"); + let plan2str = format!("{plan2:?}"); + assert_eq!(plan1str, plan2str); + Ok(()) +} + +async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; + + let mut function_names: Vec = vec![]; + let mut function_anchors: Vec = vec![]; + for e in &proto.extensions { + let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), + _ => unreachable!("Producer does not generate a non-function extension"), + }; + function_names.push(function_name.to_string()); + function_anchors.push(function_anchor); + } + + Ok((function_names, function_anchors)) +} + +async fn create_context() -> Result { + let state = SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ) + .with_serializer_registry(Arc::new(MockSerializerRegistry)); + let ctx = SessionContext::new_with_state(state); + let mut explicit_options = CsvReadOptions::new(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(5, 2), true), + Field::new("c", DataType::Date32, true), + Field::new("d", DataType::Boolean, true), + Field::new("e", DataType::UInt32, true), + Field::new("f", DataType::Utf8, true), + ]); + explicit_options.schema = Some(&schema); + ctx.register_csv("data", "tests/testdata/data.csv", explicit_options) + .await?; + ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) +} + +/// Cover all supported types +async fn create_all_type_context() -> Result { + let ctx = SessionContext::new(); + let mut explicit_options = CsvReadOptions::new(); + let schema = Schema::new(vec![ + Field::new("bool_col", DataType::Boolean, true), + Field::new("int8_col", DataType::Int8, true), + Field::new("uint8_col", DataType::UInt8, true), + Field::new("int16_col", DataType::Int16, true), + Field::new("uint16_col", DataType::UInt16, true), + Field::new("int32_col", DataType::Int32, true), + Field::new("uint32_col", DataType::UInt32, true), + Field::new("int64_col", DataType::Int64, true), + Field::new("uint64_col", DataType::UInt64, true), + Field::new("float32_col", DataType::Float32, true), + Field::new("float64_col", DataType::Float64, true), + Field::new( + "sec_timestamp_col", + DataType::Timestamp(TimeUnit::Second, None), + true, + ), + Field::new( + "ms_timestamp_col", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "us_timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ns_timestamp_col", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new("date32_col", DataType::Date32, true), + Field::new("date64_col", DataType::Date64, true), + Field::new("binary_col", DataType::Binary, true), + Field::new("large_binary_col", DataType::LargeBinary, true), + Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, true), + Field::new_list("list_col", Field::new("item", DataType::Int64, true), true), + Field::new_list( + "large_list_col", + Field::new("item", DataType::Int64, true), + true, + ), + Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), + Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), + ]); + explicit_options.schema = Some(&schema); + explicit_options.has_header = false; + ctx.register_csv("data", "tests/testdata/empty.csv", explicit_options) + .await?; + + Ok(ctx) +} diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs new file mode 100644 index 0000000000000..b64dd2c138fc9 --- /dev/null +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::arrow::datatypes::Schema; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::object_store::ObjectStoreUrl; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::error::Result; +use datafusion::physical_plan::{displayable, ExecutionPlan, Statistics}; +use datafusion::prelude::SessionContext; +use datafusion_substrait::physical_plan::{consumer, producer}; + +use substrait::proto::extensions; + +#[tokio::test] +async fn parquet_exec() -> Result<()> { + let scan_config = FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_schema: Arc::new(Schema::empty()), + file_groups: vec![ + vec![PartitionedFile::new( + "file://foo/part-0.parquet".to_string(), + 123, + )], + vec![PartitionedFile::new( + "file://foo/part-1.parquet".to_string(), + 123, + )], + ], + statistics: Statistics::new_unknown(&Schema::empty()), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }; + let parquet_exec: Arc = + Arc::new(ParquetExec::new(scan_config, None, None)); + + let mut extension_info: ( + Vec, + HashMap, + ) = (vec![], HashMap::new()); + + let substrait_rel = + producer::to_substrait_rel(parquet_exec.as_ref(), &mut extension_info)?; + + let ctx = SessionContext::new(); + + let parquet_exec_roundtrip = + consumer::from_substrait_rel(&ctx, substrait_rel.as_ref(), &HashMap::new()) + .await?; + + let expected = format!("{}", displayable(parquet_exec.as_ref()).indent(true)); + let actual = format!( + "{}", + displayable(parquet_exec_roundtrip.as_ref()).indent(true) + ); + assert_eq!(expected, actual); + + Ok(()) +} diff --git a/datafusion/substrait/tests/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs similarity index 96% rename from datafusion/substrait/tests/serialize.rs rename to datafusion/substrait/tests/cases/serialize.rs index d6dc5d7e58f2d..f6736ca222790 100644 --- a/datafusion/substrait/tests/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -30,7 +30,7 @@ mod tests { #[tokio::test] async fn serialize_simple_select() -> Result<()> { - let mut ctx = create_context().await?; + let ctx = create_context().await?; let path = "tests/simple_select.bin"; let sql = "SELECT a, b FROM data"; // Test reference @@ -42,7 +42,7 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let plan = from_substrait_plan(&mut ctx, &proto).await?; + let plan = from_substrait_plan(&ctx, &proto).await?; let plan_str_ref = format!("{plan_ref:?}"); let plan_str = format!("{plan:?}"); assert_eq!(plan_str_ref, plan_str); diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs deleted file mode 100644 index 8cdf89b294730..0000000000000 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ /dev/null @@ -1,685 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_substrait::logical_plan::{consumer, producer}; - -#[cfg(test)] -mod tests { - - use std::hash::Hash; - use std::sync::Arc; - - use crate::{consumer::from_substrait_plan, producer::to_substrait_plan}; - use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; - use datafusion::common::{DFSchema, DFSchemaRef}; - use datafusion::error::Result; - use datafusion::execution::context::SessionState; - use datafusion::execution::registry::SerializerRegistry; - use datafusion::execution::runtime_env::RuntimeEnv; - use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; - use datafusion::prelude::*; - use substrait::proto::extensions::simple_extension_declaration::MappingType; - - struct MockSerializerRegistry; - - impl SerializerRegistry for MockSerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - if node.name() == "MockUserDefinedLogicalPlan" { - let node = node - .as_any() - .downcast_ref::() - .unwrap(); - node.serialize() - } else { - unreachable!() - } - } - - fn deserialize_logical_plan( - &self, - name: &str, - bytes: &[u8], - ) -> Result> - { - if name == "MockUserDefinedLogicalPlan" { - MockUserDefinedLogicalPlan::deserialize(bytes) - } else { - unreachable!() - } - } - } - - #[derive(Debug, PartialEq, Eq, Hash)] - struct MockUserDefinedLogicalPlan { - /// Replacement for serialize/deserialize data - validation_bytes: Vec, - inputs: Vec, - empty_schema: DFSchemaRef, - } - - impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "MockUserDefinedLogicalPlan" - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - self.inputs.iter().collect() - } - - fn schema(&self) -> &DFSchemaRef { - &self.empty_schema - } - - fn expressions(&self) -> Vec { - vec![] - } - - fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "MockUserDefinedLogicalPlan [validation_bytes={:?}]", - self.validation_bytes - ) - } - - fn from_template( - &self, - _: &[Expr], - inputs: &[LogicalPlan], - ) -> Arc { - Arc::new(Self { - validation_bytes: self.validation_bytes.clone(), - inputs: inputs.to_vec(), - empty_schema: Arc::new(DFSchema::empty()), - }) - } - - fn dyn_hash(&self, _: &mut dyn std::hash::Hasher) { - unimplemented!() - } - - fn dyn_eq(&self, _: &dyn UserDefinedLogicalNode) -> bool { - unimplemented!() - } - } - - impl MockUserDefinedLogicalPlan { - pub fn new(validation_bytes: Vec) -> Self { - Self { - validation_bytes, - inputs: vec![], - empty_schema: Arc::new(DFSchema::empty()), - } - } - - fn serialize(&self) -> Result> { - Ok(self.validation_bytes.clone()) - } - - fn deserialize(bytes: &[u8]) -> Result> - where - Self: Sized, - { - Ok(Arc::new(MockUserDefinedLogicalPlan::new(bytes.to_vec()))) - } - } - - #[tokio::test] - async fn simple_select() -> Result<()> { - roundtrip("SELECT a, b FROM data").await - } - - #[tokio::test] - async fn wildcard_select() -> Result<()> { - roundtrip("SELECT * FROM data").await - } - - #[tokio::test] - async fn select_with_filter() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a > 1").await - } - - #[tokio::test] - async fn select_with_reused_functions() -> Result<()> { - let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; - roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = - function_extension_info(sql).await?; - function_names.sort(); - function_anchors.sort(); - - assert_eq!(function_names, ["and", "gt", "lt"]); - assert_eq!(function_anchors, [0, 1, 2]); - - Ok(()) - } - - #[tokio::test] - async fn select_with_filter_date() -> Result<()> { - roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await - } - - #[tokio::test] - async fn select_with_filter_bool_expr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE d AND a > 1").await - } - - #[tokio::test] - async fn select_with_limit() -> Result<()> { - roundtrip_fill_na("SELECT * FROM data LIMIT 100").await - } - - #[tokio::test] - async fn select_with_limit_offset() -> Result<()> { - roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await - } - - #[tokio::test] - async fn simple_aggregate() -> Result<()> { - roundtrip("SELECT a, sum(b) FROM data GROUP BY a").await - } - - #[tokio::test] - async fn aggregate_distinct_with_having() -> Result<()> { - roundtrip( - "SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100", - ) - .await - } - - #[tokio::test] - async fn aggregate_multiple_keys() -> Result<()> { - roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await - } - - #[tokio::test] - async fn decimal_literal() -> Result<()> { - roundtrip("SELECT * FROM data WHERE b > 2.5").await - } - - #[tokio::test] - async fn null_decimal_literal() -> Result<()> { - roundtrip("SELECT * FROM data WHERE b = NULL").await - } - - #[tokio::test] - async fn u32_literal() -> Result<()> { - roundtrip("SELECT * FROM data WHERE e > 4294967295").await - } - - #[tokio::test] - async fn simple_distinct() -> Result<()> { - test_alias( - "SELECT distinct a FROM data", - "SELECT a FROM data GROUP BY a", - ) - .await - } - - #[tokio::test] - async fn select_distinct_two_fields() -> Result<()> { - test_alias( - "SELECT distinct a, b FROM data", - "SELECT a, b FROM data GROUP BY a, b", - ) - .await - } - - #[tokio::test] - async fn simple_alias() -> Result<()> { - test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await - } - - #[tokio::test] - async fn two_table_alias() -> Result<()> { - test_alias( - "SELECT d1.a FROM data d1 JOIN data2 d2 ON d1.a = d2.a", - "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", - ) - .await - } - - #[tokio::test] - async fn between_integers() -> Result<()> { - test_alias( - "SELECT * FROM data WHERE a BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a >= 2 AND a <= 6", - ) - .await - } - - #[tokio::test] - async fn not_between_integers() -> Result<()> { - test_alias( - "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6", - "SELECT * FROM data WHERE a < 2 OR a > 6", - ) - .await - } - - #[tokio::test] - async fn simple_scalar_function_abs() -> Result<()> { - roundtrip("SELECT ABS(a) FROM data").await - } - - #[tokio::test] - async fn simple_scalar_function_pow() -> Result<()> { - roundtrip("SELECT POW(a, 2) FROM data").await - } - - #[tokio::test] - async fn simple_scalar_function_substr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await - } - - #[tokio::test] - async fn case_without_base_expression() -> Result<()> { - roundtrip( - "SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' END) FROM data", - ) - .await - } - - #[tokio::test] - async fn case_with_base_expression() -> Result<()> { - roundtrip( - "SELECT (CASE a - WHEN 0 THEN 'zero' - WHEN 1 THEN 'one' - ELSE 'other' - END) FROM data", - ) - .await - } - - #[tokio::test] - async fn cast_decimal_to_int() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = CAST(2.5 AS int)").await - } - - #[tokio::test] - async fn implicit_cast() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = b").await - } - - #[tokio::test] - async fn aggregate_case() -> Result<()> { - assert_expected_plan( - "SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ - \n TableScan: data projection=[a]", - ) - .await - } - - #[tokio::test] - async fn roundtrip_inlist() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a IN (1, 2, 3)").await - } - - #[tokio::test] - async fn roundtrip_inner_join() -> Result<()> { - roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await - } - - #[tokio::test] - async fn inner_join() -> Result<()> { - assert_expected_plan( - "SELECT data.a FROM data JOIN data2 ON data.a = data2.a", - "Projection: data.a\ - \n Inner Join: data.a = data2.a\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", - ) - .await - } - - #[tokio::test] - async fn roundtrip_left_join() -> Result<()> { - roundtrip("SELECT data.a FROM data LEFT JOIN data2 ON data.a = data2.a").await - } - - #[tokio::test] - async fn roundtrip_right_join() -> Result<()> { - roundtrip("SELECT data.a FROM data RIGHT JOIN data2 ON data.a = data2.a").await - } - - #[tokio::test] - async fn roundtrip_outer_join() -> Result<()> { - roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a") - .await - } - - #[tokio::test] - async fn simple_intersect() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data2.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data2 projection=[a]", - ) - .await - } - - #[tokio::test] - async fn simple_intersect_table_reuse() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", - ) - .await - } - - #[tokio::test] - async fn simple_window_function() -> Result<()> { - roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) OVER (PARTITION BY a) FROM data;").await - } - - #[tokio::test] - async fn qualified_schema_table_reference() -> Result<()> { - roundtrip("SELECT * FROM public.data;").await - } - - #[tokio::test] - async fn qualified_catalog_schema_table_reference() -> Result<()> { - roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await - } - - /// Construct a plan that contains several literals of types that are currently supported. - /// This case ignores: - /// - Date64, for this literal is not supported - /// - FixedSizeBinary, for converting UTF-8 literal to FixedSizeBinary is not supported - /// - List, this nested type is not supported in arrow_cast - /// - Decimal128 and Decimal256, them will fallback to UTF8 cast expr rather than plain literal. - #[tokio::test] - async fn all_type_literal() -> Result<()> { - roundtrip_all_types( - "select * from data where - bool_col = TRUE AND - int8_col = arrow_cast('0', 'Int8') AND - uint8_col = arrow_cast('0', 'UInt8') AND - int16_col = arrow_cast('0', 'Int16') AND - uint16_col = arrow_cast('0', 'UInt16') AND - int32_col = arrow_cast('0', 'Int32') AND - uint32_col = arrow_cast('0', 'UInt32') AND - int64_col = arrow_cast('0', 'Int64') AND - uint64_col = arrow_cast('0', 'UInt64') AND - float32_col = arrow_cast('0', 'Float32') AND - float64_col = arrow_cast('0', 'Float64') AND - sec_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Second, None)') AND - ms_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Millisecond, None)') AND - us_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Microsecond, None)') AND - ns_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Nanosecond, None)') AND - date32_col = arrow_cast('2020-01-01', 'Date32') AND - binary_col = arrow_cast('binary', 'Binary') AND - large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND - utf8_col = arrow_cast('utf8', 'Utf8') AND - large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');", - ) - .await - } - - /// Construct a plan that cast columns. Only those SQL types are supported for now. - #[tokio::test] - async fn new_test_grammar() -> Result<()> { - roundtrip_all_types( - "select - bool_col::boolean, - int8_col::tinyint, - uint8_col::tinyint unsigned, - int16_col::smallint, - uint16_col::smallint unsigned, - int32_col::integer, - uint32_col::integer unsigned, - int64_col::bigint, - uint64_col::bigint unsigned, - float32_col::float, - float64_col::double, - decimal_128_col::decimal(10, 2), - date32_col::date, - binary_col::bytea - from data", - ) - .await - } - - #[tokio::test] - async fn extension_logical_plan() -> Result<()> { - let mut ctx = create_context().await?; - let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec(); - let ext_plan = LogicalPlan::Extension(Extension { - node: Arc::new(MockUserDefinedLogicalPlan { - validation_bytes, - inputs: vec![], - empty_schema: Arc::new(DFSchema::empty()), - }), - }); - - let proto = to_substrait_plan(&ext_plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; - - let plan1str = format!("{ext_plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - - Ok(()) - } - - async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { - let mut ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - let plan2str = format!("{plan2:?}"); - assert_eq!(expected_plan_str, &plan2str); - Ok(()) - } - - async fn roundtrip_fill_na(sql: &str) -> Result<()> { - let mut ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan1 = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan1, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - // Format plan string and replace all None's with 0 - let plan1str = format!("{plan1:?}").replace("None", "0"); - let plan2str = format!("{plan2:?}").replace("None", "0"); - - assert_eq!(plan1str, plan2str); - Ok(()) - } - - async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { - // Since we ignore the SubqueryAlias in the producer, the result should be - // the same as producing a Substrait plan from the same query without aliases - // sql_with_alias -> substrait -> logical plan = sql_no_alias -> substrait -> logical plan - let mut ctx = create_context().await?; - - let df_a = ctx.sql(sql_with_alias).await?; - let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; - let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?; - - let df = ctx.sql(sql_no_alias).await?; - let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; - let plan = from_substrait_plan(&mut ctx, &proto).await?; - - println!("{plan_with_alias:#?}"); - println!("{plan:#?}"); - - let plan1str = format!("{plan_with_alias:?}"); - let plan2str = format!("{plan:?}"); - assert_eq!(plan1str, plan2str); - Ok(()) - } - - async fn roundtrip(sql: &str) -> Result<()> { - let mut ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - Ok(()) - } - - async fn roundtrip_all_types(sql: &str) -> Result<()> { - let mut ctx = create_all_type_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&mut ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); - Ok(()) - } - - async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { - let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - - let mut function_names: Vec = vec![]; - let mut function_anchors: Vec = vec![]; - for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() - { - MappingType::ExtensionFunction(ext_f) => { - (ext_f.function_anchor, &ext_f.name) - } - _ => unreachable!("Producer does not generate a non-function extension"), - }; - function_names.push(function_name.to_string()); - function_anchors.push(function_anchor); - } - - Ok((function_names, function_anchors)) - } - - async fn create_context() -> Result { - let state = SessionState::with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ) - .with_serializer_registry(Arc::new(MockSerializerRegistry)); - let ctx = SessionContext::with_state(state); - let mut explicit_options = CsvReadOptions::new(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Decimal128(5, 2), true), - Field::new("c", DataType::Date32, true), - Field::new("d", DataType::Boolean, true), - Field::new("e", DataType::UInt32, true), - ]); - explicit_options.schema = Some(&schema); - ctx.register_csv("data", "tests/testdata/data.csv", explicit_options) - .await?; - ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) - .await?; - Ok(ctx) - } - - /// Cover all supported types - async fn create_all_type_context() -> Result { - let ctx = SessionContext::new(); - let mut explicit_options = CsvReadOptions::new(); - let schema = Schema::new(vec![ - Field::new("bool_col", DataType::Boolean, true), - Field::new("int8_col", DataType::Int8, true), - Field::new("uint8_col", DataType::UInt8, true), - Field::new("int16_col", DataType::Int16, true), - Field::new("uint16_col", DataType::UInt16, true), - Field::new("int32_col", DataType::Int32, true), - Field::new("uint32_col", DataType::UInt32, true), - Field::new("int64_col", DataType::Int64, true), - Field::new("uint64_col", DataType::UInt64, true), - Field::new("float32_col", DataType::Float32, true), - Field::new("float64_col", DataType::Float64, true), - Field::new( - "sec_timestamp_col", - DataType::Timestamp(TimeUnit::Second, None), - true, - ), - Field::new( - "ms_timestamp_col", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "us_timestamp_col", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "ns_timestamp_col", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new("date32_col", DataType::Date32, true), - Field::new("date64_col", DataType::Date64, true), - Field::new("binary_col", DataType::Binary, true), - Field::new("large_binary_col", DataType::LargeBinary, true), - Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true), - Field::new("utf8_col", DataType::Utf8, true), - Field::new("large_utf8_col", DataType::LargeUtf8, true), - Field::new_list("list_col", Field::new("item", DataType::Int64, true), true), - Field::new_list( - "large_list_col", - Field::new("item", DataType::Int64, true), - true, - ), - Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), - Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), - ]); - explicit_options.schema = Some(&schema); - explicit_options.has_header = false; - ctx.register_csv("data", "tests/testdata/empty.csv", explicit_options) - .await?; - - Ok(ctx) - } -} diff --git a/datafusion/substrait/tests/roundtrip_physical_plan.rs b/datafusion/substrait/tests/roundtrip_physical_plan.rs deleted file mode 100644 index de549412b61ff..0000000000000 --- a/datafusion/substrait/tests/roundtrip_physical_plan.rs +++ /dev/null @@ -1,80 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#[cfg(test)] -mod tests { - use datafusion::arrow::datatypes::Schema; - use datafusion::datasource::listing::PartitionedFile; - use datafusion::datasource::object_store::ObjectStoreUrl; - use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; - use datafusion::error::Result; - use datafusion::physical_plan::{displayable, ExecutionPlan}; - use datafusion::prelude::SessionContext; - use datafusion_substrait::physical_plan::{consumer, producer}; - use std::collections::HashMap; - use std::sync::Arc; - use substrait::proto::extensions; - - #[tokio::test] - async fn parquet_exec() -> Result<()> { - let scan_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: Arc::new(Schema::empty()), - file_groups: vec![ - vec![PartitionedFile::new( - "file://foo/part-0.parquet".to_string(), - 123, - )], - vec![PartitionedFile::new( - "file://foo/part-1.parquet".to_string(), - 123, - )], - ], - statistics: Default::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }; - let parquet_exec: Arc = - Arc::new(ParquetExec::new(scan_config, None, None)); - - let mut extension_info: ( - Vec, - HashMap, - ) = (vec![], HashMap::new()); - - let substrait_rel = - producer::to_substrait_rel(parquet_exec.as_ref(), &mut extension_info)?; - - let mut ctx = SessionContext::new(); - - let parquet_exec_roundtrip = consumer::from_substrait_rel( - &mut ctx, - substrait_rel.as_ref(), - &HashMap::new(), - ) - .await?; - - let expected = format!("{}", displayable(parquet_exec.as_ref()).indent()); - let actual = format!("{}", displayable(parquet_exec_roundtrip.as_ref()).indent()); - assert_eq!(expected, actual); - - Ok(()) - } -} diff --git a/datafusion/substrait/tests/substrait_integration.rs b/datafusion/substrait/tests/substrait_integration.rs new file mode 100644 index 0000000000000..6ce41c9de71a8 --- /dev/null +++ b/datafusion/substrait/tests/substrait_integration.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Run all tests that are found in the `cases` directory +mod cases; diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 170457da5812f..1b85b166b1dfb 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ -a,b,c,d,e -1,2.0,2020-01-01,false,4294967296 -3,4.5,2020-01-01,true,2147483648 \ No newline at end of file +a,b,c,d,e,f +1,2.0,2020-01-01,false,4294967296,'a' +3,4.5,2020-01-01,true,2147483648,'b' \ No newline at end of file diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml new file mode 100644 index 0000000000000..c5f795d0653ae --- /dev/null +++ b/datafusion/wasmtest/Cargo.toml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-wasmtest" +description = "Test library to compile datafusion crates to wasm" +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = "1.70" + +[lib] +crate-type = ["cdylib", "rlib",] + +[dependencies] + +# The `console_error_panic_hook` crate provides better debugging of panics by +# logging them with `console.error`. This is great for development, but requires +# all the `std::fmt` and `std::panicking` infrastructure, so isn't great for +# code size when deploying. +console_error_panic_hook = { version = "0.1.1", optional = true } + +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-optimizer = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-sql = { workspace = true } + +# getrandom must be compiled with js feature +getrandom = { version = "0.2.8", features = ["js"] } +parquet = { workspace = true } +wasm-bindgen = "0.2.87" diff --git a/datafusion/wasmtest/README.md b/datafusion/wasmtest/README.md new file mode 100644 index 0000000000000..d26369a18ab9e --- /dev/null +++ b/datafusion/wasmtest/README.md @@ -0,0 +1,68 @@ + + +# DataFusion wasmtest + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion used to verify that various DataFusion crates compile successfully to the +`wasm32-unknown-unknown` target with wasm-pack. + +[df]: https://crates.io/crates/datafusion + +## wasmtest + +Some of DataFusion's downstream projects compile to WASM to run in the browser. Doing so requires special care that certain library dependencies are not included in DataFusion. + +## Setup + +First, [install wasm-pack](https://rustwasm.github.io/wasm-pack/installer/) + +Then use wasm-pack to compile the crate from within this directory + +``` +wasm-pack build +``` + +## Try it out + +The `datafusion-wasm-app` directory contains a simple app (created with [`create-wasm-app`](https://github.com/rustwasm/create-wasm-app) and then manually updated to WebPack 5) that invokes DataFusion and writes results to the browser console. + +From within the `datafusion/wasmtest/datafusion-wasm-app` directory: + +``` +npm install +npm run start +``` + +Then open http://localhost:8080/ in a web browser and check the console to see the results of using various DataFusion crates. + +**Note:** In GitHub Actions we test the compilation with `wasm-build`, but we don't currently invoke `datafusion-wasm-app`. In the future we may want to test the behavior of the WASM build using [`wasm-pack test`](https://rustwasm.github.io/wasm-pack/book/tutorials/npm-browser-packages/testing-your-project.html). + +## Compatibility + +The following DataFusion crates are verified to work in a wasm-pack environment using the default `wasm32-unknown-unknown` target: + +- `datafusion-common` with default-features disabled to remove the `parquet` dependency (see below) +- `datafusion-expr` +- `datafusion-optimizer` +- `datafusion-physical-expr` +- `datafusion-sql` + +The difficulty with getting the remaining DataFusion crates compiled to WASM is that they have non-optional dependencies on the [`parquet`](https://docs.rs/crate/parquet/) crate with its default features enabled. Several of the default parquet crate features require native dependencies that are not compatible with WASM, in particular the `lz4` and `zstd` features. If we can arrange our feature flags to make it possible to depend on parquet with these features disabled, then it should be possible to compile the core `datafusion` crate to WASM as well. diff --git a/datafusion/wasmtest/datafusion-wasm-app/.gitignore b/datafusion/wasmtest/datafusion-wasm-app/.gitignore new file mode 100644 index 0000000000000..f06235c460c2d --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/.gitignore @@ -0,0 +1,2 @@ +node_modules +dist diff --git a/datafusion/wasmtest/datafusion-wasm-app/README.md b/datafusion/wasmtest/datafusion-wasm-app/README.md new file mode 100644 index 0000000000000..3cc362de8f746 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/README.md @@ -0,0 +1,68 @@ +
+ +

create-wasm-app

+ +An npm init template for kick starting a project that uses NPM packages containing Rust-generated WebAssembly and bundles them with Webpack. + +

+ Build Status +

+ +

+ Usage + | + Chat +

+ +Built with 🦀🕸 by The Rust and WebAssembly Working Group + +
+ +## About + +This template is designed for depending on NPM packages that contain +Rust-generated WebAssembly and using them to create a Website. + +- Want to create an NPM package with Rust and WebAssembly? [Check out + `wasm-pack-template`.](https://github.com/rustwasm/wasm-pack-template) +- Want to make a monorepo-style Website without publishing to NPM? Check out + [`rust-webpack-template`](https://github.com/rustwasm/rust-webpack-template) + and/or + [`rust-parcel-template`](https://github.com/rustwasm/rust-parcel-template). + +## 🚴 Usage + +``` +npm init wasm-app +``` + +## 🔋 Batteries Included + +- `.gitignore`: ignores `node_modules` +- `LICENSE-APACHE` and `LICENSE-MIT`: most Rust projects are licensed this way, so these are included for you +- `README.md`: the file you are reading now! +- `index.html`: a bare bones html document that includes the webpack bundle +- `index.js`: example js file with a comment showing how to import and use a wasm pkg +- `package.json` and `package-lock.json`: + - pulls in devDependencies for using webpack: + - [`webpack`](https://www.npmjs.com/package/webpack) + - [`webpack-cli`](https://www.npmjs.com/package/webpack-cli) + - [`webpack-dev-server`](https://www.npmjs.com/package/webpack-dev-server) + - defines a `start` script to run `webpack-dev-server` +- `webpack.config.js`: configuration file for bundling your js with webpack + +## License + +Licensed under either of + +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +- MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, as defined in the Apache-2.0 +license, shall be dual licensed as above, without any additional terms or +conditions. diff --git a/datafusion/wasmtest/datafusion-wasm-app/bootstrap.js b/datafusion/wasmtest/datafusion-wasm-app/bootstrap.js new file mode 100644 index 0000000000000..4ad835cb679f1 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/bootstrap.js @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A dependency graph that contains any wasm must all be imported +// asynchronously. This `bootstrap.js` file does the single async import, so +// that no one else needs to worry about it again. +import("./index.js") + .catch(e => console.error("Error importing `index.js`:", e)); diff --git a/datafusion/wasmtest/datafusion-wasm-app/index.html b/datafusion/wasmtest/datafusion-wasm-app/index.html new file mode 100644 index 0000000000000..4d50e2c01e416 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/index.html @@ -0,0 +1,12 @@ + + + + + Hello wasm-pack! + + +

See console

+ + + + diff --git a/datafusion/wasmtest/datafusion-wasm-app/index.js b/datafusion/wasmtest/datafusion-wasm-app/index.js new file mode 100644 index 0000000000000..7ee31b7d3802e --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/index.js @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import * as wasm from "datafusion-wasmtest"; + +wasm.try_datafusion(); diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json new file mode 100644 index 0000000000000..c7b90cf05f1b1 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -0,0 +1,7461 @@ +{ + "name": "create-wasm-app", + "version": "0.1.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "name": "create-wasm-app", + "version": "0.1.0", + "license": "(MIT OR Apache-2.0)", + "dependencies": { + "datafusion-wasmtest": "../pkg" + }, + "devDependencies": { + "copy-webpack-plugin": "6.4.1", + "webpack": "5.88.2", + "webpack-cli": "5.1.4", + "webpack-dev-server": "4.15.1" + } + }, + "../pkg": { + "name": "datafusion-wasmtest", + "version": "31.0.0", + "license": "Apache-2.0" + }, + "node_modules/@discoveryjs/json-ext": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz", + "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", + "dev": true, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/@gar/promisify": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.3.tgz", + "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", + "dev": true + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.3", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", + "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "dev": true, + "dependencies": { + "@jridgewell/set-array": "^1.0.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.9" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz", + "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==", + "dev": true, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", + "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", + "dev": true, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/source-map": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.5.tgz", + "integrity": "sha512-UTYAUj/wviwdsMfzoSJspJxbkH5o1snzwX0//0ENX1u/55kkZZkcTZP6u9bwKGkv+dkk9at4m1Cpt0uY80kcpQ==", + "dev": true, + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.0", + "@jridgewell/trace-mapping": "^0.3.9" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.4.15", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz", + "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==", + "dev": true + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.19", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz", + "integrity": "sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw==", + "dev": true, + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@leichtgewicht/ip-codec": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", + "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", + "dev": true + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@npmcli/fs": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-1.1.1.tgz", + "integrity": "sha512-8KG5RD0GVP4ydEzRn/I4BNDuxDtqVbOdm8675T49OIG/NGhaK0pjPX7ZcDlvKYbA+ulvVK3ztfcF4uBdOxuJbQ==", + "dev": true, + "dependencies": { + "@gar/promisify": "^1.0.1", + "semver": "^7.3.5" + } + }, + "node_modules/@npmcli/move-file": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-1.1.2.tgz", + "integrity": "sha512-1SUf/Cg2GzGDyaf15aR9St9TWlb+XvbZXWpDx8YKs7MLzMH/BCeopv+y9vzrzgkfykCGuWOlSu3mZhj2+FQcrg==", + "deprecated": "This functionality has been moved to @npmcli/fs", + "dev": true, + "dependencies": { + "mkdirp": "^1.0.4", + "rimraf": "^3.0.2" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/@types/body-parser": { + "version": "1.19.3", + "resolved": "https://registry.npmjs.org/@types/body-parser/-/body-parser-1.19.3.tgz", + "integrity": "sha512-oyl4jvAfTGX9Bt6Or4H9ni1Z447/tQuxnZsytsCaExKlmJiU8sFgnIBRzJUpKwB5eWn9HuBYlUlVA74q/yN0eQ==", + "dev": true, + "dependencies": { + "@types/connect": "*", + "@types/node": "*" + } + }, + "node_modules/@types/bonjour": { + "version": "3.5.11", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", + "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/connect": { + "version": "3.4.36", + "resolved": "https://registry.npmjs.org/@types/connect/-/connect-3.4.36.tgz", + "integrity": "sha512-P63Zd/JUGq+PdrM1lv0Wv5SBYeA2+CORvbrXbngriYY0jzLUWfQMQQxOhjONEz/wlHOAxOdY7CY65rgQdTjq2w==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/connect-history-api-fallback": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", + "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "dev": true, + "dependencies": { + "@types/express-serve-static-core": "*", + "@types/node": "*" + } + }, + "node_modules/@types/eslint": { + "version": "8.44.2", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.44.2.tgz", + "integrity": "sha512-sdPRb9K6iL5XZOmBubg8yiFp5yS/JdUDQsq5e6h95km91MCYMuvp7mh1fjPEYUhvHepKpZOjnEaMBR4PxjWDzg==", + "dev": true, + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/eslint-scope": { + "version": "3.7.4", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", + "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", + "dev": true, + "dependencies": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.1.tgz", + "integrity": "sha512-LG4opVs2ANWZ1TJoKc937iMmNstM/d0ae1vNbnBvBhqCSezgVUOzcLCqbI5elV8Vy6WKwKjaqR+zO9VKirBBCA==", + "dev": true + }, + "node_modules/@types/express": { + "version": "4.17.17", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", + "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "dev": true, + "dependencies": { + "@types/body-parser": "*", + "@types/express-serve-static-core": "^4.17.33", + "@types/qs": "*", + "@types/serve-static": "*" + } + }, + "node_modules/@types/express-serve-static-core": { + "version": "4.17.36", + "resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.17.36.tgz", + "integrity": "sha512-zbivROJ0ZqLAtMzgzIUC4oNqDG9iF0lSsAqpOD9kbs5xcIM3dTiyuHvBc7R8MtWBp3AAWGaovJa+wzWPjLYW7Q==", + "dev": true, + "dependencies": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "node_modules/@types/http-errors": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@types/http-errors/-/http-errors-2.0.2.tgz", + "integrity": "sha512-lPG6KlZs88gef6aD85z3HNkztpj7w2R7HmR3gygjfXCQmsLloWNARFkMuzKiiY8FGdh1XDpgBdrSf4aKDiA7Kg==", + "dev": true + }, + "node_modules/@types/http-proxy": { + "version": "1.17.12", + "resolved": "https://registry.npmjs.org/@types/http-proxy/-/http-proxy-1.17.12.tgz", + "integrity": "sha512-kQtujO08dVtQ2wXAuSFfk9ASy3sug4+ogFR8Kd8UgP8PEuc1/G/8yjYRmp//PcDNJEUKOza/MrQu15bouEUCiw==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/json-schema": { + "version": "7.0.13", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", + "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "dev": true + }, + "node_modules/@types/mime": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@types/mime/-/mime-1.3.2.tgz", + "integrity": "sha512-YATxVxgRqNH6nHEIsvg6k2Boc1JHI9ZbH5iWFFv/MTkchz3b1ieGDa5T0a9RznNdI0KhVbdbWSN+KWWrQZRxTw==", + "dev": true + }, + "node_modules/@types/node": { + "version": "20.6.3", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.6.3.tgz", + "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", + "dev": true + }, + "node_modules/@types/qs": { + "version": "6.9.8", + "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", + "integrity": "sha512-u95svzDlTysU5xecFNTgfFG5RUWu1A9P0VzgpcIiGZA9iraHOdSzcxMxQ55DyeRaGCSxQi7LxXDI4rzq/MYfdg==", + "dev": true + }, + "node_modules/@types/range-parser": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@types/range-parser/-/range-parser-1.2.4.tgz", + "integrity": "sha512-EEhsLsD6UsDM1yFhAvy0Cjr6VwmpMWqFBCb9w07wVugF7w9nfajxLuVmngTIpgS6svCnm6Vaw+MZhoDCKnOfsw==", + "dev": true + }, + "node_modules/@types/retry": { + "version": "0.12.0", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", + "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "dev": true + }, + "node_modules/@types/send": { + "version": "0.17.1", + "resolved": "https://registry.npmjs.org/@types/send/-/send-0.17.1.tgz", + "integrity": "sha512-Cwo8LE/0rnvX7kIIa3QHCkcuF21c05Ayb0ZfxPiv0W8VRiZiNW/WuRupHKpqqGVGf7SUA44QSOUKaEd9lIrd/Q==", + "dev": true, + "dependencies": { + "@types/mime": "^1", + "@types/node": "*" + } + }, + "node_modules/@types/serve-index": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", + "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "dev": true, + "dependencies": { + "@types/express": "*" + } + }, + "node_modules/@types/serve-static": { + "version": "1.15.2", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", + "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "dev": true, + "dependencies": { + "@types/http-errors": "*", + "@types/mime": "*", + "@types/node": "*" + } + }, + "node_modules/@types/sockjs": { + "version": "0.3.33", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", + "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/ws": { + "version": "8.5.5", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", + "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@webassemblyjs/ast": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.6.tgz", + "integrity": "sha512-IN1xI7PwOvLPgjcf180gC1bqn3q/QaOCwYUahIOhbYUu8KA/3tw2RT/T0Gidi1l7Hhj5D/INhJxiICObqpMu4Q==", + "dev": true, + "dependencies": { + "@webassemblyjs/helper-numbers": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + } + }, + "node_modules/@webassemblyjs/floating-point-hex-parser": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", + "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "dev": true + }, + "node_modules/@webassemblyjs/helper-api-error": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", + "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "dev": true + }, + "node_modules/@webassemblyjs/helper-buffer": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.6.tgz", + "integrity": "sha512-z3nFzdcp1mb8nEOFFk8DrYLpHvhKC3grJD2ardfKOzmbmJvEf/tPIqCY+sNcwZIY8ZD7IkB2l7/pqhUhqm7hLA==", + "dev": true + }, + "node_modules/@webassemblyjs/helper-numbers": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", + "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "dev": true, + "dependencies": { + "@webassemblyjs/floating-point-hex-parser": "1.11.6", + "@webassemblyjs/helper-api-error": "1.11.6", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/helper-wasm-bytecode": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", + "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "dev": true + }, + "node_modules/@webassemblyjs/helper-wasm-section": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.6.tgz", + "integrity": "sha512-LPpZbSOwTpEC2cgn4hTydySy1Ke+XEu+ETXuoyvuyezHO3Kjdu90KK95Sh9xTbmjrCsUwvWwCOQQNta37VrS9g==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6" + } + }, + "node_modules/@webassemblyjs/ieee754": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", + "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "dev": true, + "dependencies": { + "@xtuc/ieee754": "^1.2.0" + } + }, + "node_modules/@webassemblyjs/leb128": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", + "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "dev": true, + "dependencies": { + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/utf8": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", + "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "dev": true + }, + "node_modules/@webassemblyjs/wasm-edit": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.6.tgz", + "integrity": "sha512-Ybn2I6fnfIGuCR+Faaz7YcvtBKxvoLV3Lebn1tM4o/IAJzmi9AWYIPWpyBfU8cC+JxAO57bk4+zdsTjJR+VTOw==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/helper-wasm-section": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6", + "@webassemblyjs/wasm-opt": "1.11.6", + "@webassemblyjs/wasm-parser": "1.11.6", + "@webassemblyjs/wast-printer": "1.11.6" + } + }, + "node_modules/@webassemblyjs/wasm-gen": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.6.tgz", + "integrity": "sha512-3XOqkZP/y6B4F0PBAXvI1/bky7GryoogUtfwExeP/v7Nzwo1QLcq5oQmpKlftZLbT+ERUOAZVQjuNVak6UXjPA==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/ieee754": "1.11.6", + "@webassemblyjs/leb128": "1.11.6", + "@webassemblyjs/utf8": "1.11.6" + } + }, + "node_modules/@webassemblyjs/wasm-opt": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.6.tgz", + "integrity": "sha512-cOrKuLRE7PCe6AsOVl7WasYf3wbSo4CeOk6PkrjS7g57MFfVUF9u6ysQBBODX0LdgSvQqRiGz3CXvIDKcPNy4g==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6", + "@webassemblyjs/wasm-parser": "1.11.6" + } + }, + "node_modules/@webassemblyjs/wasm-parser": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.6.tgz", + "integrity": "sha512-6ZwPeGzMJM3Dqp3hCsLgESxBGtT/OeCvCZ4TA1JUPYgmhAx38tTPR9JaKy0S5H3evQpO/h2uWs2j6Yc/fjkpTQ==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/ieee754": "1.11.6", + "@webassemblyjs/leb128": "1.11.6", + "@webassemblyjs/utf8": "1.11.6" + } + }, + "node_modules/@webassemblyjs/wast-printer": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.6.tgz", + "integrity": "sha512-JM7AhRcE+yW2GWYaKeHL5vt4xqee5N2WcezptmgyhNS+ScggqcT1OtXykhAb13Sn5Yas0j2uv9tHgrjwvzAP4A==", + "dev": true, + "dependencies": { + "@webassemblyjs/ast": "1.11.6", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webpack-cli/configtest": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.1.1.tgz", + "integrity": "sha512-wy0mglZpDSiSS0XHrVR+BAdId2+yxPSoJW8fsna3ZpYSlufjvxnP4YbKTCBZnNIcGN4r6ZPXV55X4mYExOfLmw==", + "dev": true, + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + } + }, + "node_modules/@webpack-cli/info": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@webpack-cli/info/-/info-2.0.2.tgz", + "integrity": "sha512-zLHQdI/Qs1UyT5UBdWNqsARasIA+AaF8t+4u2aS2nEpBQh2mWIVb8qAklq0eUENnC5mOItrIB4LiS9xMtph18A==", + "dev": true, + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + } + }, + "node_modules/@webpack-cli/serve": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@webpack-cli/serve/-/serve-2.0.5.tgz", + "integrity": "sha512-lqaoKnRYBdo1UgDX8uF24AfGMifWK19TxPmM5FHc2vAGxrJ/qtyUyFBWoY1tISZdelsQ5fBcOusifo5o5wSJxQ==", + "dev": true, + "engines": { + "node": ">=14.15.0" + }, + "peerDependencies": { + "webpack": "5.x.x", + "webpack-cli": "5.x.x" + }, + "peerDependenciesMeta": { + "webpack-dev-server": { + "optional": true + } + } + }, + "node_modules/@xtuc/ieee754": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@xtuc/ieee754/-/ieee754-1.2.0.tgz", + "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", + "dev": true + }, + "node_modules/@xtuc/long": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", + "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", + "dev": true + }, + "node_modules/accepts": { + "version": "1.3.8", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", + "integrity": "sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw==", + "dev": true, + "dependencies": { + "mime-types": "~2.1.34", + "negotiator": "0.6.3" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/acorn": { + "version": "8.10.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", + "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-import-assertions": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.9.0.tgz", + "integrity": "sha512-cmMwop9x+8KFhxvKrKfPYmN6/pKTYYHBqLa0DfvVZcKMJWNyWLnaqND7dx/qn66R7ewM1UX5XMaDVP5wlVTaVA==", + "dev": true, + "peerDependencies": { + "acorn": "^8" + } + }, + "node_modules/aggregate-error": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz", + "integrity": "sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA==", + "dev": true, + "dependencies": { + "clean-stack": "^2.0.0", + "indent-string": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/ajv-formats/-/ajv-formats-2.1.1.tgz", + "integrity": "sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==", + "dev": true, + "dependencies": { + "ajv": "^8.0.0" + }, + "peerDependencies": { + "ajv": "^8.0.0" + }, + "peerDependenciesMeta": { + "ajv": { + "optional": true + } + } + }, + "node_modules/ajv-formats/node_modules/ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "node_modules/ajv-keywords": { + "version": "3.5.2", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", + "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", + "dev": true, + "peerDependencies": { + "ajv": "^6.9.1" + } + }, + "node_modules/ansi-html-community": { + "version": "0.0.8", + "resolved": "https://registry.npmjs.org/ansi-html-community/-/ansi-html-community-0.0.8.tgz", + "integrity": "sha512-1APHAyr3+PCamwNw3bXCPp4HFLONZt/yIH0sZp0/469KWNTEy+qN5jQ3GVX6DMZ1UXAi34yVwtTeaG/HpBuuzw==", + "dev": true, + "engines": [ + "node >= 0.8.0" + ], + "bin": { + "ansi-html": "bin/ansi-html" + } + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/array-flatten": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", + "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", + "dev": true + }, + "node_modules/array-union": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", + "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true + }, + "node_modules/batch": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", + "integrity": "sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY=", + "dev": true + }, + "node_modules/big.js": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", + "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", + "dev": true, + "engines": { + "node": "*" + } + }, + "node_modules/binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/body-parser": { + "version": "1.20.1", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", + "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "dev": true, + "dependencies": { + "bytes": "3.1.2", + "content-type": "~1.0.4", + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "on-finished": "2.4.1", + "qs": "6.11.0", + "raw-body": "2.5.1", + "type-is": "~1.6.18", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8", + "npm": "1.2.8000 || >= 1.4.16" + } + }, + "node_modules/body-parser/node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/body-parser/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/body-parser/node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/bonjour-service": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", + "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", + "dev": true, + "dependencies": { + "array-flatten": "^2.1.2", + "dns-equal": "^1.0.0", + "fast-deep-equal": "^3.1.3", + "multicast-dns": "^7.2.5" + } + }, + "node_modules/brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dev": true, + "dependencies": { + "fill-range": "^7.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/browserslist": { + "version": "4.21.11", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", + "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "caniuse-lite": "^1.0.30001538", + "electron-to-chromium": "^1.4.526", + "node-releases": "^2.0.13", + "update-browserslist-db": "^1.0.13" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true + }, + "node_modules/bytes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", + "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/cacache": { + "version": "15.3.0", + "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", + "integrity": "sha512-VVdYzXEn+cnbXpFgWs5hTT7OScegHVmLhJIR8Ufqk3iFD6A6j5iSX1KuBTfNEv4tdJWE2PzA6IVFtcLC7fN9wQ==", + "dev": true, + "dependencies": { + "@npmcli/fs": "^1.0.0", + "@npmcli/move-file": "^1.0.1", + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "glob": "^7.1.4", + "infer-owner": "^1.0.4", + "lru-cache": "^6.0.0", + "minipass": "^3.1.1", + "minipass-collect": "^1.0.2", + "minipass-flush": "^1.0.5", + "minipass-pipeline": "^1.2.2", + "mkdirp": "^1.0.3", + "p-map": "^4.0.0", + "promise-inflight": "^1.0.1", + "rimraf": "^3.0.2", + "ssri": "^8.0.1", + "tar": "^6.0.2", + "unique-filename": "^1.1.1" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/call-bind": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", + "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.1", + "get-intrinsic": "^1.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001538", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", + "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ] + }, + "node_modules/chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + ], + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chownr": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", + "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", + "dev": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/chrome-trace-event": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.2.tgz", + "integrity": "sha512-9e/zx1jw7B4CO+c/RXoCsfg/x1AfUBioy4owYH0bJprEYAx5hRFLRhWBqHAG57D0ZM4H7vxbP7bPe0VwhQRYDQ==", + "dev": true, + "dependencies": { + "tslib": "^1.9.0" + }, + "engines": { + "node": ">=6.0" + } + }, + "node_modules/clean-stack": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", + "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/clone-deep": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", + "integrity": "sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==", + "dev": true, + "dependencies": { + "is-plain-object": "^2.0.4", + "kind-of": "^6.0.2", + "shallow-clone": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/colorette": { + "version": "2.0.20", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz", + "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", + "dev": true + }, + "node_modules/commander": { + "version": "2.20.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", + "dev": true + }, + "node_modules/commondir": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", + "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", + "dev": true + }, + "node_modules/compressible": { + "version": "2.0.18", + "resolved": "https://registry.npmjs.org/compressible/-/compressible-2.0.18.tgz", + "integrity": "sha512-AF3r7P5dWxL8MxyITRMlORQNaOA2IkAFaTr4k7BUumjPtRpGDTZpl0Pb1XCO6JeDCBdp126Cgs9sMxqSjgYyRg==", + "dev": true, + "dependencies": { + "mime-db": ">= 1.43.0 < 2" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/compression": { + "version": "1.7.4", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.7.4.tgz", + "integrity": "sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ==", + "dev": true, + "dependencies": { + "accepts": "~1.3.5", + "bytes": "3.0.0", + "compressible": "~2.0.16", + "debug": "2.6.9", + "on-headers": "~1.0.2", + "safe-buffer": "5.1.2", + "vary": "~1.1.2" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/compression/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true + }, + "node_modules/connect-history-api-fallback": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", + "integrity": "sha512-U73+6lQFmfiNPrYbXqr6kZ1i1wiRqXnp2nhMsINseWXO8lDau0LGEffJ8kQi4EjLZympVgRdvqjAgiZ1tgzDDA==", + "dev": true, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/content-disposition": { + "version": "0.5.4", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz", + "integrity": "sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ==", + "dev": true, + "dependencies": { + "safe-buffer": "5.2.1" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/content-disposition/node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", + "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==", + "dev": true + }, + "node_modules/copy-webpack-plugin": { + "version": "6.4.1", + "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-6.4.1.tgz", + "integrity": "sha512-MXyPCjdPVx5iiWyl40Va3JGh27bKzOTNY3NjUTrosD2q7dR/cLD0013uqJ3BpFbUjyONINjb6qI7nDIJujrMbA==", + "dev": true, + "dependencies": { + "cacache": "^15.0.5", + "fast-glob": "^3.2.4", + "find-cache-dir": "^3.3.1", + "glob-parent": "^5.1.1", + "globby": "^11.0.1", + "loader-utils": "^2.0.0", + "normalize-path": "^3.0.0", + "p-limit": "^3.0.2", + "schema-utils": "^3.0.0", + "serialize-javascript": "^5.0.1", + "webpack-sources": "^1.4.3" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^4.37.0 || ^5.0.0" + } + }, + "node_modules/core-util-is": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", + "integrity": "sha1-tf1UIgqivFq1eqtxQMlAdUUDwac=", + "dev": true + }, + "node_modules/cross-spawn": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "dev": true, + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/datafusion-wasmtest": { + "resolved": "../pkg", + "link": true + }, + "node_modules/debug": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.1.1.tgz", + "integrity": "sha512-pYAIzeRo8J6KPEaJ0VWOh5Pzkbw/RetuzehGM7QRRX5he4fPHx2rdKMB256ehJCkX+XRQm16eZLqLNS8RSZXZw==", + "deprecated": "Debug versions >=3.2.0 <3.2.7 || >=4 <4.3.1 have a low-severity ReDos regression when used in a Node.js environment. It is recommended you upgrade to 3.2.7 or 4.3.1. (https://github.com/visionmedia/debug/issues/797)", + "dev": true, + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/debug/node_modules/ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + }, + "node_modules/default-gateway": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", + "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "dev": true, + "dependencies": { + "execa": "^5.0.0" + }, + "engines": { + "node": ">= 10" + } + }, + "node_modules/define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/depd": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", + "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/destroy": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.2.0.tgz", + "integrity": "sha512-2sJGJTaXIIaR1w4iJSNoN0hnMY7Gpc/n8D4qSCJw8QqFWXf7cuAgnEHxBpweaVcPevC2l3KpjYCx3NypQQgaJg==", + "dev": true, + "engines": { + "node": ">= 0.8", + "npm": "1.2.8000 || >= 1.4.16" + } + }, + "node_modules/detect-node": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.0.4.tgz", + "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", + "dev": true + }, + "node_modules/dir-glob": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", + "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", + "dev": true, + "dependencies": { + "path-type": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/dns-equal": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", + "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", + "dev": true + }, + "node_modules/dns-packet": { + "version": "5.6.1", + "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", + "integrity": "sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==", + "dev": true, + "dependencies": { + "@leichtgewicht/ip-codec": "^2.0.1" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "dev": true + }, + "node_modules/electron-to-chromium": { + "version": "1.4.528", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", + "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "dev": true + }, + "node_modules/emojis-list": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", + "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/enhanced-resolve": { + "version": "5.15.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.15.0.tgz", + "integrity": "sha512-LXYT42KJ7lpIKECr2mAXIaMldcNCh/7E0KBKOu4KSfkHmP+mZmSs+8V5gBAqisWBy0OO4W5Oyys0GO1Y8KtdKg==", + "dev": true, + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/envinfo": { + "version": "7.10.0", + "resolved": "https://registry.npmjs.org/envinfo/-/envinfo-7.10.0.tgz", + "integrity": "sha512-ZtUjZO6l5mwTHvc1L9+1q5p/R3wTopcfqMW8r5t8SJSKqeVI/LtajORwRFEKpEFuekjD0VBjwu1HMxL4UalIRw==", + "dev": true, + "bin": { + "envinfo": "dist/cli.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/es-module-lexer": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", + "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "dev": true + }, + "node_modules/escalade": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", + "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=", + "dev": true + }, + "node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esrecurse/node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/eventemitter3": { + "version": "4.0.7", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz", + "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", + "dev": true + }, + "node_modules/events": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "dev": true, + "engines": { + "node": ">=0.8.x" + } + }, + "node_modules/execa": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", + "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", + "dev": true, + "dependencies": { + "cross-spawn": "^7.0.3", + "get-stream": "^6.0.0", + "human-signals": "^2.1.0", + "is-stream": "^2.0.0", + "merge-stream": "^2.0.0", + "npm-run-path": "^4.0.1", + "onetime": "^5.1.2", + "signal-exit": "^3.0.3", + "strip-final-newline": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sindresorhus/execa?sponsor=1" + } + }, + "node_modules/express": { + "version": "4.18.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", + "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "dev": true, + "dependencies": { + "accepts": "~1.3.8", + "array-flatten": "1.1.1", + "body-parser": "1.20.1", + "content-disposition": "0.5.4", + "content-type": "~1.0.4", + "cookie": "0.5.0", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "2.0.0", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.2.0", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "merge-descriptors": "1.0.1", + "methods": "~1.1.2", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "path-to-regexp": "0.1.7", + "proxy-addr": "~2.0.7", + "qs": "6.11.0", + "range-parser": "~1.2.1", + "safe-buffer": "5.2.1", + "send": "0.18.0", + "serve-static": "1.15.0", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "type-is": "~1.6.18", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + }, + "engines": { + "node": ">= 0.10.0" + } + }, + "node_modules/express/node_modules/array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", + "dev": true + }, + "node_modules/express/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/express/node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/express/node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/express/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true + }, + "node_modules/fast-glob": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", + "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.4" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "node_modules/fastest-levenshtein": { + "version": "1.0.16", + "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", + "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", + "dev": true, + "engines": { + "node": ">= 4.9.1" + } + }, + "node_modules/fastq": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "dev": true, + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/faye-websocket": { + "version": "0.11.4", + "resolved": "https://registry.npmjs.org/faye-websocket/-/faye-websocket-0.11.4.tgz", + "integrity": "sha512-CzbClwlXAuiRQAlUyfqPgvPoNKTckTPGfwZV4ZdAhVcP2lh9KUxJg2b5GkE7XbjKQ3YJnQ9z6D9ntLAlB+tP8g==", + "dev": true, + "dependencies": { + "websocket-driver": ">=0.5.1" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dev": true, + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/finalhandler": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.2.0.tgz", + "integrity": "sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==", + "dev": true, + "dependencies": { + "debug": "2.6.9", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "statuses": "2.0.1", + "unpipe": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/finalhandler/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/finalhandler/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/find-cache-dir": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", + "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", + "dev": true, + "dependencies": { + "commondir": "^1.0.1", + "make-dir": "^3.0.2", + "pkg-dir": "^4.1.0" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/avajs/find-cache-dir?sponsor=1" + } + }, + "node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/follow-redirects": { + "version": "1.15.3", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", + "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fs-minipass": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", + "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", + "dev": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/fs-monkey": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", + "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", + "dev": true + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "dev": true + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", + "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "dev": true + }, + "node_modules/get-intrinsic": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", + "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.1", + "has": "^1.0.3", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-stream": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", + "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "dev": true, + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/glob-to-regexp": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", + "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", + "dev": true + }, + "node_modules/globby": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", + "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "dev": true, + "dependencies": { + "array-union": "^2.1.0", + "dir-glob": "^3.0.1", + "fast-glob": "^3.2.9", + "ignore": "^5.2.0", + "merge2": "^1.4.1", + "slash": "^3.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "dev": true + }, + "node_modules/handle-thing": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/handle-thing/-/handle-thing-2.0.1.tgz", + "integrity": "sha512-9Qn4yBxelxoh2Ow62nP+Ka/kMnOXRi8BXnRaUwezLNhqelnN49xKz4F/dPP8OYLxLxq6JDtZb2i9XznUQbNPTg==", + "dev": true + }, + "node_modules/has": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", + "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.1" + }, + "engines": { + "node": ">= 0.4.0" + } + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/has-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", + "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", + "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hpack.js": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", + "integrity": "sha1-h3dMCUnlE/QuhFdbPEVoH63ioLI=", + "dev": true, + "dependencies": { + "inherits": "^2.0.1", + "obuf": "^1.0.0", + "readable-stream": "^2.0.1", + "wbuf": "^1.1.0" + } + }, + "node_modules/html-entities": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", + "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/mdevils" + }, + { + "type": "patreon", + "url": "https://patreon.com/mdevils" + } + ] + }, + "node_modules/http-deceiver": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", + "integrity": "sha1-+nFolEq5pRnTN8sL7HKE3D5yPYc=", + "dev": true + }, + "node_modules/http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "dev": true, + "dependencies": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-errors/node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-errors/node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true + }, + "node_modules/http-errors/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-parser-js": { + "version": "0.5.8", + "resolved": "https://registry.npmjs.org/http-parser-js/-/http-parser-js-0.5.8.tgz", + "integrity": "sha512-SGeBX54F94Wgu5RH3X5jsDtf4eHyRogWX1XGT3b4HuW3tQPM4AaBzoUji/4AAJNXCEOWZ5O0DgZmJw1947gD5Q==", + "dev": true + }, + "node_modules/http-proxy": { + "version": "1.18.1", + "resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz", + "integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==", + "dev": true, + "dependencies": { + "eventemitter3": "^4.0.0", + "follow-redirects": "^1.0.0", + "requires-port": "^1.0.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/http-proxy-middleware": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", + "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "dev": true, + "dependencies": { + "@types/http-proxy": "^1.17.8", + "http-proxy": "^1.18.1", + "is-glob": "^4.0.1", + "is-plain-obj": "^3.0.0", + "micromatch": "^4.0.2" + }, + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "@types/express": "^4.17.13" + }, + "peerDependenciesMeta": { + "@types/express": { + "optional": true + } + } + }, + "node_modules/human-signals": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", + "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "dev": true, + "engines": { + "node": ">=10.17.0" + } + }, + "node_modules/iconv-lite": { + "version": "0.4.24", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", + "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", + "dev": true, + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/ignore": { + "version": "5.2.4", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", + "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/import-local": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz", + "integrity": "sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg==", + "dev": true, + "dependencies": { + "pkg-dir": "^4.2.0", + "resolve-cwd": "^3.0.0" + }, + "bin": { + "import-local-fixture": "fixtures/cli.js" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/infer-owner": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/infer-owner/-/infer-owner-1.0.4.tgz", + "integrity": "sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A==", + "dev": true + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "dev": true, + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=", + "dev": true + }, + "node_modules/interpret": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/interpret/-/interpret-3.1.1.tgz", + "integrity": "sha512-6xwYfHbajpoF0xLW+iwLkhwgvLoZDfjYfoFNu8ftMoXINzwuymNLd9u/KmwtdT2GbR+/Cz66otEGEVVUHX9QLQ==", + "dev": true, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/ipaddr.js": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-2.1.0.tgz", + "integrity": "sha512-LlbxQ7xKzfBusov6UMi4MFpEg0m+mAm9xyNGEduwXMEDuf4WfzB/RZwMVYEd7IKGvh4IUkEXYxtAVu9T3OelJQ==", + "dev": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-core-module": { + "version": "2.13.0", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.0.tgz", + "integrity": "sha512-Z7dk6Qo8pOCp3l4tsX2C5ZVas4V+UxwQodwZhLopL91TX8UyyHEXafPcyoeeWuLrwzHcr3igO78wNLwHJHsMCQ==", + "dev": true, + "dependencies": { + "has": "^1.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-docker": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", + "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "dev": true, + "bin": { + "is-docker": "cli.js" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-plain-obj": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-plain-obj/-/is-plain-obj-3.0.0.tgz", + "integrity": "sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-plain-object": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", + "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", + "dev": true, + "dependencies": { + "isobject": "^3.0.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "dev": true, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-wsl": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", + "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "dev": true, + "dependencies": { + "is-docker": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=", + "dev": true + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true + }, + "node_modules/isobject": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", + "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/jest-worker": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-27.5.1.tgz", + "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", + "dev": true, + "dependencies": { + "@types/node": "*", + "merge-stream": "^2.0.0", + "supports-color": "^8.0.0" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", + "dev": true + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true + }, + "node_modules/json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true, + "bin": { + "json5": "lib/cli.js" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/kind-of": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", + "integrity": "sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/launch-editor": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", + "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "dev": true, + "dependencies": { + "picocolors": "^1.0.0", + "shell-quote": "^1.7.3" + } + }, + "node_modules/loader-runner": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", + "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "dev": true, + "engines": { + "node": ">=6.11.5" + } + }, + "node_modules/loader-utils": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", + "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", + "dev": true, + "dependencies": { + "big.js": "^5.2.2", + "emojis-list": "^3.0.0", + "json5": "^2.1.2" + }, + "engines": { + "node": ">=8.9.0" + } + }, + "node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dev": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/make-dir": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", + "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", + "dev": true, + "dependencies": { + "semver": "^6.0.0" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/make-dir/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/media-typer": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/memfs": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", + "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "dev": true, + "dependencies": { + "fs-monkey": "^1.0.4" + }, + "engines": { + "node": ">= 4.0.0" + } + }, + "node_modules/merge-descriptors": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", + "integrity": "sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==", + "dev": true + }, + "node_modules/merge-stream": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", + "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", + "dev": true + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/methods": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", + "integrity": "sha512-iclAHeNqNm68zFtnZ0e+1L2yUIdvzNoauKU4WBA3VvH/vPFieF7qfRlwUZU+DA9P9bPXIS90ulxoUoCH23sV2w==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/micromatch": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", + "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "dev": true, + "dependencies": { + "braces": "^3.0.2", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/mime": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz", + "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==", + "dev": true, + "bin": { + "mime": "cli.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "dev": true, + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mimic-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", + "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/minimalistic-assert": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", + "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", + "dev": true + }, + "node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "dev": true, + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minipass-collect": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/minipass-collect/-/minipass-collect-1.0.2.tgz", + "integrity": "sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA==", + "dev": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/minipass-flush": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/minipass-flush/-/minipass-flush-1.0.5.tgz", + "integrity": "sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw==", + "dev": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/minipass-pipeline": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/minipass-pipeline/-/minipass-pipeline-1.2.4.tgz", + "integrity": "sha512-xuIq7cIOt09RPRJ19gdi4b+RiNvDFYe5JH+ggNvBqGqpQXcru3PcRmOZuHBKWK1Txf9+cQ+HMVN4d6z46LZP7A==", + "dev": true, + "dependencies": { + "minipass": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/minizlib": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", + "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", + "dev": true, + "dependencies": { + "minipass": "^3.0.0", + "yallist": "^4.0.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/mkdirp": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", + "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", + "dev": true, + "bin": { + "mkdirp": "bin/cmd.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=", + "dev": true + }, + "node_modules/multicast-dns": { + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", + "integrity": "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==", + "dev": true, + "dependencies": { + "dns-packet": "^5.2.2", + "thunky": "^1.0.2" + }, + "bin": { + "multicast-dns": "cli.js" + } + }, + "node_modules/negotiator": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz", + "integrity": "sha512-+EUsqGPLsM+j/zdChZjsnX51g4XrHFOIXwfnCVPGlQk/k5giakcKsuxCObBRu6DSm9opw/O6slWbJdghQM4bBg==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/neo-async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", + "dev": true + }, + "node_modules/node-forge": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", + "integrity": "sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==", + "dev": true, + "engines": { + "node": ">= 6.13.0" + } + }, + "node_modules/node-releases": { + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", + "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "dev": true + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/npm-run-path": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", + "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", + "dev": true, + "dependencies": { + "path-key": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/object-inspect": { + "version": "1.12.3", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", + "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/obuf": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/obuf/-/obuf-1.1.2.tgz", + "integrity": "sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg==", + "dev": true + }, + "node_modules/on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "dev": true, + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/on-headers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", + "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "dev": true, + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/onetime": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", + "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", + "dev": true, + "dependencies": { + "mimic-fn": "^2.1.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/open": { + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "dev": true, + "dependencies": { + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-locate/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-map": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/p-map/-/p-map-4.0.0.tgz", + "integrity": "sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ==", + "dev": true, + "dependencies": { + "aggregate-error": "^3.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-retry": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", + "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "dev": true, + "dependencies": { + "@types/retry": "0.12.0", + "retry": "^0.13.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true + }, + "node_modules/path-to-regexp": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", + "integrity": "sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==", + "dev": true + }, + "node_modules/path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/picocolors": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", + "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "dev": true + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pkg-dir": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", + "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", + "dev": true, + "dependencies": { + "find-up": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.0.tgz", + "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==", + "dev": true + }, + "node_modules/promise-inflight": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", + "integrity": "sha512-6zWPyEOFaQBJYcGMHBKTKJ3u6TBsnMFOIZSa6ce1e/ZrrsOlnHRHbabMjLiBYKp+n44X9eUI6VUPaukCXHuG4g==", + "dev": true + }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "dev": true, + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/proxy-addr/node_modules/ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "dev": true, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/punycode": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", + "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/qs": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", + "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "dev": true, + "dependencies": { + "side-channel": "^1.0.4" + }, + "engines": { + "node": ">=0.6" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/randombytes": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", + "integrity": "sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==", + "dev": true, + "dependencies": { + "safe-buffer": "^5.1.0" + } + }, + "node_modules/range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", + "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "dev": true, + "dependencies": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/raw-body/node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/readable-stream": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", + "dev": true, + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/rechoir": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/rechoir/-/rechoir-0.8.0.tgz", + "integrity": "sha512-/vxpCXddiX8NGfGO/mTafwjq4aFa/71pvamip0++IQk3zG8cbCj0fifNPrjjF1XMXUne91jL9OoxmdykoEtifQ==", + "dev": true, + "dependencies": { + "resolve": "^1.20.0" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/requires-port": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz", + "integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==", + "dev": true + }, + "node_modules/resolve": { + "version": "1.22.6", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.6.tgz", + "integrity": "sha512-njhxM7mV12JfufShqGy3Rz8j11RPdLy4xi15UurGJeoHLfJpVXKdh3ueuOqbYUcDZnffr6X739JBo5LzyahEsw==", + "dev": true, + "dependencies": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-cwd": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", + "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", + "dev": true, + "dependencies": { + "resolve-from": "^5.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/resolve-from": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", + "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/retry": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", + "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", + "dev": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true, + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", + "dev": true + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "dev": true + }, + "node_modules/schema-utils": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", + "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "dev": true, + "dependencies": { + "@types/json-schema": "^7.0.8", + "ajv": "^6.12.5", + "ajv-keywords": "^3.5.2" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/select-hose": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/select-hose/-/select-hose-2.0.0.tgz", + "integrity": "sha1-Yl2GWPhlr0Psliv8N2o3NZpJlMo=", + "dev": true + }, + "node_modules/selfsigned": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", + "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "dev": true, + "dependencies": { + "node-forge": "^1" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/semver": { + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "dev": true, + "dependencies": { + "lru-cache": "^6.0.0" + }, + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/send": { + "version": "0.18.0", + "resolved": "https://registry.npmjs.org/send/-/send-0.18.0.tgz", + "integrity": "sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==", + "dev": true, + "dependencies": { + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "mime": "1.6.0", + "ms": "2.1.3", + "on-finished": "2.4.1", + "range-parser": "~1.2.1", + "statuses": "2.0.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/send/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/send/node_modules/debug/node_modules/ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", + "dev": true + }, + "node_modules/send/node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/send/node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + }, + "node_modules/send/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/serialize-javascript": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-5.0.1.tgz", + "integrity": "sha512-SaaNal9imEO737H2c05Og0/8LUXG7EnsZyMa8MzkmuHoELfT6txuj0cMqRj6zfPKnmQ1yasR4PCJc8x+M4JSPA==", + "dev": true, + "dependencies": { + "randombytes": "^2.1.0" + } + }, + "node_modules/serve-index": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/serve-index/-/serve-index-1.9.1.tgz", + "integrity": "sha1-03aNabHn2C5c4FD/9bRTvqEqkjk=", + "dev": true, + "dependencies": { + "accepts": "~1.3.4", + "batch": "0.6.1", + "debug": "2.6.9", + "escape-html": "~1.0.3", + "http-errors": "~1.6.2", + "mime-types": "~2.1.17", + "parseurl": "~1.3.2" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/serve-index/node_modules/debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "dependencies": { + "ms": "2.0.0" + } + }, + "node_modules/serve-index/node_modules/http-errors": { + "version": "1.6.3", + "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", + "dev": true, + "dependencies": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/serve-index/node_modules/setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==", + "dev": true + }, + "node_modules/serve-static": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.15.0.tgz", + "integrity": "sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==", + "dev": true, + "dependencies": { + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "parseurl": "~1.3.3", + "send": "0.18.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "dev": true + }, + "node_modules/shallow-clone": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/shallow-clone/-/shallow-clone-3.0.1.tgz", + "integrity": "sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA==", + "dev": true, + "dependencies": { + "kind-of": "^6.0.2" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/shell-quote": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", + "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", + "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "dev": true, + "dependencies": { + "call-bind": "^1.0.0", + "get-intrinsic": "^1.0.2", + "object-inspect": "^1.9.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "dev": true + }, + "node_modules/slash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", + "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/sockjs": { + "version": "0.3.24", + "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz", + "integrity": "sha512-GJgLTZ7vYb/JtPSSZ10hsOYIvEYsjbNU+zPdIHcUaWVNUEPivzxku31865sSSud0Da0W4lEeOPlmw93zLQchuQ==", + "dev": true, + "dependencies": { + "faye-websocket": "^0.11.3", + "uuid": "^8.3.2", + "websocket-driver": "^0.7.4" + } + }, + "node_modules/source-list-map": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/source-list-map/-/source-list-map-2.0.1.tgz", + "integrity": "sha512-qnQ7gVMxGNxsiL4lEuJwe/To8UnK7fAnmbGEEH8RpLouuKbeEm0lhbQVFIrNSuB+G7tVrAlVsZgETT5nljf+Iw==", + "dev": true + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/spdy": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/spdy/-/spdy-4.0.2.tgz", + "integrity": "sha512-r46gZQZQV+Kl9oItvl1JZZqJKGr+oEkB08A6BzkiR7593/7IbtuncXHd2YoYeTsG4157ZssMu9KYvUHLcjcDoA==", + "dev": true, + "dependencies": { + "debug": "^4.1.0", + "handle-thing": "^2.0.0", + "http-deceiver": "^1.2.7", + "select-hose": "^2.0.0", + "spdy-transport": "^3.0.0" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/spdy-transport": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/spdy-transport/-/spdy-transport-3.0.0.tgz", + "integrity": "sha512-hsLVFE5SjA6TCisWeJXFKniGGOpBgMLmerfO2aCyCU5s7nJ/rpAepqmFifv/GCbSbueEeAJJnmSQ2rKC/g8Fcw==", + "dev": true, + "dependencies": { + "debug": "^4.1.0", + "detect-node": "^2.0.4", + "hpack.js": "^2.1.6", + "obuf": "^1.1.2", + "readable-stream": "^3.0.6", + "wbuf": "^1.7.3" + } + }, + "node_modules/spdy-transport/node_modules/readable-stream": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.0.tgz", + "integrity": "sha512-BViHy7LKeTz4oNnkcLJ+lVSL6vpiFeX6/d3oSH8zCW7UxP2onchk+vTGB143xuFjHS3deTgkKoXXymXqymiIdA==", + "dev": true, + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/ssri": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/ssri/-/ssri-8.0.1.tgz", + "integrity": "sha512-97qShzy1AiyxvPNIkLWoGua7xoQzzPjQ0HAH4B0rWKo7SZ6USuPcrUiAFrws0UH8RrbWmgq3LMTObhPIHbbBeQ==", + "dev": true, + "dependencies": { + "minipass": "^3.1.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/statuses": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", + "integrity": "sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow=", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dev": true, + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/strip-final-newline": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", + "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/tapable": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "dev": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/tar": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.0.tgz", + "integrity": "sha512-/Wo7DcT0u5HUV486xg675HtjNd3BXZ6xDbzsCUZPt5iw8bTQ63bP0Raut3mvro9u+CUyq7YQd8Cx55fsZXxqLQ==", + "dev": true, + "dependencies": { + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "minipass": "^5.0.0", + "minizlib": "^2.1.1", + "mkdirp": "^1.0.3", + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/tar/node_modules/minipass": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", + "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/terser": { + "version": "5.20.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.20.0.tgz", + "integrity": "sha512-e56ETryaQDyebBwJIWYB2TT6f2EZ0fL0sW/JRXNMN26zZdKi2u/E/5my5lG6jNxym6qsrVXfFRmOdV42zlAgLQ==", + "dev": true, + "dependencies": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.8.2", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + }, + "bin": { + "terser": "bin/terser" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/terser-webpack-plugin": { + "version": "5.3.9", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.9.tgz", + "integrity": "sha512-ZuXsqE07EcggTWQjXUj+Aot/OMcD0bMKGgF63f7UxYcu5/AJF53aIpK1YoP5xR9l6s/Hy2b+t1AM0bLNPRuhwA==", + "dev": true, + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.17", + "jest-worker": "^27.4.5", + "schema-utils": "^3.1.1", + "serialize-javascript": "^6.0.1", + "terser": "^5.16.8" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^5.1.0" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "esbuild": { + "optional": true + }, + "uglify-js": { + "optional": true + } + } + }, + "node_modules/terser-webpack-plugin/node_modules/serialize-javascript": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", + "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", + "dev": true, + "dependencies": { + "randombytes": "^2.1.0" + } + }, + "node_modules/thunky": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz", + "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", + "dev": true + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "dev": true, + "engines": { + "node": ">=0.6" + } + }, + "node_modules/tslib": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.11.1.tgz", + "integrity": "sha512-aZW88SY8kQbU7gpV19lN24LtXh/yD4ZZg6qieAJDDg+YBsJcSmLGK9QpnUjAKVG/xefmvJGd1WUmfpT/g6AJGA==", + "dev": true + }, + "node_modules/type-is": { + "version": "1.6.18", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz", + "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==", + "dev": true, + "dependencies": { + "media-typer": "0.3.0", + "mime-types": "~2.1.24" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/unique-filename": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/unique-filename/-/unique-filename-1.1.1.tgz", + "integrity": "sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ==", + "dev": true, + "dependencies": { + "unique-slug": "^2.0.0" + } + }, + "node_modules/unique-slug": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/unique-slug/-/unique-slug-2.0.2.tgz", + "integrity": "sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w==", + "dev": true, + "dependencies": { + "imurmurhash": "^0.1.4" + } + }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/update-browserslist-db": { + "version": "1.0.13", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", + "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "escalade": "^3.1.1", + "picocolors": "^1.0.0" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=", + "dev": true + }, + "node_modules/utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==", + "dev": true, + "engines": { + "node": ">= 0.4.0" + } + }, + "node_modules/uuid": { + "version": "8.3.2", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", + "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "dev": true, + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/watchpack": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", + "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", + "dev": true, + "dependencies": { + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.1.2" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/wbuf": { + "version": "1.7.3", + "resolved": "https://registry.npmjs.org/wbuf/-/wbuf-1.7.3.tgz", + "integrity": "sha512-O84QOnr0icsbFGLS0O3bI5FswxzRr8/gHwWkDlQFskhSPryQXvrTMxjxGP4+iWYoauLoBvfDpkrOauZ+0iZpDA==", + "dev": true, + "dependencies": { + "minimalistic-assert": "^1.0.0" + } + }, + "node_modules/webpack": { + "version": "5.88.2", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.88.2.tgz", + "integrity": "sha512-JmcgNZ1iKj+aiR0OvTYtWQqJwq37Pf683dY9bVORwVbUrDhLhdn/PlO2sHsFHPkj7sHNQF3JwaAkp49V+Sq1tQ==", + "dev": true, + "dependencies": { + "@types/eslint-scope": "^3.7.3", + "@types/estree": "^1.0.0", + "@webassemblyjs/ast": "^1.11.5", + "@webassemblyjs/wasm-edit": "^1.11.5", + "@webassemblyjs/wasm-parser": "^1.11.5", + "acorn": "^8.7.1", + "acorn-import-assertions": "^1.9.0", + "browserslist": "^4.14.5", + "chrome-trace-event": "^1.0.2", + "enhanced-resolve": "^5.15.0", + "es-module-lexer": "^1.2.1", + "eslint-scope": "5.1.1", + "events": "^3.2.0", + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.2.9", + "json-parse-even-better-errors": "^2.3.1", + "loader-runner": "^4.2.0", + "mime-types": "^2.1.27", + "neo-async": "^2.6.2", + "schema-utils": "^3.2.0", + "tapable": "^2.1.1", + "terser-webpack-plugin": "^5.3.7", + "watchpack": "^2.4.0", + "webpack-sources": "^3.2.3" + }, + "bin": { + "webpack": "bin/webpack.js" + }, + "engines": { + "node": ">=10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependenciesMeta": { + "webpack-cli": { + "optional": true + } + } + }, + "node_modules/webpack-cli": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/webpack-cli/-/webpack-cli-5.1.4.tgz", + "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", + "dev": true, + "dependencies": { + "@discoveryjs/json-ext": "^0.5.0", + "@webpack-cli/configtest": "^2.1.1", + "@webpack-cli/info": "^2.0.2", + "@webpack-cli/serve": "^2.0.5", + "colorette": "^2.0.14", + "commander": "^10.0.1", + "cross-spawn": "^7.0.3", + "envinfo": "^7.7.3", + "fastest-levenshtein": "^1.0.12", + "import-local": "^3.0.2", + "interpret": "^3.1.1", + "rechoir": "^0.8.0", + "webpack-merge": "^5.7.3" + }, + "bin": { + "webpack-cli": "bin/cli.js" + }, + "engines": { + "node": ">=14.15.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "5.x.x" + }, + "peerDependenciesMeta": { + "@webpack-cli/generators": { + "optional": true + }, + "webpack-bundle-analyzer": { + "optional": true + }, + "webpack-dev-server": { + "optional": true + } + } + }, + "node_modules/webpack-cli/node_modules/commander": { + "version": "10.0.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz", + "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "dev": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/webpack-dev-middleware": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", + "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "dev": true, + "dependencies": { + "colorette": "^2.0.10", + "memfs": "^3.4.3", + "mime-types": "^2.1.31", + "range-parser": "^1.2.1", + "schema-utils": "^4.0.0" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^4.0.0 || ^5.0.0" + } + }, + "node_modules/webpack-dev-middleware/node_modules/ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/webpack-dev-middleware/node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.3" + }, + "peerDependencies": { + "ajv": "^8.8.2" + } + }, + "node_modules/webpack-dev-middleware/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "node_modules/webpack-dev-middleware/node_modules/schema-utils": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", + "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "dev": true, + "dependencies": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/webpack-dev-server": { + "version": "4.15.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", + "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", + "dev": true, + "dependencies": { + "@types/bonjour": "^3.5.9", + "@types/connect-history-api-fallback": "^1.3.5", + "@types/express": "^4.17.13", + "@types/serve-index": "^1.9.1", + "@types/serve-static": "^1.13.10", + "@types/sockjs": "^0.3.33", + "@types/ws": "^8.5.5", + "ansi-html-community": "^0.0.8", + "bonjour-service": "^1.0.11", + "chokidar": "^3.5.3", + "colorette": "^2.0.10", + "compression": "^1.7.4", + "connect-history-api-fallback": "^2.0.0", + "default-gateway": "^6.0.3", + "express": "^4.17.3", + "graceful-fs": "^4.2.6", + "html-entities": "^2.3.2", + "http-proxy-middleware": "^2.0.3", + "ipaddr.js": "^2.0.1", + "launch-editor": "^2.6.0", + "open": "^8.0.9", + "p-retry": "^4.5.0", + "rimraf": "^3.0.2", + "schema-utils": "^4.0.0", + "selfsigned": "^2.1.1", + "serve-index": "^1.9.1", + "sockjs": "^0.3.24", + "spdy": "^4.0.2", + "webpack-dev-middleware": "^5.3.1", + "ws": "^8.13.0" + }, + "bin": { + "webpack-dev-server": "bin/webpack-dev-server.js" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^4.37.0 || ^5.0.0" + }, + "peerDependenciesMeta": { + "webpack": { + "optional": true + }, + "webpack-cli": { + "optional": true + } + } + }, + "node_modules/webpack-dev-server/node_modules/ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/webpack-dev-server/node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "dependencies": { + "fast-deep-equal": "^3.1.3" + }, + "peerDependencies": { + "ajv": "^8.8.2" + } + }, + "node_modules/webpack-dev-server/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "node_modules/webpack-dev-server/node_modules/schema-utils": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", + "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "dev": true, + "dependencies": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + }, + "engines": { + "node": ">= 12.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/webpack-merge": { + "version": "5.9.0", + "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.9.0.tgz", + "integrity": "sha512-6NbRQw4+Sy50vYNTw7EyOn41OZItPiXB8GNv3INSoe3PSFaHJEz3SHTrYVaRm2LilNGnFUzh0FAwqPEmU/CwDg==", + "dev": true, + "dependencies": { + "clone-deep": "^4.0.1", + "wildcard": "^2.0.0" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/webpack-sources": { + "version": "1.4.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-1.4.3.tgz", + "integrity": "sha512-lgTS3Xhv1lCOKo7SA5TjKXMjpSM4sBjNV5+q2bqesbSPs5FjGmU6jjtBSkX9b4qW87vDIsCIlUPOEhbZrMdjeQ==", + "dev": true, + "dependencies": { + "source-list-map": "^2.0.0", + "source-map": "~0.6.1" + } + }, + "node_modules/webpack/node_modules/webpack-sources": { + "version": "3.2.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", + "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "dev": true, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/websocket-driver": { + "version": "0.7.4", + "resolved": "https://registry.npmjs.org/websocket-driver/-/websocket-driver-0.7.4.tgz", + "integrity": "sha512-b17KeDIQVjvb0ssuSDF2cYXSg2iztliJ4B9WdsuB6J952qCPKmnVq4DyW5motImXHDC1cBT/1UezrJVsKw5zjg==", + "dev": true, + "dependencies": { + "http-parser-js": ">=0.5.1", + "safe-buffer": ">=5.1.0", + "websocket-extensions": ">=0.1.1" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/websocket-extensions": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/websocket-extensions/-/websocket-extensions-0.1.4.tgz", + "integrity": "sha512-OqedPIGOfsDlo31UNwYbCFMSaO9m9G/0faIHj5/dZFDMFqPTcx6UwqyOy3COEaEOg/9VsGIpdqn62W5KhoKSpg==", + "dev": true, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wildcard": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/wildcard/-/wildcard-2.0.1.tgz", + "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", + "dev": true + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "dev": true + }, + "node_modules/ws": { + "version": "8.14.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz", + "integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==", + "dev": true, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "dev": true + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + } + }, + "dependencies": { + "@discoveryjs/json-ext": { + "version": "0.5.7", + "resolved": "https://registry.npmjs.org/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz", + "integrity": "sha512-dBVuXR082gk3jsFp7Rd/JI4kytwGHecnCoTtXFb7DB6CNHp4rg5k1bhg0nWdLGLnOV71lmDzGQaLMy8iPLY0pw==", + "dev": true + }, + "@gar/promisify": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@gar/promisify/-/promisify-1.1.3.tgz", + "integrity": "sha512-k2Ty1JcVojjJFwrg/ThKi2ujJ7XNLYaFGNB/bWT9wGR+oSMJHMa5w+CUq6p/pVrKeNNgA7pCqEcjSnHVoqJQFw==", + "dev": true + }, + "@jridgewell/gen-mapping": { + "version": "0.3.3", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", + "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "dev": true, + "requires": { + "@jridgewell/set-array": "^1.0.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.9" + } + }, + "@jridgewell/resolve-uri": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz", + "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==", + "dev": true + }, + "@jridgewell/set-array": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", + "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", + "dev": true + }, + "@jridgewell/source-map": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.5.tgz", + "integrity": "sha512-UTYAUj/wviwdsMfzoSJspJxbkH5o1snzwX0//0ENX1u/55kkZZkcTZP6u9bwKGkv+dkk9at4m1Cpt0uY80kcpQ==", + "dev": true, + "requires": { + "@jridgewell/gen-mapping": "^0.3.0", + "@jridgewell/trace-mapping": "^0.3.9" + } + }, + "@jridgewell/sourcemap-codec": { + "version": "1.4.15", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz", + "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==", + "dev": true + }, + "@jridgewell/trace-mapping": { + "version": "0.3.19", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz", + "integrity": "sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw==", + "dev": true, + "requires": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "@leichtgewicht/ip-codec": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", + "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", + "dev": true + }, + "@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "requires": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + } + }, + "@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true + }, + "@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "requires": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + } + }, + "@npmcli/fs": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-1.1.1.tgz", + "integrity": "sha512-8KG5RD0GVP4ydEzRn/I4BNDuxDtqVbOdm8675T49OIG/NGhaK0pjPX7ZcDlvKYbA+ulvVK3ztfcF4uBdOxuJbQ==", + "dev": true, + "requires": { + "@gar/promisify": "^1.0.1", + "semver": "^7.3.5" + } + }, + "@npmcli/move-file": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-1.1.2.tgz", + "integrity": "sha512-1SUf/Cg2GzGDyaf15aR9St9TWlb+XvbZXWpDx8YKs7MLzMH/BCeopv+y9vzrzgkfykCGuWOlSu3mZhj2+FQcrg==", + "dev": true, + "requires": { + "mkdirp": "^1.0.4", + "rimraf": "^3.0.2" + } + }, + "@types/body-parser": { + "version": "1.19.3", + "resolved": "https://registry.npmjs.org/@types/body-parser/-/body-parser-1.19.3.tgz", + "integrity": "sha512-oyl4jvAfTGX9Bt6Or4H9ni1Z447/tQuxnZsytsCaExKlmJiU8sFgnIBRzJUpKwB5eWn9HuBYlUlVA74q/yN0eQ==", + "dev": true, + "requires": { + "@types/connect": "*", + "@types/node": "*" + } + }, + "@types/bonjour": { + "version": "3.5.11", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", + "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, + "@types/connect": { + "version": "3.4.36", + "resolved": "https://registry.npmjs.org/@types/connect/-/connect-3.4.36.tgz", + "integrity": "sha512-P63Zd/JUGq+PdrM1lv0Wv5SBYeA2+CORvbrXbngriYY0jzLUWfQMQQxOhjONEz/wlHOAxOdY7CY65rgQdTjq2w==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, + "@types/connect-history-api-fallback": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", + "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "dev": true, + "requires": { + "@types/express-serve-static-core": "*", + "@types/node": "*" + } + }, + "@types/eslint": { + "version": "8.44.2", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.44.2.tgz", + "integrity": "sha512-sdPRb9K6iL5XZOmBubg8yiFp5yS/JdUDQsq5e6h95km91MCYMuvp7mh1fjPEYUhvHepKpZOjnEaMBR4PxjWDzg==", + "dev": true, + "requires": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "@types/eslint-scope": { + "version": "3.7.4", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", + "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", + "dev": true, + "requires": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, + "@types/estree": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.1.tgz", + "integrity": "sha512-LG4opVs2ANWZ1TJoKc937iMmNstM/d0ae1vNbnBvBhqCSezgVUOzcLCqbI5elV8Vy6WKwKjaqR+zO9VKirBBCA==", + "dev": true + }, + "@types/express": { + "version": "4.17.17", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", + "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "dev": true, + "requires": { + "@types/body-parser": "*", + "@types/express-serve-static-core": "^4.17.33", + "@types/qs": "*", + "@types/serve-static": "*" + } + }, + "@types/express-serve-static-core": { + "version": "4.17.36", + "resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.17.36.tgz", + "integrity": "sha512-zbivROJ0ZqLAtMzgzIUC4oNqDG9iF0lSsAqpOD9kbs5xcIM3dTiyuHvBc7R8MtWBp3AAWGaovJa+wzWPjLYW7Q==", + "dev": true, + "requires": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "@types/http-errors": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@types/http-errors/-/http-errors-2.0.2.tgz", + "integrity": "sha512-lPG6KlZs88gef6aD85z3HNkztpj7w2R7HmR3gygjfXCQmsLloWNARFkMuzKiiY8FGdh1XDpgBdrSf4aKDiA7Kg==", + "dev": true + }, + "@types/http-proxy": { + "version": "1.17.12", + "resolved": "https://registry.npmjs.org/@types/http-proxy/-/http-proxy-1.17.12.tgz", + "integrity": "sha512-kQtujO08dVtQ2wXAuSFfk9ASy3sug4+ogFR8Kd8UgP8PEuc1/G/8yjYRmp//PcDNJEUKOza/MrQu15bouEUCiw==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, + "@types/json-schema": { + "version": "7.0.13", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.13.tgz", + "integrity": "sha512-RbSSoHliUbnXj3ny0CNFOoxrIDV6SUGyStHsvDqosw6CkdPV8TtWGlfecuK4ToyMEAql6pzNxgCFKanovUzlgQ==", + "dev": true + }, + "@types/mime": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@types/mime/-/mime-1.3.2.tgz", + "integrity": "sha512-YATxVxgRqNH6nHEIsvg6k2Boc1JHI9ZbH5iWFFv/MTkchz3b1ieGDa5T0a9RznNdI0KhVbdbWSN+KWWrQZRxTw==", + "dev": true + }, + "@types/node": { + "version": "20.6.3", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.6.3.tgz", + "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", + "dev": true + }, + "@types/qs": { + "version": "6.9.8", + "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", + "integrity": "sha512-u95svzDlTysU5xecFNTgfFG5RUWu1A9P0VzgpcIiGZA9iraHOdSzcxMxQ55DyeRaGCSxQi7LxXDI4rzq/MYfdg==", + "dev": true + }, + "@types/range-parser": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/@types/range-parser/-/range-parser-1.2.4.tgz", + "integrity": "sha512-EEhsLsD6UsDM1yFhAvy0Cjr6VwmpMWqFBCb9w07wVugF7w9nfajxLuVmngTIpgS6svCnm6Vaw+MZhoDCKnOfsw==", + "dev": true + }, + "@types/retry": { + "version": "0.12.0", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", + "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "dev": true + }, + "@types/send": { + "version": "0.17.1", + "resolved": "https://registry.npmjs.org/@types/send/-/send-0.17.1.tgz", + "integrity": "sha512-Cwo8LE/0rnvX7kIIa3QHCkcuF21c05Ayb0ZfxPiv0W8VRiZiNW/WuRupHKpqqGVGf7SUA44QSOUKaEd9lIrd/Q==", + "dev": true, + "requires": { + "@types/mime": "^1", + "@types/node": "*" + } + }, + "@types/serve-index": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", + "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "dev": true, + "requires": { + "@types/express": "*" + } + }, + "@types/serve-static": { + "version": "1.15.2", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", + "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "dev": true, + "requires": { + "@types/http-errors": "*", + "@types/mime": "*", + "@types/node": "*" + } + }, + "@types/sockjs": { + "version": "0.3.33", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", + "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, + "@types/ws": { + "version": "8.5.5", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", + "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, + "@webassemblyjs/ast": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.6.tgz", + "integrity": "sha512-IN1xI7PwOvLPgjcf180gC1bqn3q/QaOCwYUahIOhbYUu8KA/3tw2RT/T0Gidi1l7Hhj5D/INhJxiICObqpMu4Q==", + "dev": true, + "requires": { + "@webassemblyjs/helper-numbers": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6" + } + }, + "@webassemblyjs/floating-point-hex-parser": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.6.tgz", + "integrity": "sha512-ejAj9hfRJ2XMsNHk/v6Fu2dGS+i4UaXBXGemOfQ/JfQ6mdQg/WXtwleQRLLS4OvfDhv8rYnVwH27YJLMyYsxhw==", + "dev": true + }, + "@webassemblyjs/helper-api-error": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.6.tgz", + "integrity": "sha512-o0YkoP4pVu4rN8aTJgAyj9hC2Sv5UlkzCHhxqWj8butaLvnpdc2jOwh4ewE6CX0txSfLn/UYaV/pheS2Txg//Q==", + "dev": true + }, + "@webassemblyjs/helper-buffer": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.6.tgz", + "integrity": "sha512-z3nFzdcp1mb8nEOFFk8DrYLpHvhKC3grJD2ardfKOzmbmJvEf/tPIqCY+sNcwZIY8ZD7IkB2l7/pqhUhqm7hLA==", + "dev": true + }, + "@webassemblyjs/helper-numbers": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.6.tgz", + "integrity": "sha512-vUIhZ8LZoIWHBohiEObxVm6hwP034jwmc9kuq5GdHZH0wiLVLIPcMCdpJzG4C11cHoQ25TFIQj9kaVADVX7N3g==", + "dev": true, + "requires": { + "@webassemblyjs/floating-point-hex-parser": "1.11.6", + "@webassemblyjs/helper-api-error": "1.11.6", + "@xtuc/long": "4.2.2" + } + }, + "@webassemblyjs/helper-wasm-bytecode": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.6.tgz", + "integrity": "sha512-sFFHKwcmBprO9e7Icf0+gddyWYDViL8bpPjJJl0WHxCdETktXdmtWLGVzoHbqUcY4Be1LkNfwTmXOJUFZYSJdA==", + "dev": true + }, + "@webassemblyjs/helper-wasm-section": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.6.tgz", + "integrity": "sha512-LPpZbSOwTpEC2cgn4hTydySy1Ke+XEu+ETXuoyvuyezHO3Kjdu90KK95Sh9xTbmjrCsUwvWwCOQQNta37VrS9g==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6" + } + }, + "@webassemblyjs/ieee754": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.11.6.tgz", + "integrity": "sha512-LM4p2csPNvbij6U1f19v6WR56QZ8JcHg3QIJTlSwzFcmx6WSORicYj6I63f9yU1kEUtrpG+kjkiIAkevHpDXrg==", + "dev": true, + "requires": { + "@xtuc/ieee754": "^1.2.0" + } + }, + "@webassemblyjs/leb128": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.11.6.tgz", + "integrity": "sha512-m7a0FhE67DQXgouf1tbN5XQcdWoNgaAuoULHIfGFIEVKA6tu/edls6XnIlkmS6FrXAquJRPni3ZZKjw6FSPjPQ==", + "dev": true, + "requires": { + "@xtuc/long": "4.2.2" + } + }, + "@webassemblyjs/utf8": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.11.6.tgz", + "integrity": "sha512-vtXf2wTQ3+up9Zsg8sa2yWiQpzSsMyXj0qViVP6xKGCUT8p8YJ6HqI7l5eCnWx1T/FYdsv07HQs2wTFbbof/RA==", + "dev": true + }, + "@webassemblyjs/wasm-edit": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.6.tgz", + "integrity": "sha512-Ybn2I6fnfIGuCR+Faaz7YcvtBKxvoLV3Lebn1tM4o/IAJzmi9AWYIPWpyBfU8cC+JxAO57bk4+zdsTjJR+VTOw==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/helper-wasm-section": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6", + "@webassemblyjs/wasm-opt": "1.11.6", + "@webassemblyjs/wasm-parser": "1.11.6", + "@webassemblyjs/wast-printer": "1.11.6" + } + }, + "@webassemblyjs/wasm-gen": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.6.tgz", + "integrity": "sha512-3XOqkZP/y6B4F0PBAXvI1/bky7GryoogUtfwExeP/v7Nzwo1QLcq5oQmpKlftZLbT+ERUOAZVQjuNVak6UXjPA==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/ieee754": "1.11.6", + "@webassemblyjs/leb128": "1.11.6", + "@webassemblyjs/utf8": "1.11.6" + } + }, + "@webassemblyjs/wasm-opt": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.6.tgz", + "integrity": "sha512-cOrKuLRE7PCe6AsOVl7WasYf3wbSo4CeOk6PkrjS7g57MFfVUF9u6ysQBBODX0LdgSvQqRiGz3CXvIDKcPNy4g==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/wasm-gen": "1.11.6", + "@webassemblyjs/wasm-parser": "1.11.6" + } + }, + "@webassemblyjs/wasm-parser": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.6.tgz", + "integrity": "sha512-6ZwPeGzMJM3Dqp3hCsLgESxBGtT/OeCvCZ4TA1JUPYgmhAx38tTPR9JaKy0S5H3evQpO/h2uWs2j6Yc/fjkpTQ==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/helper-api-error": "1.11.6", + "@webassemblyjs/helper-wasm-bytecode": "1.11.6", + "@webassemblyjs/ieee754": "1.11.6", + "@webassemblyjs/leb128": "1.11.6", + "@webassemblyjs/utf8": "1.11.6" + } + }, + "@webassemblyjs/wast-printer": { + "version": "1.11.6", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.6.tgz", + "integrity": "sha512-JM7AhRcE+yW2GWYaKeHL5vt4xqee5N2WcezptmgyhNS+ScggqcT1OtXykhAb13Sn5Yas0j2uv9tHgrjwvzAP4A==", + "dev": true, + "requires": { + "@webassemblyjs/ast": "1.11.6", + "@xtuc/long": "4.2.2" + } + }, + "@webpack-cli/configtest": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.1.1.tgz", + "integrity": "sha512-wy0mglZpDSiSS0XHrVR+BAdId2+yxPSoJW8fsna3ZpYSlufjvxnP4YbKTCBZnNIcGN4r6ZPXV55X4mYExOfLmw==", + "dev": true, + "requires": {} + }, + "@webpack-cli/info": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@webpack-cli/info/-/info-2.0.2.tgz", + "integrity": "sha512-zLHQdI/Qs1UyT5UBdWNqsARasIA+AaF8t+4u2aS2nEpBQh2mWIVb8qAklq0eUENnC5mOItrIB4LiS9xMtph18A==", + "dev": true, + "requires": {} + }, + "@webpack-cli/serve": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@webpack-cli/serve/-/serve-2.0.5.tgz", + "integrity": "sha512-lqaoKnRYBdo1UgDX8uF24AfGMifWK19TxPmM5FHc2vAGxrJ/qtyUyFBWoY1tISZdelsQ5fBcOusifo5o5wSJxQ==", + "dev": true, + "requires": {} + }, + "@xtuc/ieee754": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@xtuc/ieee754/-/ieee754-1.2.0.tgz", + "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", + "dev": true + }, + "@xtuc/long": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", + "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", + "dev": true + }, + "accepts": { + "version": "1.3.8", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", + "integrity": "sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw==", + "dev": true, + "requires": { + "mime-types": "~2.1.34", + "negotiator": "0.6.3" + } + }, + "acorn": { + "version": "8.10.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", + "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", + "dev": true + }, + "acorn-import-assertions": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.9.0.tgz", + "integrity": "sha512-cmMwop9x+8KFhxvKrKfPYmN6/pKTYYHBqLa0DfvVZcKMJWNyWLnaqND7dx/qn66R7ewM1UX5XMaDVP5wlVTaVA==", + "dev": true, + "requires": {} + }, + "aggregate-error": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz", + "integrity": "sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA==", + "dev": true, + "requires": { + "clean-stack": "^2.0.0", + "indent-string": "^4.0.0" + } + }, + "ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + } + }, + "ajv-formats": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/ajv-formats/-/ajv-formats-2.1.1.tgz", + "integrity": "sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==", + "dev": true, + "requires": { + "ajv": "^8.0.0" + }, + "dependencies": { + "ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + } + }, + "json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + } + } + }, + "ajv-keywords": { + "version": "3.5.2", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", + "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", + "dev": true, + "requires": {} + }, + "ansi-html-community": { + "version": "0.0.8", + "resolved": "https://registry.npmjs.org/ansi-html-community/-/ansi-html-community-0.0.8.tgz", + "integrity": "sha512-1APHAyr3+PCamwNw3bXCPp4HFLONZt/yIH0sZp0/469KWNTEy+qN5jQ3GVX6DMZ1UXAi34yVwtTeaG/HpBuuzw==", + "dev": true + }, + "anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "requires": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + } + }, + "array-flatten": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", + "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", + "dev": true + }, + "array-union": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", + "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", + "dev": true + }, + "balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true + }, + "batch": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", + "integrity": "sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY=", + "dev": true + }, + "big.js": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/big.js/-/big.js-5.2.2.tgz", + "integrity": "sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ==", + "dev": true + }, + "binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "dev": true + }, + "body-parser": { + "version": "1.20.1", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", + "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "dev": true, + "requires": { + "bytes": "3.1.2", + "content-type": "~1.0.4", + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "on-finished": "2.4.1", + "qs": "6.11.0", + "raw-body": "2.5.1", + "type-is": "~1.6.18", + "unpipe": "1.0.0" + }, + "dependencies": { + "bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "dev": true + }, + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + } + }, + "depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true + } + } + }, + "bonjour-service": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", + "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", + "dev": true, + "requires": { + "array-flatten": "^2.1.2", + "dns-equal": "^1.0.0", + "fast-deep-equal": "^3.1.3", + "multicast-dns": "^7.2.5" + } + }, + "brace-expansion": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, + "requires": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dev": true, + "requires": { + "fill-range": "^7.0.1" + } + }, + "browserslist": { + "version": "4.21.11", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.21.11.tgz", + "integrity": "sha512-xn1UXOKUz7DjdGlg9RrUr0GGiWzI97UQJnugHtH0OLDfJB7jMgoIkYvRIEO1l9EeEERVqeqLYOcFBW9ldjypbQ==", + "dev": true, + "requires": { + "caniuse-lite": "^1.0.30001538", + "electron-to-chromium": "^1.4.526", + "node-releases": "^2.0.13", + "update-browserslist-db": "^1.0.13" + } + }, + "buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true + }, + "bytes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", + "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", + "dev": true + }, + "cacache": { + "version": "15.3.0", + "resolved": "https://registry.npmjs.org/cacache/-/cacache-15.3.0.tgz", + "integrity": "sha512-VVdYzXEn+cnbXpFgWs5hTT7OScegHVmLhJIR8Ufqk3iFD6A6j5iSX1KuBTfNEv4tdJWE2PzA6IVFtcLC7fN9wQ==", + "dev": true, + "requires": { + "@npmcli/fs": "^1.0.0", + "@npmcli/move-file": "^1.0.1", + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "glob": "^7.1.4", + "infer-owner": "^1.0.4", + "lru-cache": "^6.0.0", + "minipass": "^3.1.1", + "minipass-collect": "^1.0.2", + "minipass-flush": "^1.0.5", + "minipass-pipeline": "^1.2.2", + "mkdirp": "^1.0.3", + "p-map": "^4.0.0", + "promise-inflight": "^1.0.1", + "rimraf": "^3.0.2", + "ssri": "^8.0.1", + "tar": "^6.0.2", + "unique-filename": "^1.1.1" + } + }, + "call-bind": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", + "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "dev": true, + "requires": { + "function-bind": "^1.1.1", + "get-intrinsic": "^1.0.2" + } + }, + "caniuse-lite": { + "version": "1.0.30001538", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001538.tgz", + "integrity": "sha512-HWJnhnID+0YMtGlzcp3T9drmBJUVDchPJ08tpUGFLs9CYlwWPH2uLgpHn8fND5pCgXVtnGS3H4QR9XLMHVNkHw==", + "dev": true + }, + "chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "dev": true, + "requires": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "fsevents": "~2.3.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + } + }, + "chownr": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", + "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", + "dev": true + }, + "chrome-trace-event": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.2.tgz", + "integrity": "sha512-9e/zx1jw7B4CO+c/RXoCsfg/x1AfUBioy4owYH0bJprEYAx5hRFLRhWBqHAG57D0ZM4H7vxbP7bPe0VwhQRYDQ==", + "dev": true, + "requires": { + "tslib": "^1.9.0" + } + }, + "clean-stack": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", + "integrity": "sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A==", + "dev": true + }, + "clone-deep": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/clone-deep/-/clone-deep-4.0.1.tgz", + "integrity": "sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ==", + "dev": true, + "requires": { + "is-plain-object": "^2.0.4", + "kind-of": "^6.0.2", + "shallow-clone": "^3.0.0" + } + }, + "colorette": { + "version": "2.0.20", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz", + "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", + "dev": true + }, + "commander": { + "version": "2.20.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", + "dev": true + }, + "commondir": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", + "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", + "dev": true + }, + "compressible": { + "version": "2.0.18", + "resolved": "https://registry.npmjs.org/compressible/-/compressible-2.0.18.tgz", + "integrity": "sha512-AF3r7P5dWxL8MxyITRMlORQNaOA2IkAFaTr4k7BUumjPtRpGDTZpl0Pb1XCO6JeDCBdp126Cgs9sMxqSjgYyRg==", + "dev": true, + "requires": { + "mime-db": ">= 1.43.0 < 2" + } + }, + "compression": { + "version": "1.7.4", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.7.4.tgz", + "integrity": "sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ==", + "dev": true, + "requires": { + "accepts": "~1.3.5", + "bytes": "3.0.0", + "compressible": "~2.0.16", + "debug": "2.6.9", + "on-headers": "~1.0.2", + "safe-buffer": "5.1.2", + "vary": "~1.1.2" + }, + "dependencies": { + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + } + } + } + }, + "concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true + }, + "connect-history-api-fallback": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", + "integrity": "sha512-U73+6lQFmfiNPrYbXqr6kZ1i1wiRqXnp2nhMsINseWXO8lDau0LGEffJ8kQi4EjLZympVgRdvqjAgiZ1tgzDDA==", + "dev": true + }, + "content-disposition": { + "version": "0.5.4", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz", + "integrity": "sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ==", + "dev": true, + "requires": { + "safe-buffer": "5.2.1" + }, + "dependencies": { + "safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true + } + } + }, + "content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "dev": true + }, + "cookie": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", + "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "dev": true + }, + "cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==", + "dev": true + }, + "copy-webpack-plugin": { + "version": "6.4.1", + "resolved": "https://registry.npmjs.org/copy-webpack-plugin/-/copy-webpack-plugin-6.4.1.tgz", + "integrity": "sha512-MXyPCjdPVx5iiWyl40Va3JGh27bKzOTNY3NjUTrosD2q7dR/cLD0013uqJ3BpFbUjyONINjb6qI7nDIJujrMbA==", + "dev": true, + "requires": { + "cacache": "^15.0.5", + "fast-glob": "^3.2.4", + "find-cache-dir": "^3.3.1", + "glob-parent": "^5.1.1", + "globby": "^11.0.1", + "loader-utils": "^2.0.0", + "normalize-path": "^3.0.0", + "p-limit": "^3.0.2", + "schema-utils": "^3.0.0", + "serialize-javascript": "^5.0.1", + "webpack-sources": "^1.4.3" + } + }, + "core-util-is": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.2.tgz", + "integrity": "sha1-tf1UIgqivFq1eqtxQMlAdUUDwac=", + "dev": true + }, + "cross-spawn": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "dev": true, + "requires": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + } + }, + "datafusion-wasmtest": { + "version": "file:../pkg" + }, + "debug": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.1.1.tgz", + "integrity": "sha512-pYAIzeRo8J6KPEaJ0VWOh5Pzkbw/RetuzehGM7QRRX5he4fPHx2rdKMB256ehJCkX+XRQm16eZLqLNS8RSZXZw==", + "dev": true, + "requires": { + "ms": "^2.1.1" + }, + "dependencies": { + "ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + } + } + }, + "default-gateway": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", + "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "dev": true, + "requires": { + "execa": "^5.0.0" + } + }, + "define-lazy-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", + "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "dev": true + }, + "depd": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", + "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=", + "dev": true + }, + "destroy": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.2.0.tgz", + "integrity": "sha512-2sJGJTaXIIaR1w4iJSNoN0hnMY7Gpc/n8D4qSCJw8QqFWXf7cuAgnEHxBpweaVcPevC2l3KpjYCx3NypQQgaJg==", + "dev": true + }, + "detect-node": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.0.4.tgz", + "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", + "dev": true + }, + "dir-glob": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", + "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", + "dev": true, + "requires": { + "path-type": "^4.0.0" + } + }, + "dns-equal": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", + "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", + "dev": true + }, + "dns-packet": { + "version": "5.6.1", + "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", + "integrity": "sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==", + "dev": true, + "requires": { + "@leichtgewicht/ip-codec": "^2.0.1" + } + }, + "ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "dev": true + }, + "electron-to-chromium": { + "version": "1.4.528", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.4.528.tgz", + "integrity": "sha512-UdREXMXzLkREF4jA8t89FQjA8WHI6ssP38PMY4/4KhXFQbtImnghh4GkCgrtiZwLKUKVD2iTVXvDVQjfomEQuA==", + "dev": true + }, + "emojis-list": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/emojis-list/-/emojis-list-3.0.0.tgz", + "integrity": "sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==", + "dev": true + }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true + }, + "enhanced-resolve": { + "version": "5.15.0", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.15.0.tgz", + "integrity": "sha512-LXYT42KJ7lpIKECr2mAXIaMldcNCh/7E0KBKOu4KSfkHmP+mZmSs+8V5gBAqisWBy0OO4W5Oyys0GO1Y8KtdKg==", + "dev": true, + "requires": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + } + }, + "envinfo": { + "version": "7.10.0", + "resolved": "https://registry.npmjs.org/envinfo/-/envinfo-7.10.0.tgz", + "integrity": "sha512-ZtUjZO6l5mwTHvc1L9+1q5p/R3wTopcfqMW8r5t8SJSKqeVI/LtajORwRFEKpEFuekjD0VBjwu1HMxL4UalIRw==", + "dev": true + }, + "es-module-lexer": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", + "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", + "dev": true + }, + "escalade": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", + "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "dev": true + }, + "escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=", + "dev": true + }, + "eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "requires": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + } + }, + "esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "requires": { + "estraverse": "^5.2.0" + }, + "dependencies": { + "estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true + } + } + }, + "estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true + }, + "etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "dev": true + }, + "eventemitter3": { + "version": "4.0.7", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz", + "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", + "dev": true + }, + "events": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "dev": true + }, + "execa": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", + "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", + "dev": true, + "requires": { + "cross-spawn": "^7.0.3", + "get-stream": "^6.0.0", + "human-signals": "^2.1.0", + "is-stream": "^2.0.0", + "merge-stream": "^2.0.0", + "npm-run-path": "^4.0.1", + "onetime": "^5.1.2", + "signal-exit": "^3.0.3", + "strip-final-newline": "^2.0.0" + } + }, + "express": { + "version": "4.18.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", + "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "dev": true, + "requires": { + "accepts": "~1.3.8", + "array-flatten": "1.1.1", + "body-parser": "1.20.1", + "content-disposition": "0.5.4", + "content-type": "~1.0.4", + "cookie": "0.5.0", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "2.0.0", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.2.0", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "merge-descriptors": "1.0.1", + "methods": "~1.1.2", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "path-to-regexp": "0.1.7", + "proxy-addr": "~2.0.7", + "qs": "6.11.0", + "range-parser": "~1.2.1", + "safe-buffer": "5.2.1", + "send": "0.18.0", + "serve-static": "1.15.0", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "type-is": "~1.6.18", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + }, + "dependencies": { + "array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", + "dev": true + }, + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + } + }, + "depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true + }, + "safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true + }, + "statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true + } + } + }, + "fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true + }, + "fast-glob": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.1.tgz", + "integrity": "sha512-kNFPyjhh5cKjrUltxs+wFx+ZkbRaxxmZ+X0ZU31SOsxCEtP9VPgtq2teZw1DebupL5GmDaNQ6yKMMVcM41iqDg==", + "dev": true, + "requires": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.4" + } + }, + "fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true + }, + "fastest-levenshtein": { + "version": "1.0.16", + "resolved": "https://registry.npmjs.org/fastest-levenshtein/-/fastest-levenshtein-1.0.16.tgz", + "integrity": "sha512-eRnCtTTtGZFpQCwhJiUOuxPQWRXVKYDn0b2PeHfXL6/Zi53SLAzAHfVhVWK2AryC/WH05kGfxhFIPvTF0SXQzg==", + "dev": true + }, + "fastq": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "dev": true, + "requires": { + "reusify": "^1.0.4" + } + }, + "faye-websocket": { + "version": "0.11.4", + "resolved": "https://registry.npmjs.org/faye-websocket/-/faye-websocket-0.11.4.tgz", + "integrity": "sha512-CzbClwlXAuiRQAlUyfqPgvPoNKTckTPGfwZV4ZdAhVcP2lh9KUxJg2b5GkE7XbjKQ3YJnQ9z6D9ntLAlB+tP8g==", + "dev": true, + "requires": { + "websocket-driver": ">=0.5.1" + } + }, + "fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dev": true, + "requires": { + "to-regex-range": "^5.0.1" + } + }, + "finalhandler": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.2.0.tgz", + "integrity": "sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==", + "dev": true, + "requires": { + "debug": "2.6.9", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "on-finished": "2.4.1", + "parseurl": "~1.3.3", + "statuses": "2.0.1", + "unpipe": "~1.0.0" + }, + "dependencies": { + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + } + }, + "statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true + } + } + }, + "find-cache-dir": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/find-cache-dir/-/find-cache-dir-3.3.2.tgz", + "integrity": "sha512-wXZV5emFEjrridIgED11OoUKLxiYjAcqot/NJdAkOhlJ+vGzwhOAfcG5OX1jP+S0PcjEn8bdMJv+g2jwQ3Onig==", + "dev": true, + "requires": { + "commondir": "^1.0.1", + "make-dir": "^3.0.2", + "pkg-dir": "^4.1.0" + } + }, + "find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "dev": true, + "requires": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + } + }, + "follow-redirects": { + "version": "1.15.3", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", + "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "dev": true + }, + "forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "dev": true + }, + "fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", + "dev": true + }, + "fs-minipass": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", + "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", + "dev": true, + "requires": { + "minipass": "^3.0.0" + } + }, + "fs-monkey": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", + "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", + "dev": true + }, + "fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "dev": true + }, + "fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "optional": true + }, + "function-bind": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", + "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "dev": true + }, + "get-intrinsic": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", + "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "dev": true, + "requires": { + "function-bind": "^1.1.1", + "has": "^1.0.3", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3" + } + }, + "get-stream": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", + "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", + "dev": true + }, + "glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "dev": true, + "requires": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + } + }, + "glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "requires": { + "is-glob": "^4.0.1" + } + }, + "glob-to-regexp": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", + "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", + "dev": true + }, + "globby": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", + "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "dev": true, + "requires": { + "array-union": "^2.1.0", + "dir-glob": "^3.0.1", + "fast-glob": "^3.2.9", + "ignore": "^5.2.0", + "merge2": "^1.4.1", + "slash": "^3.0.0" + } + }, + "graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "dev": true + }, + "handle-thing": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/handle-thing/-/handle-thing-2.0.1.tgz", + "integrity": "sha512-9Qn4yBxelxoh2Ow62nP+Ka/kMnOXRi8BXnRaUwezLNhqelnN49xKz4F/dPP8OYLxLxq6JDtZb2i9XznUQbNPTg==", + "dev": true + }, + "has": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", + "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", + "dev": true, + "requires": { + "function-bind": "^1.1.1" + } + }, + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true + }, + "has-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", + "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "dev": true + }, + "has-symbols": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", + "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "dev": true + }, + "hpack.js": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", + "integrity": "sha1-h3dMCUnlE/QuhFdbPEVoH63ioLI=", + "dev": true, + "requires": { + "inherits": "^2.0.1", + "obuf": "^1.0.0", + "readable-stream": "^2.0.1", + "wbuf": "^1.1.0" + } + }, + "html-entities": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", + "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", + "dev": true + }, + "http-deceiver": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", + "integrity": "sha1-+nFolEq5pRnTN8sL7HKE3D5yPYc=", + "dev": true + }, + "http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "dev": true, + "requires": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "dependencies": { + "depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true + }, + "inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true + }, + "statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true + } + } + }, + "http-parser-js": { + "version": "0.5.8", + "resolved": "https://registry.npmjs.org/http-parser-js/-/http-parser-js-0.5.8.tgz", + "integrity": "sha512-SGeBX54F94Wgu5RH3X5jsDtf4eHyRogWX1XGT3b4HuW3tQPM4AaBzoUji/4AAJNXCEOWZ5O0DgZmJw1947gD5Q==", + "dev": true + }, + "http-proxy": { + "version": "1.18.1", + "resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz", + "integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==", + "dev": true, + "requires": { + "eventemitter3": "^4.0.0", + "follow-redirects": "^1.0.0", + "requires-port": "^1.0.0" + } + }, + "http-proxy-middleware": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", + "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "dev": true, + "requires": { + "@types/http-proxy": "^1.17.8", + "http-proxy": "^1.18.1", + "is-glob": "^4.0.1", + "is-plain-obj": "^3.0.0", + "micromatch": "^4.0.2" + } + }, + "human-signals": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", + "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "dev": true + }, + "iconv-lite": { + "version": "0.4.24", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", + "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", + "dev": true, + "requires": { + "safer-buffer": ">= 2.1.2 < 3" + } + }, + "ignore": { + "version": "5.2.4", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", + "integrity": "sha512-MAb38BcSbH0eHNBxn7ql2NH/kX33OkB3lZ1BNdh7ENeRChHTYsTvWrMubiIAMNS2llXEEgZ1MUOBtXChP3kaFQ==", + "dev": true + }, + "import-local": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/import-local/-/import-local-3.1.0.tgz", + "integrity": "sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg==", + "dev": true, + "requires": { + "pkg-dir": "^4.2.0", + "resolve-cwd": "^3.0.0" + } + }, + "imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true + }, + "indent-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/indent-string/-/indent-string-4.0.0.tgz", + "integrity": "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==", + "dev": true + }, + "infer-owner": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/infer-owner/-/infer-owner-1.0.4.tgz", + "integrity": "sha512-IClj+Xz94+d7irH5qRyfJonOdfTzuDaifE6ZPWfx0N0+/ATZCbuTPq2prFl526urkQd90WyUKIh1DfBQ2hMz9A==", + "dev": true + }, + "inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "dev": true, + "requires": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=", + "dev": true + }, + "interpret": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/interpret/-/interpret-3.1.1.tgz", + "integrity": "sha512-6xwYfHbajpoF0xLW+iwLkhwgvLoZDfjYfoFNu8ftMoXINzwuymNLd9u/KmwtdT2GbR+/Cz66otEGEVVUHX9QLQ==", + "dev": true + }, + "ipaddr.js": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-2.1.0.tgz", + "integrity": "sha512-LlbxQ7xKzfBusov6UMi4MFpEg0m+mAm9xyNGEduwXMEDuf4WfzB/RZwMVYEd7IKGvh4IUkEXYxtAVu9T3OelJQ==", + "dev": true + }, + "is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "requires": { + "binary-extensions": "^2.0.0" + } + }, + "is-core-module": { + "version": "2.13.0", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.13.0.tgz", + "integrity": "sha512-Z7dk6Qo8pOCp3l4tsX2C5ZVas4V+UxwQodwZhLopL91TX8UyyHEXafPcyoeeWuLrwzHcr3igO78wNLwHJHsMCQ==", + "dev": true, + "requires": { + "has": "^1.0.3" + } + }, + "is-docker": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", + "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "dev": true + }, + "is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true + }, + "is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "requires": { + "is-extglob": "^2.1.1" + } + }, + "is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true + }, + "is-plain-obj": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-plain-obj/-/is-plain-obj-3.0.0.tgz", + "integrity": "sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==", + "dev": true + }, + "is-plain-object": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/is-plain-object/-/is-plain-object-2.0.4.tgz", + "integrity": "sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og==", + "dev": true, + "requires": { + "isobject": "^3.0.1" + } + }, + "is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "dev": true + }, + "is-wsl": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", + "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "dev": true, + "requires": { + "is-docker": "^2.0.0" + } + }, + "isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=", + "dev": true + }, + "isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true + }, + "isobject": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/isobject/-/isobject-3.0.1.tgz", + "integrity": "sha512-WhB9zCku7EGTj/HQQRz5aUQEUeoQZH2bWcltRErOpymJ4boYE6wL9Tbr23krRPSZ+C5zqNSrSw+Cc7sZZ4b7vg==", + "dev": true + }, + "jest-worker": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-27.5.1.tgz", + "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", + "dev": true, + "requires": { + "@types/node": "*", + "merge-stream": "^2.0.0", + "supports-color": "^8.0.0" + } + }, + "json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", + "dev": true + }, + "json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true + }, + "json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true + }, + "kind-of": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-6.0.3.tgz", + "integrity": "sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==", + "dev": true + }, + "launch-editor": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", + "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "dev": true, + "requires": { + "picocolors": "^1.0.0", + "shell-quote": "^1.7.3" + } + }, + "loader-runner": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", + "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "dev": true + }, + "loader-utils": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/loader-utils/-/loader-utils-2.0.4.tgz", + "integrity": "sha512-xXqpXoINfFhgua9xiqD8fPFHgkoq1mmmpE92WlDbm9rNRd/EbRb+Gqf908T2DMfuHjjJlksiK2RbHVOdD/MqSw==", + "dev": true, + "requires": { + "big.js": "^5.2.2", + "emojis-list": "^3.0.0", + "json5": "^2.1.2" + } + }, + "locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "dev": true, + "requires": { + "p-locate": "^4.1.0" + } + }, + "lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dev": true, + "requires": { + "yallist": "^4.0.0" + } + }, + "make-dir": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-3.1.0.tgz", + "integrity": "sha512-g3FeP20LNwhALb/6Cz6Dd4F2ngze0jz7tbzrD2wAV+o9FeNHe4rL+yK2md0J/fiSf1sa1ADhXqi5+oVwOM/eGw==", + "dev": true, + "requires": { + "semver": "^6.0.0" + }, + "dependencies": { + "semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true + } + } + }, + "media-typer": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==", + "dev": true + }, + "memfs": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", + "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "dev": true, + "requires": { + "fs-monkey": "^1.0.4" + } + }, + "merge-descriptors": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", + "integrity": "sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==", + "dev": true + }, + "merge-stream": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", + "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", + "dev": true + }, + "merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true + }, + "methods": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", + "integrity": "sha512-iclAHeNqNm68zFtnZ0e+1L2yUIdvzNoauKU4WBA3VvH/vPFieF7qfRlwUZU+DA9P9bPXIS90ulxoUoCH23sV2w==", + "dev": true + }, + "micromatch": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", + "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "dev": true, + "requires": { + "braces": "^3.0.2", + "picomatch": "^2.3.1" + } + }, + "mime": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz", + "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==", + "dev": true + }, + "mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "dev": true + }, + "mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "dev": true, + "requires": { + "mime-db": "1.52.0" + } + }, + "mimic-fn": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", + "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", + "dev": true + }, + "minimalistic-assert": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", + "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", + "dev": true + }, + "minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "requires": { + "brace-expansion": "^1.1.7" + } + }, + "minipass": { + "version": "3.3.6", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", + "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", + "dev": true, + "requires": { + "yallist": "^4.0.0" + } + }, + "minipass-collect": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/minipass-collect/-/minipass-collect-1.0.2.tgz", + "integrity": "sha512-6T6lH0H8OG9kITm/Jm6tdooIbogG9e0tLgpY6mphXSm/A9u8Nq1ryBG+Qspiub9LjWlBPsPS3tWQ/Botq4FdxA==", + "dev": true, + "requires": { + "minipass": "^3.0.0" + } + }, + "minipass-flush": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/minipass-flush/-/minipass-flush-1.0.5.tgz", + "integrity": "sha512-JmQSYYpPUqX5Jyn1mXaRwOda1uQ8HP5KAT/oDSLCzt1BYRhQU0/hDtsB1ufZfEEzMZ9aAVmsBw8+FWsIXlClWw==", + "dev": true, + "requires": { + "minipass": "^3.0.0" + } + }, + "minipass-pipeline": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/minipass-pipeline/-/minipass-pipeline-1.2.4.tgz", + "integrity": "sha512-xuIq7cIOt09RPRJ19gdi4b+RiNvDFYe5JH+ggNvBqGqpQXcru3PcRmOZuHBKWK1Txf9+cQ+HMVN4d6z46LZP7A==", + "dev": true, + "requires": { + "minipass": "^3.0.0" + } + }, + "minizlib": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", + "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", + "dev": true, + "requires": { + "minipass": "^3.0.0", + "yallist": "^4.0.0" + } + }, + "mkdirp": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", + "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", + "dev": true + }, + "ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=", + "dev": true + }, + "multicast-dns": { + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", + "integrity": "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==", + "dev": true, + "requires": { + "dns-packet": "^5.2.2", + "thunky": "^1.0.2" + } + }, + "negotiator": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz", + "integrity": "sha512-+EUsqGPLsM+j/zdChZjsnX51g4XrHFOIXwfnCVPGlQk/k5giakcKsuxCObBRu6DSm9opw/O6slWbJdghQM4bBg==", + "dev": true + }, + "neo-async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", + "dev": true + }, + "node-forge": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", + "integrity": "sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==", + "dev": true + }, + "node-releases": { + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.13.tgz", + "integrity": "sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==", + "dev": true + }, + "normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true + }, + "npm-run-path": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", + "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", + "dev": true, + "requires": { + "path-key": "^3.0.0" + } + }, + "object-inspect": { + "version": "1.12.3", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", + "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "dev": true + }, + "obuf": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/obuf/-/obuf-1.1.2.tgz", + "integrity": "sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg==", + "dev": true + }, + "on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "dev": true, + "requires": { + "ee-first": "1.1.1" + } + }, + "on-headers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", + "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "dev": true + }, + "once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "dev": true, + "requires": { + "wrappy": "1" + } + }, + "onetime": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", + "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", + "dev": true, + "requires": { + "mimic-fn": "^2.1.0" + } + }, + "open": { + "version": "8.4.2", + "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", + "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "dev": true, + "requires": { + "define-lazy-prop": "^2.0.0", + "is-docker": "^2.1.1", + "is-wsl": "^2.2.0" + } + }, + "p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "requires": { + "yocto-queue": "^0.1.0" + } + }, + "p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "dev": true, + "requires": { + "p-limit": "^2.2.0" + }, + "dependencies": { + "p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "dev": true, + "requires": { + "p-try": "^2.0.0" + } + } + } + }, + "p-map": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/p-map/-/p-map-4.0.0.tgz", + "integrity": "sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ==", + "dev": true, + "requires": { + "aggregate-error": "^3.0.0" + } + }, + "p-retry": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", + "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "dev": true, + "requires": { + "@types/retry": "0.12.0", + "retry": "^0.13.1" + } + }, + "p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "dev": true + }, + "parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "dev": true + }, + "path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true + }, + "path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "dev": true + }, + "path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true + }, + "path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true + }, + "path-to-regexp": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", + "integrity": "sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==", + "dev": true + }, + "path-type": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "dev": true + }, + "picocolors": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", + "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "dev": true + }, + "picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true + }, + "pkg-dir": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/pkg-dir/-/pkg-dir-4.2.0.tgz", + "integrity": "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ==", + "dev": true, + "requires": { + "find-up": "^4.0.0" + } + }, + "process-nextick-args": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.0.tgz", + "integrity": "sha512-MtEC1TqN0EU5nephaJ4rAtThHtC86dNN9qCuEhtshvpVBkAW5ZO7BASN9REnF9eoXGcRub+pFuKEpOHE+HbEMw==", + "dev": true + }, + "promise-inflight": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", + "integrity": "sha512-6zWPyEOFaQBJYcGMHBKTKJ3u6TBsnMFOIZSa6ce1e/ZrrsOlnHRHbabMjLiBYKp+n44X9eUI6VUPaukCXHuG4g==", + "dev": true + }, + "proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "dev": true, + "requires": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "dependencies": { + "ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "dev": true + } + } + }, + "punycode": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", + "integrity": "sha512-rRV+zQD8tVFys26lAGR9WUuS4iUAngJScM+ZRSKtvl5tKeZ2t5bvdNFdNHBW9FWR4guGHlgmsZ1G7BSm2wTbuA==", + "dev": true + }, + "qs": { + "version": "6.11.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", + "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "dev": true, + "requires": { + "side-channel": "^1.0.4" + } + }, + "queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true + }, + "randombytes": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", + "integrity": "sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==", + "dev": true, + "requires": { + "safe-buffer": "^5.1.0" + } + }, + "range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "dev": true + }, + "raw-body": { + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", + "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "dev": true, + "requires": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.4.24", + "unpipe": "1.0.0" + }, + "dependencies": { + "bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "dev": true + } + } + }, + "readable-stream": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", + "dev": true, + "requires": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "requires": { + "picomatch": "^2.2.1" + } + }, + "rechoir": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/rechoir/-/rechoir-0.8.0.tgz", + "integrity": "sha512-/vxpCXddiX8NGfGO/mTafwjq4aFa/71pvamip0++IQk3zG8cbCj0fifNPrjjF1XMXUne91jL9OoxmdykoEtifQ==", + "dev": true, + "requires": { + "resolve": "^1.20.0" + } + }, + "require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true + }, + "requires-port": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz", + "integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==", + "dev": true + }, + "resolve": { + "version": "1.22.6", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.6.tgz", + "integrity": "sha512-njhxM7mV12JfufShqGy3Rz8j11RPdLy4xi15UurGJeoHLfJpVXKdh3ueuOqbYUcDZnffr6X739JBo5LzyahEsw==", + "dev": true, + "requires": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + } + }, + "resolve-cwd": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/resolve-cwd/-/resolve-cwd-3.0.0.tgz", + "integrity": "sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg==", + "dev": true, + "requires": { + "resolve-from": "^5.0.0" + } + }, + "resolve-from": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-5.0.0.tgz", + "integrity": "sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==", + "dev": true + }, + "retry": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", + "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", + "dev": true + }, + "reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true + }, + "rimraf": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "dev": true, + "requires": { + "glob": "^7.1.3" + } + }, + "run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "requires": { + "queue-microtask": "^1.2.2" + } + }, + "safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", + "dev": true + }, + "safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "dev": true + }, + "schema-utils": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", + "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "dev": true, + "requires": { + "@types/json-schema": "^7.0.8", + "ajv": "^6.12.5", + "ajv-keywords": "^3.5.2" + } + }, + "select-hose": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/select-hose/-/select-hose-2.0.0.tgz", + "integrity": "sha1-Yl2GWPhlr0Psliv8N2o3NZpJlMo=", + "dev": true + }, + "selfsigned": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", + "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "dev": true, + "requires": { + "node-forge": "^1" + } + }, + "semver": { + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "dev": true, + "requires": { + "lru-cache": "^6.0.0" + } + }, + "send": { + "version": "0.18.0", + "resolved": "https://registry.npmjs.org/send/-/send-0.18.0.tgz", + "integrity": "sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==", + "dev": true, + "requires": { + "debug": "2.6.9", + "depd": "2.0.0", + "destroy": "1.2.0", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "2.0.0", + "mime": "1.6.0", + "ms": "2.1.3", + "on-finished": "2.4.1", + "range-parser": "~1.2.1", + "statuses": "2.0.1" + }, + "dependencies": { + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + }, + "dependencies": { + "ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", + "dev": true + } + } + }, + "depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "dev": true + }, + "ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "dev": true + }, + "statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "dev": true + } + } + }, + "serialize-javascript": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-5.0.1.tgz", + "integrity": "sha512-SaaNal9imEO737H2c05Og0/8LUXG7EnsZyMa8MzkmuHoELfT6txuj0cMqRj6zfPKnmQ1yasR4PCJc8x+M4JSPA==", + "dev": true, + "requires": { + "randombytes": "^2.1.0" + } + }, + "serve-index": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/serve-index/-/serve-index-1.9.1.tgz", + "integrity": "sha1-03aNabHn2C5c4FD/9bRTvqEqkjk=", + "dev": true, + "requires": { + "accepts": "~1.3.4", + "batch": "0.6.1", + "debug": "2.6.9", + "escape-html": "~1.0.3", + "http-errors": "~1.6.2", + "mime-types": "~2.1.17", + "parseurl": "~1.3.2" + }, + "dependencies": { + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "dev": true, + "requires": { + "ms": "2.0.0" + } + }, + "http-errors": { + "version": "1.6.3", + "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", + "dev": true, + "requires": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + } + }, + "setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==", + "dev": true + } + } + }, + "serve-static": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.15.0.tgz", + "integrity": "sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==", + "dev": true, + "requires": { + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "parseurl": "~1.3.3", + "send": "0.18.0" + } + }, + "setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "dev": true + }, + "shallow-clone": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/shallow-clone/-/shallow-clone-3.0.1.tgz", + "integrity": "sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA==", + "dev": true, + "requires": { + "kind-of": "^6.0.2" + } + }, + "shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "requires": { + "shebang-regex": "^3.0.0" + } + }, + "shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true + }, + "shell-quote": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", + "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "dev": true + }, + "side-channel": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", + "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "dev": true, + "requires": { + "call-bind": "^1.0.0", + "get-intrinsic": "^1.0.2", + "object-inspect": "^1.9.0" + } + }, + "signal-exit": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", + "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", + "dev": true + }, + "slash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", + "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "dev": true + }, + "sockjs": { + "version": "0.3.24", + "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.24.tgz", + "integrity": "sha512-GJgLTZ7vYb/JtPSSZ10hsOYIvEYsjbNU+zPdIHcUaWVNUEPivzxku31865sSSud0Da0W4lEeOPlmw93zLQchuQ==", + "dev": true, + "requires": { + "faye-websocket": "^0.11.3", + "uuid": "^8.3.2", + "websocket-driver": "^0.7.4" + } + }, + "source-list-map": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/source-list-map/-/source-list-map-2.0.1.tgz", + "integrity": "sha512-qnQ7gVMxGNxsiL4lEuJwe/To8UnK7fAnmbGEEH8RpLouuKbeEm0lhbQVFIrNSuB+G7tVrAlVsZgETT5nljf+Iw==", + "dev": true + }, + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true + }, + "source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "requires": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "spdy": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/spdy/-/spdy-4.0.2.tgz", + "integrity": "sha512-r46gZQZQV+Kl9oItvl1JZZqJKGr+oEkB08A6BzkiR7593/7IbtuncXHd2YoYeTsG4157ZssMu9KYvUHLcjcDoA==", + "dev": true, + "requires": { + "debug": "^4.1.0", + "handle-thing": "^2.0.0", + "http-deceiver": "^1.2.7", + "select-hose": "^2.0.0", + "spdy-transport": "^3.0.0" + } + }, + "spdy-transport": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/spdy-transport/-/spdy-transport-3.0.0.tgz", + "integrity": "sha512-hsLVFE5SjA6TCisWeJXFKniGGOpBgMLmerfO2aCyCU5s7nJ/rpAepqmFifv/GCbSbueEeAJJnmSQ2rKC/g8Fcw==", + "dev": true, + "requires": { + "debug": "^4.1.0", + "detect-node": "^2.0.4", + "hpack.js": "^2.1.6", + "obuf": "^1.1.2", + "readable-stream": "^3.0.6", + "wbuf": "^1.7.3" + }, + "dependencies": { + "readable-stream": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.0.tgz", + "integrity": "sha512-BViHy7LKeTz4oNnkcLJ+lVSL6vpiFeX6/d3oSH8zCW7UxP2onchk+vTGB143xuFjHS3deTgkKoXXymXqymiIdA==", + "dev": true, + "requires": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + } + } + } + }, + "ssri": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/ssri/-/ssri-8.0.1.tgz", + "integrity": "sha512-97qShzy1AiyxvPNIkLWoGua7xoQzzPjQ0HAH4B0rWKo7SZ6USuPcrUiAFrws0UH8RrbWmgq3LMTObhPIHbbBeQ==", + "dev": true, + "requires": { + "minipass": "^3.1.1" + } + }, + "statuses": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", + "integrity": "sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow=", + "dev": true + }, + "string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dev": true, + "requires": { + "safe-buffer": "~5.1.0" + } + }, + "strip-final-newline": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", + "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", + "dev": true + }, + "supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + } + }, + "supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true + }, + "tapable": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", + "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "dev": true + }, + "tar": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.0.tgz", + "integrity": "sha512-/Wo7DcT0u5HUV486xg675HtjNd3BXZ6xDbzsCUZPt5iw8bTQ63bP0Raut3mvro9u+CUyq7YQd8Cx55fsZXxqLQ==", + "dev": true, + "requires": { + "chownr": "^2.0.0", + "fs-minipass": "^2.0.0", + "minipass": "^5.0.0", + "minizlib": "^2.1.1", + "mkdirp": "^1.0.3", + "yallist": "^4.0.0" + }, + "dependencies": { + "minipass": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", + "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "dev": true + } + } + }, + "terser": { + "version": "5.20.0", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.20.0.tgz", + "integrity": "sha512-e56ETryaQDyebBwJIWYB2TT6f2EZ0fL0sW/JRXNMN26zZdKi2u/E/5my5lG6jNxym6qsrVXfFRmOdV42zlAgLQ==", + "dev": true, + "requires": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.8.2", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + } + }, + "terser-webpack-plugin": { + "version": "5.3.9", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.9.tgz", + "integrity": "sha512-ZuXsqE07EcggTWQjXUj+Aot/OMcD0bMKGgF63f7UxYcu5/AJF53aIpK1YoP5xR9l6s/Hy2b+t1AM0bLNPRuhwA==", + "dev": true, + "requires": { + "@jridgewell/trace-mapping": "^0.3.17", + "jest-worker": "^27.4.5", + "schema-utils": "^3.1.1", + "serialize-javascript": "^6.0.1", + "terser": "^5.16.8" + }, + "dependencies": { + "serialize-javascript": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", + "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", + "dev": true, + "requires": { + "randombytes": "^2.1.0" + } + } + } + }, + "thunky": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz", + "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", + "dev": true + }, + "to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "requires": { + "is-number": "^7.0.0" + } + }, + "toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "dev": true + }, + "tslib": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.11.1.tgz", + "integrity": "sha512-aZW88SY8kQbU7gpV19lN24LtXh/yD4ZZg6qieAJDDg+YBsJcSmLGK9QpnUjAKVG/xefmvJGd1WUmfpT/g6AJGA==", + "dev": true + }, + "type-is": { + "version": "1.6.18", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz", + "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==", + "dev": true, + "requires": { + "media-typer": "0.3.0", + "mime-types": "~2.1.24" + } + }, + "unique-filename": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/unique-filename/-/unique-filename-1.1.1.tgz", + "integrity": "sha512-Vmp0jIp2ln35UTXuryvjzkjGdRyf9b2lTXuSYUiPmzRcl3FDtYqAwOnTJkAngD9SWhnoJzDbTKwaOrZ+STtxNQ==", + "dev": true, + "requires": { + "unique-slug": "^2.0.0" + } + }, + "unique-slug": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/unique-slug/-/unique-slug-2.0.2.tgz", + "integrity": "sha512-zoWr9ObaxALD3DOPfjPSqxt4fnZiWblxHIgeWqW8x7UqDzEtHEQLzji2cuJYQFCU6KmoJikOYAZlrTHHebjx2w==", + "dev": true, + "requires": { + "imurmurhash": "^0.1.4" + } + }, + "unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "dev": true + }, + "update-browserslist-db": { + "version": "1.0.13", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", + "integrity": "sha512-xebP81SNcPuNpPP3uzeW1NYXxI3rxyJzF3pD6sH4jE7o/IX+WtSpwnVU+qIsDPyk0d3hmFQ7mjqc6AtV604hbg==", + "dev": true, + "requires": { + "escalade": "^3.1.1", + "picocolors": "^1.0.0" + } + }, + "uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "requires": { + "punycode": "^2.1.0" + } + }, + "util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8=", + "dev": true + }, + "utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==", + "dev": true + }, + "uuid": { + "version": "8.3.2", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", + "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "dev": true + }, + "vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=", + "dev": true + }, + "watchpack": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", + "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", + "dev": true, + "requires": { + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.1.2" + } + }, + "wbuf": { + "version": "1.7.3", + "resolved": "https://registry.npmjs.org/wbuf/-/wbuf-1.7.3.tgz", + "integrity": "sha512-O84QOnr0icsbFGLS0O3bI5FswxzRr8/gHwWkDlQFskhSPryQXvrTMxjxGP4+iWYoauLoBvfDpkrOauZ+0iZpDA==", + "dev": true, + "requires": { + "minimalistic-assert": "^1.0.0" + } + }, + "webpack": { + "version": "5.88.2", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.88.2.tgz", + "integrity": "sha512-JmcgNZ1iKj+aiR0OvTYtWQqJwq37Pf683dY9bVORwVbUrDhLhdn/PlO2sHsFHPkj7sHNQF3JwaAkp49V+Sq1tQ==", + "dev": true, + "requires": { + "@types/eslint-scope": "^3.7.3", + "@types/estree": "^1.0.0", + "@webassemblyjs/ast": "^1.11.5", + "@webassemblyjs/wasm-edit": "^1.11.5", + "@webassemblyjs/wasm-parser": "^1.11.5", + "acorn": "^8.7.1", + "acorn-import-assertions": "^1.9.0", + "browserslist": "^4.14.5", + "chrome-trace-event": "^1.0.2", + "enhanced-resolve": "^5.15.0", + "es-module-lexer": "^1.2.1", + "eslint-scope": "5.1.1", + "events": "^3.2.0", + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.2.9", + "json-parse-even-better-errors": "^2.3.1", + "loader-runner": "^4.2.0", + "mime-types": "^2.1.27", + "neo-async": "^2.6.2", + "schema-utils": "^3.2.0", + "tapable": "^2.1.1", + "terser-webpack-plugin": "^5.3.7", + "watchpack": "^2.4.0", + "webpack-sources": "^3.2.3" + }, + "dependencies": { + "webpack-sources": { + "version": "3.2.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.2.3.tgz", + "integrity": "sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w==", + "dev": true + } + } + }, + "webpack-cli": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/webpack-cli/-/webpack-cli-5.1.4.tgz", + "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", + "dev": true, + "requires": { + "@discoveryjs/json-ext": "^0.5.0", + "@webpack-cli/configtest": "^2.1.1", + "@webpack-cli/info": "^2.0.2", + "@webpack-cli/serve": "^2.0.5", + "colorette": "^2.0.14", + "commander": "^10.0.1", + "cross-spawn": "^7.0.3", + "envinfo": "^7.7.3", + "fastest-levenshtein": "^1.0.12", + "import-local": "^3.0.2", + "interpret": "^3.1.1", + "rechoir": "^0.8.0", + "webpack-merge": "^5.7.3" + }, + "dependencies": { + "commander": { + "version": "10.0.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-10.0.1.tgz", + "integrity": "sha512-y4Mg2tXshplEbSGzx7amzPwKKOCGuoSRP/CjEdwwk0FOGlUbq6lKuoyDZTNZkmxHdJtp54hdfY/JUrdL7Xfdug==", + "dev": true + } + } + }, + "webpack-dev-middleware": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", + "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "dev": true, + "requires": { + "colorette": "^2.0.10", + "memfs": "^3.4.3", + "mime-types": "^2.1.31", + "range-parser": "^1.2.1", + "schema-utils": "^4.0.0" + }, + "dependencies": { + "ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + } + }, + "ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.3" + } + }, + "json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "schema-utils": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", + "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "dev": true, + "requires": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + } + } + } + }, + "webpack-dev-server": { + "version": "4.15.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", + "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", + "dev": true, + "requires": { + "@types/bonjour": "^3.5.9", + "@types/connect-history-api-fallback": "^1.3.5", + "@types/express": "^4.17.13", + "@types/serve-index": "^1.9.1", + "@types/serve-static": "^1.13.10", + "@types/sockjs": "^0.3.33", + "@types/ws": "^8.5.5", + "ansi-html-community": "^0.0.8", + "bonjour-service": "^1.0.11", + "chokidar": "^3.5.3", + "colorette": "^2.0.10", + "compression": "^1.7.4", + "connect-history-api-fallback": "^2.0.0", + "default-gateway": "^6.0.3", + "express": "^4.17.3", + "graceful-fs": "^4.2.6", + "html-entities": "^2.3.2", + "http-proxy-middleware": "^2.0.3", + "ipaddr.js": "^2.0.1", + "launch-editor": "^2.6.0", + "open": "^8.0.9", + "p-retry": "^4.5.0", + "rimraf": "^3.0.2", + "schema-utils": "^4.0.0", + "selfsigned": "^2.1.1", + "serve-index": "^1.9.1", + "sockjs": "^0.3.24", + "spdy": "^4.0.2", + "webpack-dev-middleware": "^5.3.1", + "ws": "^8.13.0" + }, + "dependencies": { + "ajv": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", + "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2", + "uri-js": "^4.2.2" + } + }, + "ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "requires": { + "fast-deep-equal": "^3.1.3" + } + }, + "json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true + }, + "schema-utils": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", + "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "dev": true, + "requires": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + } + } + } + }, + "webpack-merge": { + "version": "5.9.0", + "resolved": "https://registry.npmjs.org/webpack-merge/-/webpack-merge-5.9.0.tgz", + "integrity": "sha512-6NbRQw4+Sy50vYNTw7EyOn41OZItPiXB8GNv3INSoe3PSFaHJEz3SHTrYVaRm2LilNGnFUzh0FAwqPEmU/CwDg==", + "dev": true, + "requires": { + "clone-deep": "^4.0.1", + "wildcard": "^2.0.0" + } + }, + "webpack-sources": { + "version": "1.4.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-1.4.3.tgz", + "integrity": "sha512-lgTS3Xhv1lCOKo7SA5TjKXMjpSM4sBjNV5+q2bqesbSPs5FjGmU6jjtBSkX9b4qW87vDIsCIlUPOEhbZrMdjeQ==", + "dev": true, + "requires": { + "source-list-map": "^2.0.0", + "source-map": "~0.6.1" + } + }, + "websocket-driver": { + "version": "0.7.4", + "resolved": "https://registry.npmjs.org/websocket-driver/-/websocket-driver-0.7.4.tgz", + "integrity": "sha512-b17KeDIQVjvb0ssuSDF2cYXSg2iztliJ4B9WdsuB6J952qCPKmnVq4DyW5motImXHDC1cBT/1UezrJVsKw5zjg==", + "dev": true, + "requires": { + "http-parser-js": ">=0.5.1", + "safe-buffer": ">=5.1.0", + "websocket-extensions": ">=0.1.1" + } + }, + "websocket-extensions": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/websocket-extensions/-/websocket-extensions-0.1.4.tgz", + "integrity": "sha512-OqedPIGOfsDlo31UNwYbCFMSaO9m9G/0faIHj5/dZFDMFqPTcx6UwqyOy3COEaEOg/9VsGIpdqn62W5KhoKSpg==", + "dev": true + }, + "which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "requires": { + "isexe": "^2.0.0" + } + }, + "wildcard": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/wildcard/-/wildcard-2.0.1.tgz", + "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", + "dev": true + }, + "wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "dev": true + }, + "ws": { + "version": "8.14.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz", + "integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==", + "dev": true, + "requires": {} + }, + "yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "dev": true + }, + "yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true + } + } +} diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json new file mode 100644 index 0000000000000..cd32070fa0bc6 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -0,0 +1,35 @@ +{ + "name": "create-wasm-app", + "version": "0.1.0", + "description": "create an app to consume rust-generated wasm packages", + "main": "index.js", + "scripts": { + "build": "webpack --config webpack.config.js", + "start": "webpack-dev-server" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/rustwasm/create-wasm-app.git" + }, + "keywords": [ + "webassembly", + "wasm", + "rust", + "webpack" + ], + "author": "Ashley Williams ", + "license": "(MIT OR Apache-2.0)", + "bugs": { + "url": "https://github.com/rustwasm/create-wasm-app/issues" + }, + "homepage": "https://github.com/rustwasm/create-wasm-app#readme", + "dependencies": { + "datafusion-wasmtest": "../pkg" + }, + "devDependencies": { + "webpack": "5.88.2", + "webpack-cli": "5.1.4", + "webpack-dev-server": "4.15.1", + "copy-webpack-plugin": "6.4.1" + } +} diff --git a/datafusion/wasmtest/datafusion-wasm-app/webpack.config.js b/datafusion/wasmtest/datafusion-wasm-app/webpack.config.js new file mode 100644 index 0000000000000..33f1ac4894322 --- /dev/null +++ b/datafusion/wasmtest/datafusion-wasm-app/webpack.config.js @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +const CopyWebpackPlugin = require("copy-webpack-plugin"); +const path = require('path'); + +module.exports = { + entry: "./bootstrap.js", + output: { + path: path.resolve(__dirname, "dist"), + filename: "bootstrap.js", + }, + mode: "development", + experiments: { + asyncWebAssembly: true, // enabling async WebAssembly + }, + module: { + rules: [ + { + test: /\.wasm$/, + type: "webassembly/async", + }, + ], + }, + plugins: [ + new CopyWebpackPlugin({ + patterns: [ + { from: 'index.html', to: 'index.html' }, // If you want to keep the destination filename same as source filename + ], + }), + ], +}; diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs new file mode 100644 index 0000000000000..5bf9a18f8c6e7 --- /dev/null +++ b/datafusion/wasmtest/src/lib.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +extern crate wasm_bindgen; + +use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_expr::lit; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use datafusion_physical_expr::execution_props::ExecutionProps; +use datafusion_sql::sqlparser::dialect::GenericDialect; +use datafusion_sql::sqlparser::parser::Parser; +use std::sync::Arc; +use wasm_bindgen::prelude::*; + +pub fn set_panic_hook() { + // When the `console_error_panic_hook` feature is enabled, we can call the + // `set_panic_hook` function at least once during initialization, and then + // we will get better error messages if our code ever panics. + // + // For more details see + // https://github.com/rustwasm/console_error_panic_hook#readme + #[cfg(feature = "console_error_panic_hook")] + console_error_panic_hook::set_once(); +} + +/// Make console.log available as the log Rust function +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(js_namespace = console)] + fn log(s: &str); +} + +#[wasm_bindgen] +pub fn try_datafusion() { + set_panic_hook(); + // Create a scalar value (from datafusion-common) + let scalar = ScalarValue::from("Hello, World!"); + log(&format!("ScalarValue: {scalar:?}")); + + // Create an Expr (from datafusion-expr) + let expr = lit(28) + lit(72); + log(&format!("Expr: {expr:?}")); + + // Simplify Expr (using datafusion-phys-expr and datafusion-optimizer) + let schema = Arc::new(DFSchema::empty()); + let execution_props = ExecutionProps::new(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&execution_props).with_schema(schema)); + let simplified_expr = simplifier.simplify(expr).unwrap(); + log(&format!("Simplified Expr: {simplified_expr:?}")); + + // Parse SQL (using datafusion-sql) + let sql = "SELECT 2 + 37"; + let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... + let ast = Parser::parse_sql(&dialect, sql).unwrap(); + log(&format!("Parsed SQL: {ast:?}")); +} diff --git a/dev/changelog/27.0.0.md b/dev/changelog/27.0.0.md new file mode 100644 index 0000000000000..305e238b88611 --- /dev/null +++ b/dev/changelog/27.0.0.md @@ -0,0 +1,203 @@ + + +## [27.0.0](https://github.com/apache/arrow-datafusion/tree/27.0.0) (2023-06-26) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/26.0.0...27.0.0) + +**Breaking changes:** + +- Remove `avro_to_arrow::reader::Reader::next` in favor of `Iterator` implementation. [#6538](https://github.com/apache/arrow-datafusion/pull/6538) (LouisGariepy) +- Add support for appending data to external tables - CSV [#6526](https://github.com/apache/arrow-datafusion/pull/6526) (mustafasrepo) +- Move `physical_plan::file_format` to `datasource::plan` [#6516](https://github.com/apache/arrow-datafusion/pull/6516) (alamb) +- Remove `FromSlice` in favor of `From` impl in upstream arrow-rs code [#6587](https://github.com/apache/arrow-datafusion/pull/6587) (alamb) +- Improve main api doc page, move `avro_to_arrow` to `datasource` [#6564](https://github.com/apache/arrow-datafusion/pull/6564) (alamb) +- Fix Clippy module inception (unwrap `datasource::datasource` and `catalog::catalog` [#6640](https://github.com/apache/arrow-datafusion/pull/6640) (LouisGariepy) +- refactor: unify generic expr rewrite functions into the `datafusion_expr::expr_rewriter` [#6644](https://github.com/apache/arrow-datafusion/pull/6644) (r4ntix) +- Move `PhysicalPlanner` to `physical_planer` module [#6570](https://github.com/apache/arrow-datafusion/pull/6570) (alamb) +- Update documentation for creating User Defined Aggregates (AggregateUDF) [#6729](https://github.com/apache/arrow-datafusion/pull/6729) (alamb) +- Support User Defined Window Functions [#6703](https://github.com/apache/arrow-datafusion/pull/6703) (alamb) +- Minor: Move `PartitionStream` to physical_plan [#6756](https://github.com/apache/arrow-datafusion/pull/6756) (alamb) + +**Implemented enhancements:** + +- feat: support type coercion in Parquet Reader [#6458](https://github.com/apache/arrow-datafusion/pull/6458) (e1ijah1) +- feat: New functions and operations for working with arrays [#6384](https://github.com/apache/arrow-datafusion/pull/6384) (izveigor) +- feat: `DISTINCT` bitwise and boolean aggregate functions [#6581](https://github.com/apache/arrow-datafusion/pull/6581) (izveigor) +- feat: make_array support empty arguments [#6593](https://github.com/apache/arrow-datafusion/pull/6593) (parkma99) +- feat: encapsulate physical optimizer rules into a struct [#6645](https://github.com/apache/arrow-datafusion/pull/6645) (waynexia) +- feat: new concatenation operator for working with arrays [#6615](https://github.com/apache/arrow-datafusion/pull/6615) (izveigor) +- feat: add `-c option` to pass the SQL query directly as an argument on datafusion-cli [#6765](https://github.com/apache/arrow-datafusion/pull/6765) (r4ntix) + +**Fixed bugs:** + +- fix: ignore panics if racing against catalog/schema changes [#6536](https://github.com/apache/arrow-datafusion/pull/6536) (Weijun-H) +- fix: type coercion support date - date [#6578](https://github.com/apache/arrow-datafusion/pull/6578) (jackwener) +- fix: avoid panic in `list_files_for_scan` [#6605](https://github.com/apache/arrow-datafusion/pull/6605) (Folyd) +- fix: analyze/optimize plan in `CREATE TABLE AS SELECT` [#6610](https://github.com/apache/arrow-datafusion/pull/6610) (jackwener) +- fix: remove type coercion of case expression in Expr::Schema [#6614](https://github.com/apache/arrow-datafusion/pull/6614) (jackwener) +- fix: correct test timestamp_add_interval_months [#6622](https://github.com/apache/arrow-datafusion/pull/6622) (jackwener) +- fix: fix more panics in `ListingTable` [#6636](https://github.com/apache/arrow-datafusion/pull/6636) (Folyd) +- fix: median with even number of `Decimal128` not working [#6634](https://github.com/apache/arrow-datafusion/pull/6634) (izveigor) +- fix: port unstable subquery to sqllogicaltest [#6659](https://github.com/apache/arrow-datafusion/pull/6659) (jackwener) +- fix: correct wrong test [#6667](https://github.com/apache/arrow-datafusion/pull/6667) (jackwener) +- fix: from_plan shouldn't use original schema [#6595](https://github.com/apache/arrow-datafusion/pull/6595) (jackwener) +- fix: correct the error type [#6712](https://github.com/apache/arrow-datafusion/pull/6712) (jackwener) +- fix: parser for negative intervals [#6698](https://github.com/apache/arrow-datafusion/pull/6698) (izveigor) + +**Documentation updates:** + +- Minor: Fix doc for round function [#6661](https://github.com/apache/arrow-datafusion/pull/6661) (viirya) +- Docs: Improve documentation for `struct` function` [#6754](https://github.com/apache/arrow-datafusion/pull/6754) (alamb) + +**Merged pull requests:** + +- fix: ignore panics if racing against catalog/schema changes [#6536](https://github.com/apache/arrow-datafusion/pull/6536) (Weijun-H) +- Remove `avro_to_arrow::reader::Reader::next` in favor of `Iterator` implementation. [#6538](https://github.com/apache/arrow-datafusion/pull/6538) (LouisGariepy) +- Support ordering analysis with expressions (not just columns) by Replace `OrderedColumn` with `PhysicalSortExpr` [#6501](https://github.com/apache/arrow-datafusion/pull/6501) (mustafasrepo) +- Prepare for 26.0.0 release [#6533](https://github.com/apache/arrow-datafusion/pull/6533) (andygrove) +- fix Incorrect function-name matching with disabled enable_ident_normalization [#6528](https://github.com/apache/arrow-datafusion/pull/6528) (parkma99) +- Improve error messages with function name suggestion. [#6520](https://github.com/apache/arrow-datafusion/pull/6520) (2010YOUY01) +- Docs: add more PR guidance in contributing guide (smaller PRs) [#6546](https://github.com/apache/arrow-datafusion/pull/6546) (alamb) +- feat: support type coercion in Parquet Reader [#6458](https://github.com/apache/arrow-datafusion/pull/6458) (e1ijah1) +- Update to object_store 0.6 and arrow 41 [#6374](https://github.com/apache/arrow-datafusion/pull/6374) (tustvold) +- feat: New functions and operations for working with arrays [#6384](https://github.com/apache/arrow-datafusion/pull/6384) (izveigor) +- Add support for appending data to external tables - CSV [#6526](https://github.com/apache/arrow-datafusion/pull/6526) (mustafasrepo) +- [Minor] Update hashbrown to 0.14 [#6562](https://github.com/apache/arrow-datafusion/pull/6562) (Dandandan) +- refactor: use bitwise and boolean compute functions [#6568](https://github.com/apache/arrow-datafusion/pull/6568) (izveigor) +- Fix panic propagation in `CoalescePartitions`, consolidates panic propagation into `RecordBatchReceiverStream` [#6507](https://github.com/apache/arrow-datafusion/pull/6507) (alamb) +- Move `physical_plan::file_format` to `datasource::plan` [#6516](https://github.com/apache/arrow-datafusion/pull/6516) (alamb) +- refactor: remove type_coercion in PhysicalExpr. [#6575](https://github.com/apache/arrow-datafusion/pull/6575) (jackwener) +- Minor: remove `tokio_stream` dependency [#6565](https://github.com/apache/arrow-datafusion/pull/6565) (alamb) +- minor: remove useless mut and borrow() [#6580](https://github.com/apache/arrow-datafusion/pull/6580) (jackwener) +- Add tests for object_store builders of datafusion-cli [#6576](https://github.com/apache/arrow-datafusion/pull/6576) (r4ntix) +- Avoid per-batch field lookups in SchemaMapping [#6563](https://github.com/apache/arrow-datafusion/pull/6563) (tustvold) +- Move `JoinType` and `JoinCondition` to `datafusion_common` [#6572](https://github.com/apache/arrow-datafusion/pull/6572) (alamb) +- chore(deps): update substrait requirement from 0.10.0 to 0.11.0 [#6579](https://github.com/apache/arrow-datafusion/pull/6579) (dependabot[bot]) +- refactor: bitwise kernel right and left shifts [#6585](https://github.com/apache/arrow-datafusion/pull/6585) (izveigor) +- fix: type coercion support date - date [#6578](https://github.com/apache/arrow-datafusion/pull/6578) (jackwener) +- make page filter public [#6523](https://github.com/apache/arrow-datafusion/pull/6523) (jiacai2050) +- Minor: Remove some `use crate::` uses in physical_plan [#6573](https://github.com/apache/arrow-datafusion/pull/6573) (alamb) +- feat: `DISTINCT` bitwise and boolean aggregate functions [#6581](https://github.com/apache/arrow-datafusion/pull/6581) (izveigor) +- Make the struct function return the correct data type. [#6594](https://github.com/apache/arrow-datafusion/pull/6594) (jiangzhx) +- fix: avoid panic in `list_files_for_scan` [#6605](https://github.com/apache/arrow-datafusion/pull/6605) (Folyd) +- fix: analyze/optimize plan in `CREATE TABLE AS SELECT` [#6610](https://github.com/apache/arrow-datafusion/pull/6610) (jackwener) +- Minor: Add additional docstrings to Window function implementations [#6592](https://github.com/apache/arrow-datafusion/pull/6592) (alamb) +- Remove `FromSlice` in favor of `From` impl in upstream arrow-rs code [#6587](https://github.com/apache/arrow-datafusion/pull/6587) (alamb) +- [Minor] Cleanup tpch benchmark [#6609](https://github.com/apache/arrow-datafusion/pull/6609) (Dandandan) +- Revert "feat: Implement the bitwise_not in NotExpr (#5902)" [#6599](https://github.com/apache/arrow-datafusion/pull/6599) (jackwener) +- Port remaining tests in functions.rs to sqllogictest [#6608](https://github.com/apache/arrow-datafusion/pull/6608) (jiangzhx) +- fix: remove type coercion of case expression in Expr::Schema [#6614](https://github.com/apache/arrow-datafusion/pull/6614) (jackwener) +- Minor: use upstream `dialect_from_str` [#6616](https://github.com/apache/arrow-datafusion/pull/6616) (alamb) +- Minor: Move `PlanType`, `StringifiedPlan` and `ToStringifiedPlan` `datafusion_common` [#6571](https://github.com/apache/arrow-datafusion/pull/6571) (alamb) +- fix: correct test timestamp_add_interval_months [#6622](https://github.com/apache/arrow-datafusion/pull/6622) (jackwener) +- Impl `Literal` trait for `NonZero*` types [#6627](https://github.com/apache/arrow-datafusion/pull/6627) (Folyd) +- style: make clippy happy and remove redundant prefix [#6624](https://github.com/apache/arrow-datafusion/pull/6624) (jackwener) +- Substrait: Fix incorrect join key fields (indices) when same table is being used more than once [#6135](https://github.com/apache/arrow-datafusion/pull/6135) (nseekhao) +- Minor: Add debug logging for schema mismatch errors [#6626](https://github.com/apache/arrow-datafusion/pull/6626) (alamb) +- Minor: Move functionality into `BuildInScalarFunction` [#6612](https://github.com/apache/arrow-datafusion/pull/6612) (alamb) +- Add datafusion-cli tests to the CI Job [#6600](https://github.com/apache/arrow-datafusion/pull/6600) (r4ntix) +- Refactor joins test to sqllogic [#6525](https://github.com/apache/arrow-datafusion/pull/6525) (aprimadi) +- fix: fix more panics in `ListingTable` [#6636](https://github.com/apache/arrow-datafusion/pull/6636) (Folyd) +- fix: median with even number of `Decimal128` not working [#6634](https://github.com/apache/arrow-datafusion/pull/6634) (izveigor) +- Unify formatting of both groups and files up to 5 elements [#6637](https://github.com/apache/arrow-datafusion/pull/6637) (qrilka) +- feat: make_array support empty arguments [#6593](https://github.com/apache/arrow-datafusion/pull/6593) (parkma99) +- Minor: cleanup the unnecessary CREATE TABLE aggregate_test_100 statement at aggregate.slt [#6641](https://github.com/apache/arrow-datafusion/pull/6641) (jiangzhx) +- chore(deps): update sqllogictest requirement from 0.13.2 to 0.14.0 [#6646](https://github.com/apache/arrow-datafusion/pull/6646) (dependabot[bot]) +- Improve main api doc page, move `avro_to_arrow` to `datasource` [#6564](https://github.com/apache/arrow-datafusion/pull/6564) (alamb) +- Minor: Move `include_rank` into `BuiltInWindowFunctionExpr` [#6620](https://github.com/apache/arrow-datafusion/pull/6620) (alamb) +- Prioritize UDF over scalar built-in function in case of function name… [#6601](https://github.com/apache/arrow-datafusion/pull/6601) (epsio-banay) +- feat: encapsulate physical optimizer rules into a struct [#6645](https://github.com/apache/arrow-datafusion/pull/6645) (waynexia) +- Fix date_trunc signature [#6632](https://github.com/apache/arrow-datafusion/pull/6632) (alamb) +- Return correct scalar types for date_trunc [#6638](https://github.com/apache/arrow-datafusion/pull/6638) (viirya) +- Insert supports specifying column names in any order [#6628](https://github.com/apache/arrow-datafusion/pull/6628) (jonahgao) +- Fix Clippy module inception (unwrap `datasource::datasource` and `catalog::catalog` [#6640](https://github.com/apache/arrow-datafusion/pull/6640) (LouisGariepy) +- Add hash support for PhysicalExpr and PhysicalSortExpr [#6625](https://github.com/apache/arrow-datafusion/pull/6625) (mustafasrepo) +- Port tests in joins.rs to sqllogictes [#6642](https://github.com/apache/arrow-datafusion/pull/6642) (jiangzhx) +- Minor: Add test for date_trunc schema on scalars [#6655](https://github.com/apache/arrow-datafusion/pull/6655) (alamb) +- Simplify and encapsulate window function state management [#6621](https://github.com/apache/arrow-datafusion/pull/6621) (alamb) +- Minor: Move get_equal_orderings into `BuiltInWindowFunctionExpr`, remove `BuiltInWindowFunctionExpr::as_any` [#6619](https://github.com/apache/arrow-datafusion/pull/6619) (alamb) +- minor: use sql to setup test data for joins.slt rather than rust [#6656](https://github.com/apache/arrow-datafusion/pull/6656) (alamb) +- Support wider range of Subquery, handle the Count bug [#6457](https://github.com/apache/arrow-datafusion/pull/6457) (mingmwang) +- fix: port unstable subquery to sqllogicaltest [#6659](https://github.com/apache/arrow-datafusion/pull/6659) (jackwener) +- Minor: Fix doc for round function [#6661](https://github.com/apache/arrow-datafusion/pull/6661) (viirya) +- refactor: unify generic expr rewrite functions into the `datafusion_expr::expr_rewriter` [#6644](https://github.com/apache/arrow-datafusion/pull/6644) (r4ntix) +- Minor: add test cases for coercion bitwise shifts [#6651](https://github.com/apache/arrow-datafusion/pull/6651) (izveigor) +- refactor: unify replace count(\*) analyzer by removing it in sql crate [#6660](https://github.com/apache/arrow-datafusion/pull/6660) (jackwener) +- Combine evaluate_stateful and evaluate_inside_range [#6665](https://github.com/apache/arrow-datafusion/pull/6665) (mustafasrepo) +- Support internal cast for BuiltinScalarFunction::MakeArray [#6607](https://github.com/apache/arrow-datafusion/pull/6607) (jayzhan211) +- minor: use sql to setup test data for aggregate.slt rather than rust [#6664](https://github.com/apache/arrow-datafusion/pull/6664) (jiangzhx) +- Minor: Add tests for User Defined Aggregate functions [#6669](https://github.com/apache/arrow-datafusion/pull/6669) (alamb) +- fix: correct wrong test [#6667](https://github.com/apache/arrow-datafusion/pull/6667) (jackwener) +- fix: from_plan shouldn't use original schema [#6595](https://github.com/apache/arrow-datafusion/pull/6595) (jackwener) +- feat: new concatenation operator for working with arrays [#6615](https://github.com/apache/arrow-datafusion/pull/6615) (izveigor) +- Minor: Add more doc strings to WindowExpr [#6663](https://github.com/apache/arrow-datafusion/pull/6663) (alamb) +- minor: `with_new_inputs` replace `from_plan` [#6680](https://github.com/apache/arrow-datafusion/pull/6680) (jackwener) +- Docs: Update roadmap to point at EPIC's, clarify project goals [#6639](https://github.com/apache/arrow-datafusion/pull/6639) (alamb) +- Disable incremental compilation on CI [#6688](https://github.com/apache/arrow-datafusion/pull/6688) (alamb) +- Allow `AggregateUDF` to define retractable batch , implement sliding window functions [#6671](https://github.com/apache/arrow-datafusion/pull/6671) (alamb) +- Minor: Update user guide [#6692](https://github.com/apache/arrow-datafusion/pull/6692) (comphead) +- Minor: consolidate repartition test into sql_integration to save builder space and build time [#6685](https://github.com/apache/arrow-datafusion/pull/6685) (alamb) +- Minor: combine `statistics`, `filter_pushdown` and `custom_sources provider` tests together to reduce CI disk space [#6683](https://github.com/apache/arrow-datafusion/pull/6683) (alamb) +- Move `PhysicalPlanner` to `physical_planer` module [#6570](https://github.com/apache/arrow-datafusion/pull/6570) (alamb) +- Rename integration tests to match crate they are defined in [#6687](https://github.com/apache/arrow-datafusion/pull/6687) (alamb) +- Minor: combine fuzz tests into a single binary to save builder space and build time [#6684](https://github.com/apache/arrow-datafusion/pull/6684) (alamb) +- Minor: consolidate datafusion_substrait tests into `substrait_integration` to save builder space and build time #6685 [#6686](https://github.com/apache/arrow-datafusion/pull/6686) (alamb) +- removed self.all_values.len() from inside reserve [#6689](https://github.com/apache/arrow-datafusion/pull/6689) (BryanEmond) +- Replace supports_bounded_execution with supports_retract_batch [#6695](https://github.com/apache/arrow-datafusion/pull/6695) (mustafasrepo) +- Move `dataframe` and `dataframe_functon` into `core_integration` test binary [#6697](https://github.com/apache/arrow-datafusion/pull/6697) (alamb) +- refactor: fix clippy allow too many arguments [#6705](https://github.com/apache/arrow-datafusion/pull/6705) (aprimadi) +- Fix documentation typo [#6704](https://github.com/apache/arrow-datafusion/pull/6704) (aprimadi) +- fix: correct the error type [#6712](https://github.com/apache/arrow-datafusion/pull/6712) (jackwener) +- Port test in subqueries.rs from rust to sqllogictest [#6675](https://github.com/apache/arrow-datafusion/pull/6675) (jiangzhx) +- Improve performance/memory usage of HashJoin datastructure (5-15% improvement on selected TPC-H queries) [#6679](https://github.com/apache/arrow-datafusion/pull/6679) (Dandandan) +- refactor: alias() should skip add alias for `Expr::Sort` [#6707](https://github.com/apache/arrow-datafusion/pull/6707) (jackwener) +- chore(deps): update strum/strum_macros requirement from 0.24 to 0.25 [#6717](https://github.com/apache/arrow-datafusion/pull/6717) (jackwener) +- Move alias generator to per-query execution props [#6706](https://github.com/apache/arrow-datafusion/pull/6706) (aprimadi) +- fix: parser for negative intervals [#6698](https://github.com/apache/arrow-datafusion/pull/6698) (izveigor) +- Minor: Improve UX for setting `ExecutionProps::query_execution_start_time` [#6719](https://github.com/apache/arrow-datafusion/pull/6719) (alamb) +- add Eq and PartialEq to ListingTableUrl [#6725](https://github.com/apache/arrow-datafusion/pull/6725) (fsdvh) +- Support Expr::InList to Substrait::RexType [#6604](https://github.com/apache/arrow-datafusion/pull/6604) (jayzhan211) +- MINOR: Add maintains input order flag to CoalesceBatches [#6730](https://github.com/apache/arrow-datafusion/pull/6730) (mustafasrepo) +- Minor: Update copyight date on website [#6727](https://github.com/apache/arrow-datafusion/pull/6727) (alamb) +- Display all partitions and files in EXPLAIN VERBOSE [#6711](https://github.com/apache/arrow-datafusion/pull/6711) (qrilka) +- Update `arrow`, `arrow-flight` and `parquet` to `42.0.0` [#6702](https://github.com/apache/arrow-datafusion/pull/6702) (alamb) +- Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate [#6690](https://github.com/apache/arrow-datafusion/pull/6690) (alamb) +- Hash Join Vectorized collision checking [#6724](https://github.com/apache/arrow-datafusion/pull/6724) (Dandandan) +- Return null for date_trunc(null) instead of panic [#6723](https://github.com/apache/arrow-datafusion/pull/6723) (BryanEmond) +- `derive(Debug)` for `Expr` [#6708](https://github.com/apache/arrow-datafusion/pull/6708) (parkma99) +- refactor: extract merge_projection common function. [#6735](https://github.com/apache/arrow-datafusion/pull/6735) (jackwener) +- Fix up some `DataFusionError::Internal` errors with correct type [#6721](https://github.com/apache/arrow-datafusion/pull/6721) (alamb) +- Minor: remove some uses of unwrap [#6738](https://github.com/apache/arrow-datafusion/pull/6738) (alamb) +- Minor: remove dead code with decimal datatypes from `in_list` [#6737](https://github.com/apache/arrow-datafusion/pull/6737) (izveigor) +- Update documentation for creating User Defined Aggregates (AggregateUDF) [#6729](https://github.com/apache/arrow-datafusion/pull/6729) (alamb) +- Support User Defined Window Functions [#6703](https://github.com/apache/arrow-datafusion/pull/6703) (alamb) +- MINOR: Aggregate ordering substrait support [#6745](https://github.com/apache/arrow-datafusion/pull/6745) (mustafasrepo) +- chore(deps): update itertools requirement from 0.10 to 0.11 [#6752](https://github.com/apache/arrow-datafusion/pull/6752) (jackwener) +- refactor: move some code in physical_plan/common.rs before tests module [#6749](https://github.com/apache/arrow-datafusion/pull/6749) (aprimadi) +- Add support for order-sensitive aggregation for multipartitions [#6734](https://github.com/apache/arrow-datafusion/pull/6734) (mustafasrepo) +- Update sqlparser-rs to version `0.35.0` [#6753](https://github.com/apache/arrow-datafusion/pull/6753) (alamb) +- Docs: Update SQL status page [#6736](https://github.com/apache/arrow-datafusion/pull/6736) (alamb) +- fix typo [#6761](https://github.com/apache/arrow-datafusion/pull/6761) (Weijun-H) +- Minor: Move `PartitionStream` to physical_plan [#6756](https://github.com/apache/arrow-datafusion/pull/6756) (alamb) +- Docs: Improve documentation for `struct` function` [#6754](https://github.com/apache/arrow-datafusion/pull/6754) (alamb) +- add UT to verify the fix on "issues/6606" [#6762](https://github.com/apache/arrow-datafusion/pull/6762) (mingmwang) +- Re-export modules individually to fix rustdocs [#6757](https://github.com/apache/arrow-datafusion/pull/6757) (alamb) +- Order Preserving RepartitionExec Implementation [#6742](https://github.com/apache/arrow-datafusion/pull/6742) (mustafasrepo) +- feat: add `-c option` to pass the SQL query directly as an argument on datafusion-cli [#6765](https://github.com/apache/arrow-datafusion/pull/6765) (r4ntix) diff --git a/dev/changelog/28.0.0.md b/dev/changelog/28.0.0.md new file mode 100644 index 0000000000000..a51427be5c345 --- /dev/null +++ b/dev/changelog/28.0.0.md @@ -0,0 +1,194 @@ + + +## [28.0.0](https://github.com/apache/arrow-datafusion/tree/28.0.0) (2023-07-21) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/27.0.0...28.0.0) + +**Breaking changes:** + +- Cleanup type coercion (#3419) [#6778](https://github.com/apache/arrow-datafusion/pull/6778) (tustvold) +- refactor: encapsulate Alias as a struct [#6795](https://github.com/apache/arrow-datafusion/pull/6795) (jackwener) +- Set `DisplayAs` to be a supertrait of `ExecutionPlan` [#6835](https://github.com/apache/arrow-datafusion/pull/6835) (qrilka) +- [MINOR] Remove unnecessary api from MemTable [#6861](https://github.com/apache/arrow-datafusion/pull/6861) (metesynnada) +- refactor: Merge Expr::Like and Expr::ILike [#7007](https://github.com/apache/arrow-datafusion/pull/7007) (waynexia) + +**Implemented enhancements:** + +- feat: `array_contains` [#6618](https://github.com/apache/arrow-datafusion/pull/6618) (izveigor) +- feat: support `NULL` in array functions [#6662](https://github.com/apache/arrow-datafusion/pull/6662) (izveigor) +- feat: implement posgres style `encode`/`decode` [#6821](https://github.com/apache/arrow-datafusion/pull/6821) (ozgrakkurt) +- feat: column support for `array_append`, `array_prepend`, `array_position` and `array_positions` [#6805](https://github.com/apache/arrow-datafusion/pull/6805) (izveigor) +- feat: preserve metadata for `Field` and `Schema` in proto [#6865](https://github.com/apache/arrow-datafusion/pull/6865) (jonahgao) +- feat: Add graphviz display format for execution plan. [#6726](https://github.com/apache/arrow-datafusion/pull/6726) (liurenjie1024) +- feat: implement substrait join filter support [#6868](https://github.com/apache/arrow-datafusion/pull/6868) (nseekhao) +- feat: column support for `array_dims`, `array_ndims`, `cardinality` and `array_length` [#6864](https://github.com/apache/arrow-datafusion/pull/6864) (izveigor) +- feat: support for `NestedLoopJoinExec` in datafusion-proto [#6902](https://github.com/apache/arrow-datafusion/pull/6902) (r4ntix) +- feat: add round trip test of physical plan in tpch unit tests [#6918](https://github.com/apache/arrow-datafusion/pull/6918) (r4ntix) +- feat: implement substrait for LIKE/ILIKE expr [#6840](https://github.com/apache/arrow-datafusion/pull/6840) (waynexia) +- feat: array functions treat an array as an element [#6986](https://github.com/apache/arrow-datafusion/pull/6986) (izveigor) + +**Fixed bugs:** + +- fix: incorrect nullability of `between` expr [#6786](https://github.com/apache/arrow-datafusion/pull/6786) (jonahgao) +- fix: incorrect nullability of `InList` expr [#6799](https://github.com/apache/arrow-datafusion/pull/6799) (jonahgao) +- fix: from_plan generate Agg can be with different schema. [#6820](https://github.com/apache/arrow-datafusion/pull/6820) (jackwener) +- fix: incorrect nullability of `Like` expressions [#6829](https://github.com/apache/arrow-datafusion/pull/6829) (jonahgao) +- fix: incorrect simplification of case expr [#7006](https://github.com/apache/arrow-datafusion/pull/7006) (jonahgao) +- fix: `array_concat` with arrays with different dimensions, add `_list*` aliases for `_array*` functions [#7008](https://github.com/apache/arrow-datafusion/pull/7008) (izveigor) + +**Documentation updates:** + +- docs: Add `encode` and `decode` to the user guide [#6856](https://github.com/apache/arrow-datafusion/pull/6856) (alamb) + +**Merged pull requests:** + +- chore(deps): update indexmap requirement from 1.9.2 to 2.0.0 [#6766](https://github.com/apache/arrow-datafusion/pull/6766) (dependabot[bot]) +- Support IsDistinctFrom and IsNotDistinctFrom on interval types [#6776](https://github.com/apache/arrow-datafusion/pull/6776) (joroKr21) +- Protect main branch [#6775](https://github.com/apache/arrow-datafusion/pull/6775) (tustvold) +- Prepare 27.0.0 release [#6773](https://github.com/apache/arrow-datafusion/pull/6773) (andygrove) +- Support hex string literal [#6767](https://github.com/apache/arrow-datafusion/pull/6767) (ShiKaiWi) +- feat: `array_contains` [#6618](https://github.com/apache/arrow-datafusion/pull/6618) (izveigor) +- Make 'date_trunc' returns the same type as its input [#6654](https://github.com/apache/arrow-datafusion/pull/6654) (Weijun-H) +- Fix inserting into a table with non-nullable columns [#6722](https://github.com/apache/arrow-datafusion/pull/6722) (jonahgao) +- Cleanup type coercion (#3419) [#6778](https://github.com/apache/arrow-datafusion/pull/6778) (tustvold) +- Properly project grouping set expressions [#6777](https://github.com/apache/arrow-datafusion/pull/6777) (fsdvh) +- Minor: Simplify `date_trunc` code and add comments [#6783](https://github.com/apache/arrow-datafusion/pull/6783) (alamb) +- Minor: Add array / array sqllogic tests for `array_contains` [#6771](https://github.com/apache/arrow-datafusion/pull/6771) (alamb) +- Minor: Make `date_trunc` code easier to understand [#6789](https://github.com/apache/arrow-datafusion/pull/6789) (alamb) +- feat: support `NULL` in array functions [#6662](https://github.com/apache/arrow-datafusion/pull/6662) (izveigor) +- fix: incorrect nullability of `between` expr [#6786](https://github.com/apache/arrow-datafusion/pull/6786) (jonahgao) +- Use checked division kernel [#6792](https://github.com/apache/arrow-datafusion/pull/6792) (tustvold) +- Minor: add sqllogictests for binary data type [#6770](https://github.com/apache/arrow-datafusion/pull/6770) (alamb) +- refactor: encapsulate Alias as a struct [#6795](https://github.com/apache/arrow-datafusion/pull/6795) (jackwener) +- chore(deps): bump actions/labeler from 4.1.0 to 4.2.0 [#6803](https://github.com/apache/arrow-datafusion/pull/6803) (dependabot[bot]) +- Consistently coerce dictionaries for arithmetic [#6785](https://github.com/apache/arrow-datafusion/pull/6785) (tustvold) +- Implement serialization for UDWF and UDAF in plan protobuf [#6769](https://github.com/apache/arrow-datafusion/pull/6769) (parkma99) +- fix: incorrect nullability of `InList` expr [#6799](https://github.com/apache/arrow-datafusion/pull/6799) (jonahgao) +- Fix timestamp_add_interval_months to pass any date [#6815](https://github.com/apache/arrow-datafusion/pull/6815) (jayzhan211) +- Minor: Log TPCH benchmark results [#6813](https://github.com/apache/arrow-datafusion/pull/6813) (alamb) +- Refactor Decimal128 averaging code to be vectorizable (and easier to read) [#6810](https://github.com/apache/arrow-datafusion/pull/6810) (alamb) +- Minor: Encapsulate `return_type` and `signature` in `AggregateFunction` and `WindowFunction` [#6748](https://github.com/apache/arrow-datafusion/pull/6748) (alamb) +- fix: from_plan generate Agg can be with different schema. [#6820](https://github.com/apache/arrow-datafusion/pull/6820) (jackwener) +- [MINOR] Improve performance of `create_hashes` [#6816](https://github.com/apache/arrow-datafusion/pull/6816) (Dandandan) +- Add fetch to `SortPreservingMergeExec` and `SortPreservingMergeStream` [#6811](https://github.com/apache/arrow-datafusion/pull/6811) (Dandandan) +- chore(deps): update substrait requirement from 0.11.0 to 0.12.0 [#6825](https://github.com/apache/arrow-datafusion/pull/6825) (dependabot[bot]) +- Upgrade arrow 43 [#6812](https://github.com/apache/arrow-datafusion/pull/6812) (tustvold) +- Fix cargo build warning [#6831](https://github.com/apache/arrow-datafusion/pull/6831) (viirya) +- Simplify `IsUnkown` and `IsNotUnkown` expression [#6830](https://github.com/apache/arrow-datafusion/pull/6830) (jonahgao) +- fix: incorrect nullability of `Like` expressions [#6829](https://github.com/apache/arrow-datafusion/pull/6829) (jonahgao) +- Minor: Add one more assert to `hash_array_primitive` [#6834](https://github.com/apache/arrow-datafusion/pull/6834) (alamb) +- revert #6595 #6820 [#6827](https://github.com/apache/arrow-datafusion/pull/6827) (jackwener) +- Add Duration to ScalarValue [#6838](https://github.com/apache/arrow-datafusion/pull/6838) (tustvold) +- Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet [#6750](https://github.com/apache/arrow-datafusion/pull/6750) (aprimadi) +- Add clickbench queries to sqllogictest coverage [#6836](https://github.com/apache/arrow-datafusion/pull/6836) (alamb) +- feat: implement posgres style `encode`/`decode` [#6821](https://github.com/apache/arrow-datafusion/pull/6821) (ozgrakkurt) +- chore(deps): update rstest requirement from 0.17.0 to 0.18.0 [#6847](https://github.com/apache/arrow-datafusion/pull/6847) (dependabot[bot]) +- [minior] support serde for some function [#6846](https://github.com/apache/arrow-datafusion/pull/6846) (liukun4515) +- Support fixed_size_list for make_array [#6759](https://github.com/apache/arrow-datafusion/pull/6759) (jayzhan211) +- Improve median performance. [#6837](https://github.com/apache/arrow-datafusion/pull/6837) (vincev) +- Mismatch in MemTable of Select Into when projecting on aggregate window functions [#6566](https://github.com/apache/arrow-datafusion/pull/6566) (berkaysynnada) +- feat: column support for `array_append`, `array_prepend`, `array_position` and `array_positions` [#6805](https://github.com/apache/arrow-datafusion/pull/6805) (izveigor) +- MINOR: Fix ordering of the aggregate_source_with_order table [#6852](https://github.com/apache/arrow-datafusion/pull/6852) (mustafasrepo) +- Return error when internal multiplication overflowing in decimal division kernel [#6833](https://github.com/apache/arrow-datafusion/pull/6833) (viirya) +- Deprecate ScalarValue::and, ScalarValue::or (#6842) [#6844](https://github.com/apache/arrow-datafusion/pull/6844) (tustvold) +- chore(deps): update bigdecimal requirement from 0.3.0 to 0.4.0 [#6848](https://github.com/apache/arrow-datafusion/pull/6848) (dependabot[bot]) +- feat: preserve metadata for `Field` and `Schema` in proto [#6865](https://github.com/apache/arrow-datafusion/pull/6865) (jonahgao) +- Set `DisplayAs` to be a supertrait of `ExecutionPlan` [#6835](https://github.com/apache/arrow-datafusion/pull/6835) (qrilka) +- [MINOR] Remove unnecessary api from MemTable [#6861](https://github.com/apache/arrow-datafusion/pull/6861) (metesynnada) +- Adjustment of HashJoinExec APIs to Preserve Probe Side Order [#6858](https://github.com/apache/arrow-datafusion/pull/6858) (metesynnada) +- [MINOR] Adding order into StreamingTableExec [#6860](https://github.com/apache/arrow-datafusion/pull/6860) (metesynnada) +- Docs: try and clarify what `PartitionEvaluator` functions are called [#6869](https://github.com/apache/arrow-datafusion/pull/6869) (alamb) +- docs: Add `encode` and `decode` to the user guide [#6856](https://github.com/apache/arrow-datafusion/pull/6856) (alamb) +- Fix build on main due to logical conflict [#6875](https://github.com/apache/arrow-datafusion/pull/6875) (alamb) +- feat: Add graphviz display format for execution plan. [#6726](https://github.com/apache/arrow-datafusion/pull/6726) (liurenjie1024) +- Fix (another) logical conflict [#6882](https://github.com/apache/arrow-datafusion/pull/6882) (alamb) +- Minor: Consolidate display related traits [#6883](https://github.com/apache/arrow-datafusion/pull/6883) (alamb) +- test: parquet use the byte array as the physical type to store decimal [#6851](https://github.com/apache/arrow-datafusion/pull/6851) (smallzhongfeng) +- Make streaming_merge public [#6874](https://github.com/apache/arrow-datafusion/pull/6874) (kazuyukitanimura) +- Performance: Use a specialized sum accumulator for retractable aggregregates [#6888](https://github.com/apache/arrow-datafusion/pull/6888) (alamb) +- Support array concatenation for arrays with different dimensions [#6872](https://github.com/apache/arrow-datafusion/pull/6872) (jayzhan211) +- feat: implement substrait join filter support [#6868](https://github.com/apache/arrow-datafusion/pull/6868) (nseekhao) +- feat: column support for `array_dims`, `array_ndims`, `cardinality` and `array_length` [#6864](https://github.com/apache/arrow-datafusion/pull/6864) (izveigor) +- Add FixedSizeBinary support to binary_op_dyn_scalar [#6891](https://github.com/apache/arrow-datafusion/pull/6891) (maxburke) +- Minor: deleted duplicated substrait integration test [#6894](https://github.com/apache/arrow-datafusion/pull/6894) (alamb) +- Minor: add test cases with columns for math expressions [#6787](https://github.com/apache/arrow-datafusion/pull/6787) (izveigor) +- Minor: reduce redundant code [#6901](https://github.com/apache/arrow-datafusion/pull/6901) (smallzhongfeng) +- Minor: Add some more doc comments to `BoundedAggregateStream` [#6881](https://github.com/apache/arrow-datafusion/pull/6881) (alamb) +- feat: support for `NestedLoopJoinExec` in datafusion-proto [#6902](https://github.com/apache/arrow-datafusion/pull/6902) (r4ntix) +- Fix `make_array` null handling, update tests [#6900](https://github.com/apache/arrow-datafusion/pull/6900) (alamb) +- chore(deps): bump actions/labeler from 4.2.0 to 4.3.0 [#6911](https://github.com/apache/arrow-datafusion/pull/6911) (dependabot[bot]) +- Minor: Add TPCH scale factor 10 to bench.sh, use 10 iteration [#6893](https://github.com/apache/arrow-datafusion/pull/6893) (alamb) +- Minor: Add output to aggregrate_fuzz.rs on failure [#6905](https://github.com/apache/arrow-datafusion/pull/6905) (alamb) +- allow window UDF to return null [#6915](https://github.com/apache/arrow-datafusion/pull/6915) (mhilton) +- Minor: Add factory method to PartitionedFile to create File Scan [#6909](https://github.com/apache/arrow-datafusion/pull/6909) (comphead) +- [minor]fix doc to remove duplicate content [#6923](https://github.com/apache/arrow-datafusion/pull/6923) (liukun4515) +- Revert "chore(deps): update bigdecimal requirement from 0.3.0 to 0.4.0 (#6848)" [#6896](https://github.com/apache/arrow-datafusion/pull/6896) (alamb) +- [Minor] Make FileScanConfig::project pub [#6931](https://github.com/apache/arrow-datafusion/pull/6931) (Dandandan) +- feat: add round trip test of physical plan in tpch unit tests [#6918](https://github.com/apache/arrow-datafusion/pull/6918) (r4ntix) +- Minor: Use thiserror to implement the `From` trait for `DFSqlLogicTestError` [#6924](https://github.com/apache/arrow-datafusion/pull/6924) (jonahgao) +- parallel csv scan [#6801](https://github.com/apache/arrow-datafusion/pull/6801) (2010YOUY01) +- Add additional test coverage for aggregaes using dates/times/timestamps/decimals [#6939](https://github.com/apache/arrow-datafusion/pull/6939) (alamb) +- Replace repartition execs with sort preserving repartition execs [#6921](https://github.com/apache/arrow-datafusion/pull/6921) (mertak) +- Vectorized hash grouping [#6904](https://github.com/apache/arrow-datafusion/pull/6904) (alamb) +- Fix incorrect results in `BitAnd` GroupsAccumulator [#6957](https://github.com/apache/arrow-datafusion/pull/6957) (alamb) +- Fixes for clippy 1.71 [#6959](https://github.com/apache/arrow-datafusion/pull/6959) (alamb) +- Improve unnest_column performance [#6903](https://github.com/apache/arrow-datafusion/pull/6903) (vincev) +- Pass `schema_infer_max_records` to JsonFormat. [#6945](https://github.com/apache/arrow-datafusion/pull/6945) (vincev) +- deps: bump sqllogictest to 0.15.0 [#6941](https://github.com/apache/arrow-datafusion/pull/6941) (jonahgao) +- Preserve field metadata across expressions in logical plans [#6920](https://github.com/apache/arrow-datafusion/pull/6920) (dexterduck) +- Support equality and comparison between interval arrays and scalars [#6948](https://github.com/apache/arrow-datafusion/pull/6948) (joroKr21) +- chore(deps): update bigdecimal requirement from 0.3.0 to 0.4.1 [#6946](https://github.com/apache/arrow-datafusion/pull/6946) (dependabot[bot]) +- feat: implement substrait for LIKE/ILIKE expr [#6840](https://github.com/apache/arrow-datafusion/pull/6840) (waynexia) +- Minor: Add comments about initial value for `BitAnd` accumulator [#6964](https://github.com/apache/arrow-datafusion/pull/6964) (alamb) +- [Functions] Support Arithmetic function COT() [#6925](https://github.com/apache/arrow-datafusion/pull/6925) (Syleechan) +- Minor: remove duplication in Min/Max accumulator [#6960](https://github.com/apache/arrow-datafusion/pull/6960) (alamb) +- [MINOR]Add new tests [#6953](https://github.com/apache/arrow-datafusion/pull/6953) (mustafasrepo) +- Column support for array concat [#6879](https://github.com/apache/arrow-datafusion/pull/6879) (jayzhan211) +- Minor: Add FixedSizeBinaryTest [#6895](https://github.com/apache/arrow-datafusion/pull/6895) (alamb) +- [MINOR] Remove update state api from PartitionEvaluator [#6966](https://github.com/apache/arrow-datafusion/pull/6966) (mustafasrepo) +- Fix required partitioning of Single aggregation mode [#6950](https://github.com/apache/arrow-datafusion/pull/6950) (Dandandan) +- [MINOR] Remove global sort rule from planner [#6965](https://github.com/apache/arrow-datafusion/pull/6965) (mustafasrepo) +- Column support for array_to_string [#6940](https://github.com/apache/arrow-datafusion/pull/6940) (jayzhan211) +- chore: fix format [#6991](https://github.com/apache/arrow-datafusion/pull/6991) (Weijun-H) +- Extend Ordering Equivalence Support [#6956](https://github.com/apache/arrow-datafusion/pull/6956) (mustafasrepo) +- chore: break earlier in macro `contains!` [#6989](https://github.com/apache/arrow-datafusion/pull/6989) (Weijun-H) +- fix: incorrect simplification of case expr [#7006](https://github.com/apache/arrow-datafusion/pull/7006) (jonahgao) +- Minor: Add String/Binary aggregate tests [#6962](https://github.com/apache/arrow-datafusion/pull/6962) (alamb) +- [MINOR] Supporting repartition joins conf in SHJ [#6998](https://github.com/apache/arrow-datafusion/pull/6998) (metesynnada) +- [MINOR] Code refactor on hash join utils [#6999](https://github.com/apache/arrow-datafusion/pull/6999) (metesynnada) +- feat: array functions treat an array as an element [#6986](https://github.com/apache/arrow-datafusion/pull/6986) (izveigor) +- [MINOR] Moving some test utils from EnsureSorting to test_utils [#7009](https://github.com/apache/arrow-datafusion/pull/7009) (metesynnada) +- MINOR: Bug fix, Use correct ordering equivalence when window expr contains partition by [#7011](https://github.com/apache/arrow-datafusion/pull/7011) (mustafasrepo) +- refactor: Merge Expr::Like and Expr::ILike [#7007](https://github.com/apache/arrow-datafusion/pull/7007) (waynexia) +- Docs: Add docs to `RepartitionExec` and architecture guide [#7003](https://github.com/apache/arrow-datafusion/pull/7003) (alamb) +- Consolidate `BoundedAggregateStream` [#6932](https://github.com/apache/arrow-datafusion/pull/6932) (alamb) +- Minor: Improve aggregate test coverage more [#6952](https://github.com/apache/arrow-datafusion/pull/6952) (alamb) +- Don't store hashes in GroupOrdering [#7029](https://github.com/apache/arrow-datafusion/pull/7029) (tustvold) +- Extract GroupValues (#6969) [#7016](https://github.com/apache/arrow-datafusion/pull/7016) (tustvold) +- Refactor AnalysisContext and statistics() of FilterExec [#6982](https://github.com/apache/arrow-datafusion/pull/6982) (berkaysynnada) +- Fix `datafusion-cli/Dockerfile` to build successfully [#7031](https://github.com/apache/arrow-datafusion/pull/7031) (sarutak) +- functions: support trunc() function with one or two args [#6942](https://github.com/apache/arrow-datafusion/pull/6942) (Syleechan) +- Move the column aliases below the SubqueryAlias [#7035](https://github.com/apache/arrow-datafusion/pull/7035) (jonahgao) +- fix: `array_concat` with arrays with different dimensions, add `_list*` aliases for `_array*` functions [#7008](https://github.com/apache/arrow-datafusion/pull/7008) (izveigor) +- Add support for ClickBench in bench.sh [#7005](https://github.com/apache/arrow-datafusion/pull/7005) (alamb) +- Remove RowAccumulators and datafusion-row [#6968](https://github.com/apache/arrow-datafusion/pull/6968) (alamb) +- Decimal256 coercion [#7034](https://github.com/apache/arrow-datafusion/pull/7034) (jdye64) +- Double RawTable on grow instead of triple [#7041](https://github.com/apache/arrow-datafusion/pull/7041) (tustvold) +- Specialize single column primitive group values [#7043](https://github.com/apache/arrow-datafusion/pull/7043) (tustvold) diff --git a/dev/changelog/29.0.0.md b/dev/changelog/29.0.0.md new file mode 100644 index 0000000000000..6d946eb61cba1 --- /dev/null +++ b/dev/changelog/29.0.0.md @@ -0,0 +1,162 @@ + + +## [29.0.0](https://github.com/apache/arrow-datafusion/tree/29.0.0) (2023-08-11) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/28.0.0...29.0.0) + +**Breaking changes:** + +- change the input_type parameter of the create_udaf function from DataType to Vec [#7096](https://github.com/apache/arrow-datafusion/pull/7096) (jiangzhx) +- Implement `array_slice` and `array_element`, remove `array_trim` [#6936](https://github.com/apache/arrow-datafusion/pull/6936) (izveigor) +- improve the ergonomics of creating field and list array accesses [#7215](https://github.com/apache/arrow-datafusion/pull/7215) (izveigor) +- Update Arrow 45.0.0 And Datum Arithmetic, change Decimal Division semantics [#6832](https://github.com/apache/arrow-datafusion/pull/6832) (tustvold) + +**Implemented enhancements:** + +- feat: support SQL array replacement and removement functions [#7057](https://github.com/apache/arrow-datafusion/pull/7057) (izveigor) +- feat: array containment operator `@>` and `<@` [#6885](https://github.com/apache/arrow-datafusion/pull/6885) (izveigor) +- feat: add sqllogictests crate [#7134](https://github.com/apache/arrow-datafusion/pull/7134) (tshauck) +- feat: allow `datafusion-cli` to accept multiple statements [#7138](https://github.com/apache/arrow-datafusion/pull/7138) (NiwakaDev) +- feat: Add linear regression aggregate functions [#7211](https://github.com/apache/arrow-datafusion/pull/7211) (2010YOUY01) + +**Fixed bugs:** + +- fix: disallow interval - timestamp [#7086](https://github.com/apache/arrow-datafusion/pull/7086) (jackwener) +- fix: Projection columns_map remove name search [#7099](https://github.com/apache/arrow-datafusion/pull/7099) (mustafasrepo) +- fix: fix index bug and add test to check it [#7124](https://github.com/apache/arrow-datafusion/pull/7124) (mustafasrepo) +- fix: Fix panic in filter predicate [#7126](https://github.com/apache/arrow-datafusion/pull/7126) (alamb) +- fix: correct count(\*) alias [#7081](https://github.com/apache/arrow-datafusion/pull/7081) (jackwener) +- fix: skip compression tests on --no-default-features [#7172](https://github.com/apache/arrow-datafusion/pull/7172) (not-my-profile) +- fix: typo in substrait [#7224](https://github.com/apache/arrow-datafusion/pull/7224) (waynexia) + +**Documentation updates:** + +- Add additional links to main README [#7102](https://github.com/apache/arrow-datafusion/pull/7102) (alamb) +- docs: fix broken link [#7177](https://github.com/apache/arrow-datafusion/pull/7177) (SteveLauC) + +**Merged pull requests:** + +- [Minor] Speedup to_array_of_size for Decimal128 [#7055](https://github.com/apache/arrow-datafusion/pull/7055) (Dandandan) +- Replace `array_contains` with SQL array functions: `array_has`, `array_has_any`, `array_has_all` [#6990](https://github.com/apache/arrow-datafusion/pull/6990) (jayzhan211) +- Add more Decimal256 type coercion [#7047](https://github.com/apache/arrow-datafusion/pull/7047) (viirya) +- Create `dfbench`, split up `tpch` benchmark runner into modules [#7054](https://github.com/apache/arrow-datafusion/pull/7054) (alamb) +- chore(deps): update sqlparser requirement from 0.35 to 0.36.1 [#7051](https://github.com/apache/arrow-datafusion/pull/7051) (alamb) +- use ObjectStore for dataframe writes [#6987](https://github.com/apache/arrow-datafusion/pull/6987) (devinjdangelo) +- Prepare 28.0.0 Release [#7056](https://github.com/apache/arrow-datafusion/pull/7056) (andygrove) +- refactor: with_inputs() can use original schema to avoid recompute schema. [#7069](https://github.com/apache/arrow-datafusion/pull/7069) (jackwener) +- Fix cli tests [#7083](https://github.com/apache/arrow-datafusion/pull/7083) (mustafasrepo) +- Ignore blank lines and comments at the end of query files for datafusion-cli [#7076](https://github.com/apache/arrow-datafusion/pull/7076) (sarutak) +- Support case sensitive column for `with_column_renamed` [#7063](https://github.com/apache/arrow-datafusion/pull/7063) (comphead) +- Add Decimal256 to `ScalarValue` [#7048](https://github.com/apache/arrow-datafusion/pull/7048) (viirya) +- Enrich CSV reader config: quote & escape [#6927](https://github.com/apache/arrow-datafusion/pull/6927) (parkma99) +- [Refactor] PipelineFixer physical optimizer rule removal [#7059](https://github.com/apache/arrow-datafusion/pull/7059) (metesynnada) +- fix: disallow interval - timestamp [#7086](https://github.com/apache/arrow-datafusion/pull/7086) (jackwener) +- Add Utf8->Binary type coercion for comparison [#7080](https://github.com/apache/arrow-datafusion/pull/7080) (jonahgao) +- Refactor Replace Repartition rule [#7090](https://github.com/apache/arrow-datafusion/pull/7090) (mustafasrepo) +- change the input_type parameter of the create_udaf function from DataType to Vec [#7096](https://github.com/apache/arrow-datafusion/pull/7096) (jiangzhx) +- fix: Projection columns_map remove name search [#7099](https://github.com/apache/arrow-datafusion/pull/7099) (mustafasrepo) +- Minor: Refine doc comments for BuiltinScalarFunction::return_dimension [#7045](https://github.com/apache/arrow-datafusion/pull/7045) (alamb) +- Relax check during aggregate partial mode. [#7101](https://github.com/apache/arrow-datafusion/pull/7101) (mustafasrepo) +- refactor byte_to_string and string_to_byte [#7091](https://github.com/apache/arrow-datafusion/pull/7091) (parkma99) +- Minor: add test + docs for 2 argument trunc with columns [#7042](https://github.com/apache/arrow-datafusion/pull/7042) (alamb) +- Move inactive projects to a different section [#7104](https://github.com/apache/arrow-datafusion/pull/7104) (alamb) +- Port remaining information_schema rust tests to sqllogictests [#7050](https://github.com/apache/arrow-datafusion/pull/7050) (palash25) +- Change `rust-version` in Cargo.toml to comply with MSRV [#7107](https://github.com/apache/arrow-datafusion/pull/7107) (sarutak) +- create all needed folders in advance for benchmarks [#7105](https://github.com/apache/arrow-datafusion/pull/7105) (smiklos) +- Initial support for functional dependencies handling primary key and unique constraints [#7040](https://github.com/apache/arrow-datafusion/pull/7040) (mustafasrepo) +- Add ClickBench queries to DataFusion benchmark runner [#7060](https://github.com/apache/arrow-datafusion/pull/7060) (alamb) +- feat: support SQL array replacement and removement functions [#7057](https://github.com/apache/arrow-datafusion/pull/7057) (izveigor) +- [doc], [minor]. Update docstring of group by rewrite. [#7111](https://github.com/apache/arrow-datafusion/pull/7111) (mustafasrepo) +- Add additional links to main README [#7102](https://github.com/apache/arrow-datafusion/pull/7102) (alamb) +- fix: fix index bug and add test to check it [#7124](https://github.com/apache/arrow-datafusion/pull/7124) (mustafasrepo) +- fix: Fix panic in filter predicate [#7126](https://github.com/apache/arrow-datafusion/pull/7126) (alamb) +- Add MSRV check as a GA job [#7123](https://github.com/apache/arrow-datafusion/pull/7123) (sarutak) +- Minor: move `AnalysisContext` out of physical_expr and into its own module [#7127](https://github.com/apache/arrow-datafusion/pull/7127) (alamb) +- fix: correct count(\*) alias [#7081](https://github.com/apache/arrow-datafusion/pull/7081) (jackwener) +- `make_array` with column of list [#7137](https://github.com/apache/arrow-datafusion/pull/7137) (jayzhan211) +- feat: array containment operator `@>` and `<@` [#6885](https://github.com/apache/arrow-datafusion/pull/6885) (izveigor) +- [MINOR]: Make memory exec partition number =1, in test utils [#7148](https://github.com/apache/arrow-datafusion/pull/7148) (mustafasrepo) +- Substrait union/union all [#7117](https://github.com/apache/arrow-datafusion/pull/7117) (nseekhao) +- minor: Remove mac m1 compilation for size_of_scalar test [#7151](https://github.com/apache/arrow-datafusion/pull/7151) (mustafasrepo) +- chore: add config option for allowing bounded use of sort-preserving operators [#7164](https://github.com/apache/arrow-datafusion/pull/7164) (wolffcm) +- chore: edition use workspace [#7140](https://github.com/apache/arrow-datafusion/pull/7140) (jackwener) +- [bug]: Fix multi partition wrong column requirement bug [#7129](https://github.com/apache/arrow-datafusion/pull/7129) (mustafasrepo) +- Refactor memory_limit tests to make them easier to extend [#7131](https://github.com/apache/arrow-datafusion/pull/7131) (alamb) +- Minor: show output ordering in MemoryExec [#7169](https://github.com/apache/arrow-datafusion/pull/7169) (alamb) +- Move ordering equivalence, and output ordering for joins to util functions [#7167](https://github.com/apache/arrow-datafusion/pull/7167) (mustafasrepo) +- Add regr_slope() aggregate function [#7135](https://github.com/apache/arrow-datafusion/pull/7135) (2010YOUY01) +- Add expression for array_agg [#7159](https://github.com/apache/arrow-datafusion/pull/7159) (willrnch) +- fix: skip compression tests on --no-default-features [#7172](https://github.com/apache/arrow-datafusion/pull/7172) (not-my-profile) +- HashJoin order fixing [#7155](https://github.com/apache/arrow-datafusion/pull/7155) (metesynnada) +- tweak: demote heading levels in PR template [#7176](https://github.com/apache/arrow-datafusion/pull/7176) (not-my-profile) +- feat: add sqllogictests crate [#7134](https://github.com/apache/arrow-datafusion/pull/7134) (tshauck) +- docs: fix broken link [#7177](https://github.com/apache/arrow-datafusion/pull/7177) (SteveLauC) +- Add nanvl builtin function [#7171](https://github.com/apache/arrow-datafusion/pull/7171) (sarutak) +- chore(deps): update apache-avro requirement from 0.14 to 0.15 [#7174](https://github.com/apache/arrow-datafusion/pull/7174) (jackwener) +- make dataframe.task_ctx public [#7183](https://github.com/apache/arrow-datafusion/pull/7183) (milenkovicm) +- feat: allow `datafusion-cli` to accept multiple statements [#7138](https://github.com/apache/arrow-datafusion/pull/7138) (NiwakaDev) +- Add `plan_err!` error macro [#7115](https://github.com/apache/arrow-datafusion/pull/7115) (comphead) +- refactor: add ExecutionPlan::file_scan_config to avoid downcasting [#7175](https://github.com/apache/arrow-datafusion/pull/7175) (not-my-profile) +- Minor: Add documentation + diagrams for ExternalSorter [#7179](https://github.com/apache/arrow-datafusion/pull/7179) (alamb) +- Support simplifying expressions such as `~ ^(ba_r|foo)$` , where the string includes underline [#7186](https://github.com/apache/arrow-datafusion/pull/7186) (tanruixiang) +- Add MemoryReservation::{split_off, take, new_empty} [#7184](https://github.com/apache/arrow-datafusion/pull/7184) (alamb) +- Update bench.sh to only run 5 iterations [#7189](https://github.com/apache/arrow-datafusion/pull/7189) (alamb) +- Implement `array_slice` and `array_element`, remove `array_trim` [#6936](https://github.com/apache/arrow-datafusion/pull/6936) (izveigor) +- Unify DataFrame and SQL (Insert Into) Write Methods [#7141](https://github.com/apache/arrow-datafusion/pull/7141) (devinjdangelo) +- Minor: Further Increase stack_size to prevent roundtrip_deeply_nested test stack overflow [#7208](https://github.com/apache/arrow-datafusion/pull/7208) (devinjdangelo) +- Don't track files generated by regen.sh [#7204](https://github.com/apache/arrow-datafusion/pull/7204) (sarutak) +- Update some docs/scripts to reflect the removed/added packages. [#7202](https://github.com/apache/arrow-datafusion/pull/7202) (sarutak) +- Implement `array_repeat`, remove `array_fill` [#7199](https://github.com/apache/arrow-datafusion/pull/7199) (izveigor) +- Use tokio only if running from a multi-thread tokio context [#7205](https://github.com/apache/arrow-datafusion/pull/7205) (viirya) +- Remove Outdated NY Taxi benchmark [#7210](https://github.com/apache/arrow-datafusion/pull/7210) (alamb) +- improve the ergonomics of creating field and list array accesses [#7215](https://github.com/apache/arrow-datafusion/pull/7215) (izveigor) +- [MINOR] Document refactor on NestedLoopJoin [#7217](https://github.com/apache/arrow-datafusion/pull/7217) (metesynnada) +- Docs: Add GlareDB to list of DataFusion users [#7223](https://github.com/apache/arrow-datafusion/pull/7223) (alamb) +- fix: typo in substrait [#7224](https://github.com/apache/arrow-datafusion/pull/7224) (waynexia) +- Minor: Add constructors to GetFieldAccessExpr and add docs [#7219](https://github.com/apache/arrow-datafusion/pull/7219) (alamb) +- chore: required at least 1 approve before merge [#7226](https://github.com/apache/arrow-datafusion/pull/7226) (jackwener) +- feat: Add linear regression aggregate functions [#7211](https://github.com/apache/arrow-datafusion/pull/7211) (2010YOUY01) +- Add `Expr::field`, `Expr::index`, and `Expr::slice`, add docs [#7218](https://github.com/apache/arrow-datafusion/pull/7218) (alamb) +- Extend insert into support to include Json backed tables [#7212](https://github.com/apache/arrow-datafusion/pull/7212) (devinjdangelo) +- Minor: rename `GetFieldAccessCharacteristic` and add docs [#7220](https://github.com/apache/arrow-datafusion/pull/7220) (alamb) +- Minor: Remove unecessary `clone_with_replacement` [#7232](https://github.com/apache/arrow-datafusion/pull/7232) (alamb) +- Update Arrow 45.0.0 And Datum Arithmetic, change Decimal Division semantics [#6832](https://github.com/apache/arrow-datafusion/pull/6832) (tustvold) +- Support `make_array` null handling in nested version [#7207](https://github.com/apache/arrow-datafusion/pull/7207) (jayzhan211) +- [Minor], Bug Fix: Add empty ordering check at the source. [#7230](https://github.com/apache/arrow-datafusion/pull/7230) (mustafasrepo) +- Minor: with preserve order now receives argument [#7231](https://github.com/apache/arrow-datafusion/pull/7231) (mustafasrepo) +- Minor: Remove [[example]] table from datafusion-examples/Cargo.toml [#7235](https://github.com/apache/arrow-datafusion/pull/7235) (sarutak) +- Remove additional cast from TPCH q8 [#7233](https://github.com/apache/arrow-datafusion/pull/7233) (viirya) +- Minor: Move `project_schema` to `datafusion_common` [#7237](https://github.com/apache/arrow-datafusion/pull/7237) (alamb) +- Minor: Extract ExecutionPlanVisitor to its own module [#7236](https://github.com/apache/arrow-datafusion/pull/7236) (alamb) +- Minor: Move streams out of `physical_plan` module [#7234](https://github.com/apache/arrow-datafusion/pull/7234) (alamb) +- doc: Add link to contributor's guide for new functions within the src [#7240](https://github.com/apache/arrow-datafusion/pull/7240) (2010YOUY01) +- Account for memory usage in SortPreservingMerge (#5885) [#7130](https://github.com/apache/arrow-datafusion/pull/7130) (alamb) +- Deprecate `batch_byte_size` [#7245](https://github.com/apache/arrow-datafusion/pull/7245) (alamb) +- Minor: Move `Partitioning` and`Distribution` to physical_expr [#7238](https://github.com/apache/arrow-datafusion/pull/7238) (alamb) +- Minor: remove duplication in `create_writer` [#7229](https://github.com/apache/arrow-datafusion/pull/7229) (alamb) +- Support array `flatten` sql function [#7239](https://github.com/apache/arrow-datafusion/pull/7239) (jayzhan211) +- Minor: fix clippy for memory_limit test [#7248](https://github.com/apache/arrow-datafusion/pull/7248) (yjshen) +- Update `physical_plan` tests to not use SessionContext [#7243](https://github.com/apache/arrow-datafusion/pull/7243) (alamb) +- Add API to make `unnest` consistent with DuckDB/ClickHouse, add option for preserve_nulls, update docs [#7168](https://github.com/apache/arrow-datafusion/pull/7168) (alamb) +- chore(sqllogictests-doc): add testing set up [#7258](https://github.com/apache/arrow-datafusion/pull/7258) (appletreeisyellow) +- Avoid to use TempDir::into_path for temporary dirs expected to be deleted automatically [#7252](https://github.com/apache/arrow-datafusion/pull/7252) (sarutak) +- [MINOR]: update benefits_from_input_partitioning implementation for projection and repartition [#7246](https://github.com/apache/arrow-datafusion/pull/7246) (mustafasrepo) +- Adding order equivalence support on MemoryExec [#7259](https://github.com/apache/arrow-datafusion/pull/7259) (metesynnada) +- chore(functions): fix function names typo [#7269](https://github.com/apache/arrow-datafusion/pull/7269) (appletreeisyellow) diff --git a/dev/changelog/30.0.0.md b/dev/changelog/30.0.0.md new file mode 100644 index 0000000000000..e713555497447 --- /dev/null +++ b/dev/changelog/30.0.0.md @@ -0,0 +1,83 @@ + + +## [30.0.0](https://github.com/apache/arrow-datafusion/tree/30.0.0) (2023-08-22) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/29.0.0...30.0.0) + +**Implemented enhancements:** + +- feat: Add support for PostgreSQL bitwise XOR operator [#7256](https://github.com/apache/arrow-datafusion/pull/7256) (jonahgao) + +**Fixed bugs:** + +- fix(functions): support `Dictionary` for string and int functions [#7262](https://github.com/apache/arrow-datafusion/pull/7262) (appletreeisyellow) +- fix: CLI should support different sql dialects [#7263](https://github.com/apache/arrow-datafusion/pull/7263) (jonahgao) +- fix: build_timestamp_list data type mismatch [#7267](https://github.com/apache/arrow-datafusion/pull/7267) (yukkit) + +**Documentation updates:** + +- Minor: Remove stubbed out redundant Execution Plan section of library user guide [#7309](https://github.com/apache/arrow-datafusion/pull/7309) (alamb) + +**Merged pull requests:** + +- chore(functions): fix function names typo [#7269](https://github.com/apache/arrow-datafusion/pull/7269) (appletreeisyellow) +- fix(functions): support `Dictionary` for string and int functions [#7262](https://github.com/apache/arrow-datafusion/pull/7262) (appletreeisyellow) +- Change benefits_from_partitioning flag to vector [#7247](https://github.com/apache/arrow-datafusion/pull/7247) (mustafasrepo) +- fix: CLI should support different sql dialects [#7263](https://github.com/apache/arrow-datafusion/pull/7263) (jonahgao) +- fix: build_timestamp_list data type mismatch [#7267](https://github.com/apache/arrow-datafusion/pull/7267) (yukkit) +- feat: Add support for PostgreSQL bitwise XOR operator [#7256](https://github.com/apache/arrow-datafusion/pull/7256) (jonahgao) +- Improve error message for aggregate/window functions [#7265](https://github.com/apache/arrow-datafusion/pull/7265) (2010YOUY01) +- Extend insert into to support Parquet backed tables [#7244](https://github.com/apache/arrow-datafusion/pull/7244) (devinjdangelo) +- Operators documentation [#7264](https://github.com/apache/arrow-datafusion/pull/7264) (spaydar) +- Minor: Add upstream ticket reference in comments [#7275](https://github.com/apache/arrow-datafusion/pull/7275) (alamb) +- Add parquet-filter and sort benchmarks to dfbench [#7120](https://github.com/apache/arrow-datafusion/pull/7120) (alamb) +- Allow `skip_failed_rules` to skip buggy logical plan rules that have a schema mismatch [#7277](https://github.com/apache/arrow-datafusion/pull/7277) (smiklos) +- Enable creating and inserting to empty external tables via SQL [#7276](https://github.com/apache/arrow-datafusion/pull/7276) (devinjdangelo) +- Prepare 29.0.0 Release [#7270](https://github.com/apache/arrow-datafusion/pull/7270) (andygrove) +- Hotfix: Test in information_schema.slt fails [#7286](https://github.com/apache/arrow-datafusion/pull/7286) (sarutak) +- Move sqllogictests to sqllogictests crate to break cyclic dependency [#7284](https://github.com/apache/arrow-datafusion/pull/7284) (alamb) +- Add isnan and iszero [#7274](https://github.com/apache/arrow-datafusion/pull/7274) (sarutak) +- Add library guide for table provider and catalog providers [#7287](https://github.com/apache/arrow-datafusion/pull/7287) (tshauck) +- Implement Support for Copy To Logical and Physical plans [#7283](https://github.com/apache/arrow-datafusion/pull/7283) (devinjdangelo) +- Add `internal_err!` error macro [#7293](https://github.com/apache/arrow-datafusion/pull/7293) (comphead) +- refactor: data types in `array_expressions` [#7280](https://github.com/apache/arrow-datafusion/pull/7280) (izveigor) +- Fix Unnest for array aggregations. [#7300](https://github.com/apache/arrow-datafusion/pull/7300) (vincev) +- Minor: Followup tasks for `nanvl` [#7311](https://github.com/apache/arrow-datafusion/pull/7311) (sarutak) +- Minor: Remove stubbed out redundant Execution Plan section of library user guide [#7309](https://github.com/apache/arrow-datafusion/pull/7309) (alamb) +- Minor: fix some parquet writer session level defaults [#7295](https://github.com/apache/arrow-datafusion/pull/7295) (devinjdangelo) +- Add Sqllogictests for INSERT INTO external table [#7294](https://github.com/apache/arrow-datafusion/pull/7294) (devinjdangelo) +- Minor: Fix documentation typos for array expressions [#7314](https://github.com/apache/arrow-datafusion/pull/7314) (Weijun-H) +- Qualify filter fields in the update plan [#7316](https://github.com/apache/arrow-datafusion/pull/7316) (gruuya) +- chore(deps): update tokio requirement to 1.28 [#7324](https://github.com/apache/arrow-datafusion/pull/7324) (jonahgao) +- Bug-fix / Join Output Orderings [#7296](https://github.com/apache/arrow-datafusion/pull/7296) (berkaysynnada) +- Add `internal_err` error macros. Part 2 [#7321](https://github.com/apache/arrow-datafusion/pull/7321) (comphead) +- Minor: Improve doc comments to datafusion-sql [#7318](https://github.com/apache/arrow-datafusion/pull/7318) (alamb) +- Minor: make memory_limit tests more self describing [#7190](https://github.com/apache/arrow-datafusion/pull/7190) (alamb) +- Minor: Improve docstrings for `LogicalPlan` [#7331](https://github.com/apache/arrow-datafusion/pull/7331) (alamb) +- minor: fix doc/typo [#7341](https://github.com/apache/arrow-datafusion/pull/7341) (jackwener) +- Minor: Extract `FileScanConfig` into its own module [#7335](https://github.com/apache/arrow-datafusion/pull/7335) (alamb) +- Minor: Move shared testing code into datafusion_common [#7334](https://github.com/apache/arrow-datafusion/pull/7334) (alamb) +- refine: `substr` error [#7339](https://github.com/apache/arrow-datafusion/pull/7339) (Weijun-H) +- Add `not_impl_err` error macro [#7340](https://github.com/apache/arrow-datafusion/pull/7340) (comphead) +- chore: public sql_statement_to_plan_with_context() [#7268](https://github.com/apache/arrow-datafusion/pull/7268) (waynexia) +- Deprecate ScalarValue bitor, bitand, and bitxor (#6842) [#7351](https://github.com/apache/arrow-datafusion/pull/7351) (tustvold) +- feature: Support `EXPLAIN COPY` [#7291](https://github.com/apache/arrow-datafusion/pull/7291) (alamb) +- Add `SQLOptions` for controlling allowed SQL statements, update docs [#7333](https://github.com/apache/arrow-datafusion/pull/7333) (alamb) +- Refactor: Consolidate OutputFileFormat and FileType into datafusion_common [#7336](https://github.com/apache/arrow-datafusion/pull/7336) (devinjdangelo) diff --git a/dev/changelog/31.0.0.md b/dev/changelog/31.0.0.md new file mode 100644 index 0000000000000..9f606ffd51e12 --- /dev/null +++ b/dev/changelog/31.0.0.md @@ -0,0 +1,123 @@ + + +## [31.0.0](https://github.com/apache/arrow-datafusion/tree/31.0.0) (2023-09-08) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/30.0.0...31.0.0) + +**Breaking changes:** + +- Specialize Avg and Sum accumulators (#6842) [#7358](https://github.com/apache/arrow-datafusion/pull/7358) (tustvold) +- Use datum arithmetic scalar value [#7375](https://github.com/apache/arrow-datafusion/pull/7375) (tustvold) + +**Implemented enhancements:** + +- feat: `array-empty` [#7313](https://github.com/apache/arrow-datafusion/pull/7313) (Weijun-H) +- Support `REPLACE` SQL alias syntax [#7368](https://github.com/apache/arrow-datafusion/pull/7368) (berkaysynnada) +- feat: support primary key alternate syntax [#7160](https://github.com/apache/arrow-datafusion/pull/7160) (parkma99) +- docs: Add `Expr` library developer page [#7359](https://github.com/apache/arrow-datafusion/pull/7359) (tshauck) +- feat: support Binary for `min/max` [#7397](https://github.com/apache/arrow-datafusion/pull/7397) (Weijun-H) +- feat: Add memory pool configuration to `datafusion-cli` [#7424](https://github.com/apache/arrow-datafusion/pull/7424) (Weijun-H) +- Support Configuring Arrow RecordBatch Writers via SQL Statement Options [#7390](https://github.com/apache/arrow-datafusion/pull/7390) (devinjdangelo) +- Add ROLLUP and GROUPING SETS substrait support [#7382](https://github.com/apache/arrow-datafusion/pull/7382) (nseekhao) +- feat: Allow creating a ValuesExec from record batches [#7444](https://github.com/apache/arrow-datafusion/pull/7444) (scsmithr) +- minor: Add ARROW to `CREATE EXTERNAL TABLE` docs and add example of `COMPRESSION TYPE` [#7489](https://github.com/apache/arrow-datafusion/pull/7489) (alamb) +- Support Configuring Parquet Column Specific Options via SQL Statement Options [#7466](https://github.com/apache/arrow-datafusion/pull/7466) (devinjdangelo) +- Write Multiple Parquet Files in Parallel [#7483](https://github.com/apache/arrow-datafusion/pull/7483) (devinjdangelo) +- feat: explain with statistics [#7459](https://github.com/apache/arrow-datafusion/pull/7459) (korowa) + +**Fixed bugs:** + +- fix: union_distinct shouldn't remove child distinct [#7346](https://github.com/apache/arrow-datafusion/pull/7346) (jackwener) +- fix: inconsistent scalar types in `DistinctArrayAggAccumulator` state [#7385](https://github.com/apache/arrow-datafusion/pull/7385) (korowa) +- fix: incorrect nullability calculation of `InListExpr` [#7496](https://github.com/apache/arrow-datafusion/pull/7496) (jonahgao) + +**Merged pull requests:** + +- Remove redundant type check in Avg [#7374](https://github.com/apache/arrow-datafusion/pull/7374) (viirya) +- feat: `array-empty` [#7313](https://github.com/apache/arrow-datafusion/pull/7313) (Weijun-H) +- Minor: add `WriteOp::name` and `DmlStatement::name` [#7329](https://github.com/apache/arrow-datafusion/pull/7329) (alamb) +- Specialize Median Accumulator [#7376](https://github.com/apache/arrow-datafusion/pull/7376) (tustvold) +- Specialize Avg and Sum accumulators (#6842) [#7358](https://github.com/apache/arrow-datafusion/pull/7358) (tustvold) +- Change error type of invalid argument to PlanError rather than InternalError, remove misleading comments [#7355](https://github.com/apache/arrow-datafusion/pull/7355) (alamb) +- Implement `array_pop_back` function [#7348](https://github.com/apache/arrow-datafusion/pull/7348) (tanruixiang) +- Add `exec_err!` error macro [#7361](https://github.com/apache/arrow-datafusion/pull/7361) (comphead) +- Update sqlparser requirement from 0.36.1 to 0.37.0 [#7387](https://github.com/apache/arrow-datafusion/pull/7387) (viirya) +- DML documentation [#7362](https://github.com/apache/arrow-datafusion/pull/7362) (spaydar) +- Support `REPLACE` SQL alias syntax [#7368](https://github.com/apache/arrow-datafusion/pull/7368) (berkaysynnada) +- Bug-fix/next_value() of Min/Max Scalar Values [#7384](https://github.com/apache/arrow-datafusion/pull/7384) (berkaysynnada) +- Prepare 30.0.0 release [#7372](https://github.com/apache/arrow-datafusion/pull/7372) (andygrove) +- fix: union_distinct shouldn't remove child distinct [#7346](https://github.com/apache/arrow-datafusion/pull/7346) (jackwener) +- feat: support primary key alternate syntax [#7160](https://github.com/apache/arrow-datafusion/pull/7160) (parkma99) +- Merge hash table implementations and remove leftover utilities [#7366](https://github.com/apache/arrow-datafusion/pull/7366) (metesynnada) +- Minor: remove stray `println!` from `array_expressions.rs` [#7389](https://github.com/apache/arrow-datafusion/pull/7389) (alamb) +- Projection Order Propagation [#7364](https://github.com/apache/arrow-datafusion/pull/7364) (berkaysynnada) +- Document and `scratch` directory for sqllogictest and make test specific [#7312](https://github.com/apache/arrow-datafusion/pull/7312) (alamb) +- Minor: Move test for `Distribution` and `Partitioning` with code [#7392](https://github.com/apache/arrow-datafusion/pull/7392) (alamb) +- Minor: move datasource statistics code into its own module [#7391](https://github.com/apache/arrow-datafusion/pull/7391) (alamb) +- Use datum arithmetic scalar value [#7375](https://github.com/apache/arrow-datafusion/pull/7375) (tustvold) +- Fix IN expr for NaN [#7378](https://github.com/apache/arrow-datafusion/pull/7378) (sarutak) +- Unnecessary to list all files during partition pruning [#7395](https://github.com/apache/arrow-datafusion/pull/7395) (smallzhongfeng) +- Optimize `Unnest` and implement `skip_nulls=true` if specified [#7371](https://github.com/apache/arrow-datafusion/pull/7371) (smiklos) +- Docs: Add query syntax to `COPY` docs [#7388](https://github.com/apache/arrow-datafusion/pull/7388) (alamb) +- Clean up clippy for Rust 1.72 release [#7399](https://github.com/apache/arrow-datafusion/pull/7399) (alamb) +- fix: inconsistent scalar types in `DistinctArrayAggAccumulator` state [#7385](https://github.com/apache/arrow-datafusion/pull/7385) (korowa) +- Fix python CI [#7416](https://github.com/apache/arrow-datafusion/pull/7416) (tustvold) +- docs: Add `Expr` library developer page [#7359](https://github.com/apache/arrow-datafusion/pull/7359) (tshauck) +- Update ObjectStore 0.7.0 and Arrow 46.0.0 [#7282](https://github.com/apache/arrow-datafusion/pull/7282) (tustvold) +- Fix Decimal256 scalar display string in sqllogictest [#7404](https://github.com/apache/arrow-datafusion/pull/7404) (viirya) +- feat: support Binary for `min/max` [#7397](https://github.com/apache/arrow-datafusion/pull/7397) (Weijun-H) +- Make sqllogictest distinguish NaN from -NaN [#7403](https://github.com/apache/arrow-datafusion/pull/7403) (sarutak) +- Replace lazy_static with OnceLock [#7409](https://github.com/apache/arrow-datafusion/pull/7409) (sarutak) +- Minor: Remove the unreached simplification rule for `0 / 0` [#7405](https://github.com/apache/arrow-datafusion/pull/7405) (jonahgao) +- feat: Add memory pool configuration to `datafusion-cli` [#7424](https://github.com/apache/arrow-datafusion/pull/7424) (Weijun-H) +- Minor: Debug log when FairPool is created [#7431](https://github.com/apache/arrow-datafusion/pull/7431) (alamb) +- Support Configuring Arrow RecordBatch Writers via SQL Statement Options [#7390](https://github.com/apache/arrow-datafusion/pull/7390) (devinjdangelo) +- Add ROLLUP and GROUPING SETS substrait support [#7382](https://github.com/apache/arrow-datafusion/pull/7382) (nseekhao) +- Refactor sort_fuzz test to clarify what is covered [#7439](https://github.com/apache/arrow-datafusion/pull/7439) (alamb) +- Use DateTime::from_naive_utc_and_offset instead of DateTime::from_utc [#7451](https://github.com/apache/arrow-datafusion/pull/7451) (sarutak) +- Update substrait requirement from 0.12.0 to 0.13.1 [#7443](https://github.com/apache/arrow-datafusion/pull/7443) (viirya) +- [minior fix]: adjust the projection statistics [#7428](https://github.com/apache/arrow-datafusion/pull/7428) (liukun4515) +- Add new known users: Arroyo and Restate [#7464](https://github.com/apache/arrow-datafusion/pull/7464) (jychen7) +- ScalarFunctionExpr Maintaining Order [#7417](https://github.com/apache/arrow-datafusion/pull/7417) (berkaysynnada) +- Bug-fix/find_orderings_of_exprs [#7457](https://github.com/apache/arrow-datafusion/pull/7457) (berkaysynnada) +- Update prost-derive requirement from 0.11 to 0.12 [#7468](https://github.com/apache/arrow-datafusion/pull/7468) (dependabot[bot]) +- Revert "Update prost-derive requirement from 0.11 to 0.12 (#7468)" [#7476](https://github.com/apache/arrow-datafusion/pull/7476) (viirya) +- Return error if spill file does not exist in ExternalSorter [#7479](https://github.com/apache/arrow-datafusion/pull/7479) (viirya) +- [minor fix]: Remove unused duplicate `file_type.rs` [#7478](https://github.com/apache/arrow-datafusion/pull/7478) (sarutak) +- Minor: more flexible pool size setting for datafusion-cli [#7454](https://github.com/apache/arrow-datafusion/pull/7454) (yjshen) +- Bump actions/checkout from 3 to 4 [#7480](https://github.com/apache/arrow-datafusion/pull/7480) (dependabot[bot]) +- Support Write Options in DataFrame::write\_\* methods [#7435](https://github.com/apache/arrow-datafusion/pull/7435) (devinjdangelo) +- cp_solver, Duration vs Interval cases [#7475](https://github.com/apache/arrow-datafusion/pull/7475) (berkaysynnada) +- feat: Allow creating a ValuesExec from record batches [#7444](https://github.com/apache/arrow-datafusion/pull/7444) (scsmithr) +- Make `LogicalPlan::with_new_exprs,` deprecate `from_plan` [#7396](https://github.com/apache/arrow-datafusion/pull/7396) (alamb) +- refactor: change file type logic for create table [#7477](https://github.com/apache/arrow-datafusion/pull/7477) (tshauck) +- minor: do not fail analyzer if subquery plan contains extension [#7455](https://github.com/apache/arrow-datafusion/pull/7455) (waynexia) +- Make IN expr work with multiple items [#7449](https://github.com/apache/arrow-datafusion/pull/7449) (sarutak) +- Minor: Add doc comments and example for `ScalarVaue::to_scalar` [#7491](https://github.com/apache/arrow-datafusion/pull/7491) (alamb) +- minor: Add ARROW to `CREATE EXTERNAL TABLE` docs and add example of `COMPRESSION TYPE` [#7489](https://github.com/apache/arrow-datafusion/pull/7489) (alamb) +- Add backtrace to error messages [#7434](https://github.com/apache/arrow-datafusion/pull/7434) (comphead) +- Make sqllogictest platform-independent for the sign of NaN [#7462](https://github.com/apache/arrow-datafusion/pull/7462) (sarutak) +- Support Configuring Parquet Column Specific Options via SQL Statement Options [#7466](https://github.com/apache/arrow-datafusion/pull/7466) (devinjdangelo) +- Minor: improve error message [#7498](https://github.com/apache/arrow-datafusion/pull/7498) (alamb) +- Write Multiple Parquet Files in Parallel [#7483](https://github.com/apache/arrow-datafusion/pull/7483) (devinjdangelo) +- `PrimitiveGroupsAccumulator` should propagate timestamp timezone information properly [#7494](https://github.com/apache/arrow-datafusion/pull/7494) (sunchao) +- Minor: Add `ScalarValue::data_type()` for consistency with other APIs [#7492](https://github.com/apache/arrow-datafusion/pull/7492) (alamb) +- feat: explain with statistics [#7459](https://github.com/apache/arrow-datafusion/pull/7459) (korowa) +- fix: incorrect nullability calculation of `InListExpr` [#7496](https://github.com/apache/arrow-datafusion/pull/7496) (jonahgao) diff --git a/dev/changelog/32.0.0.md b/dev/changelog/32.0.0.md new file mode 100644 index 0000000000000..781fd50015524 --- /dev/null +++ b/dev/changelog/32.0.0.md @@ -0,0 +1,195 @@ + + +## [32.0.0](https://github.com/apache/arrow-datafusion/tree/32.0.0) (2023-10-07) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/31.0.0...32.0.0) + +**Breaking changes:** + +- Remove implicit interval type coercion from ScalarValue comparison [#7514](https://github.com/apache/arrow-datafusion/pull/7514) (tustvold) +- Remove get_scan_files and ExecutionPlan::file_scan_config (#7357) [#7487](https://github.com/apache/arrow-datafusion/pull/7487) (tustvold) +- Move `FileCompressionType` out of `common` and into `core` [#7596](https://github.com/apache/arrow-datafusion/pull/7596) (haohuaijin) +- Update arrow 47.0.0 in DataFusion [#7587](https://github.com/apache/arrow-datafusion/pull/7587) (tustvold) +- Rename `bounded_order_preserving_variants` config to `prefer_exising_sort` and update docs [#7723](https://github.com/apache/arrow-datafusion/pull/7723) (alamb) + +**Implemented enhancements:** + +- Parallelize Stateless (CSV/JSON) File Write Serialization [#7452](https://github.com/apache/arrow-datafusion/pull/7452) (devinjdangelo) +- Create a Priority Queue based Aggregation with `limit` [#7192](https://github.com/apache/arrow-datafusion/pull/7192) (avantgardnerio) +- feat: add guarantees to simplification [#7467](https://github.com/apache/arrow-datafusion/pull/7467) (wjones127) +- [Minor]: Produce better plan when group by contains all of the ordering requirements [#7542](https://github.com/apache/arrow-datafusion/pull/7542) (mustafasrepo) +- Make AvroArrowArrayReader possible to scan Avro backed table which contains nested records [#7525](https://github.com/apache/arrow-datafusion/pull/7525) (sarutak) +- feat: Support spilling for hash aggregation [#7400](https://github.com/apache/arrow-datafusion/pull/7400) (kazuyukitanimura) +- Parallelize Parquet Serialization [#7562](https://github.com/apache/arrow-datafusion/pull/7562) (devinjdangelo) +- feat: natively support more data types for the `abs` function. [#7568](https://github.com/apache/arrow-datafusion/pull/7568) (jonahgao) +- feat: Parallel collecting parquet files statistics #7573 [#7595](https://github.com/apache/arrow-datafusion/pull/7595) (hengfeiyang) +- Support hashing List columns [#7616](https://github.com/apache/arrow-datafusion/pull/7616) (jonmmease) +- feat: Better large output display in datafusion-cli with --maxrows option [#7617](https://github.com/apache/arrow-datafusion/pull/7617) (2010YOUY01) +- feat: make parse_float_as_decimal work on negative numbers [#7648](https://github.com/apache/arrow-datafusion/pull/7648) (jonahgao) +- Update Default Parquet Write Compression [#7692](https://github.com/apache/arrow-datafusion/pull/7692) (devinjdangelo) +- Support all the codecs supported by Avro [#7718](https://github.com/apache/arrow-datafusion/pull/7718) (sarutak) +- Optimize "ORDER BY + LIMIT" queries for speed / memory with special TopK operator [#7721](https://github.com/apache/arrow-datafusion/pull/7721) (Dandandan) + +**Fixed bugs:** + +- fix: inconsistent behaviors when dividing floating numbers by zero [#7503](https://github.com/apache/arrow-datafusion/pull/7503) (jonahgao) +- fix: skip EliminateCrossJoin rule if inner join with filter is found [#7529](https://github.com/apache/arrow-datafusion/pull/7529) (epsio-banay) +- fix: check for precision overflow when parsing float as decimal [#7627](https://github.com/apache/arrow-datafusion/pull/7627) (jonahgao) +- fix: substrait limit when fetch is None [#7669](https://github.com/apache/arrow-datafusion/pull/7669) (waynexia) +- fix: coerce text to timestamps with timezones [#7720](https://github.com/apache/arrow-datafusion/pull/7720) (mhilton) +- fix: avro_to_arrow: Handle avro nested nullable struct (union) [#7663](https://github.com/apache/arrow-datafusion/pull/7663) (Samrose-Ahmed) + +**Documentation updates:** + +- Documentation Updates for New Write Related Features [#7520](https://github.com/apache/arrow-datafusion/pull/7520) (devinjdangelo) +- Create 2023 Q4 roadmap [#7551](https://github.com/apache/arrow-datafusion/pull/7551) (graydenshand) +- docs: add section on supports_filters_pushdown [#7680](https://github.com/apache/arrow-datafusion/pull/7680) (tshauck) +- Add LanceDB to the list of Known Users [#7716](https://github.com/apache/arrow-datafusion/pull/7716) (alamb) +- Document crate feature flags [#7713](https://github.com/apache/arrow-datafusion/pull/7713) (alamb) + +**Merged pull requests:** + +- Prepare 31.0.0 release [#7508](https://github.com/apache/arrow-datafusion/pull/7508) (andygrove) +- Minor(proto): Implement `TryFrom<&DFSchema>` for `protobuf::DfSchema` [#7505](https://github.com/apache/arrow-datafusion/pull/7505) (jonahgao) +- fix: inconsistent behaviors when dividing floating numbers by zero [#7503](https://github.com/apache/arrow-datafusion/pull/7503) (jonahgao) +- Parallelize Stateless (CSV/JSON) File Write Serialization [#7452](https://github.com/apache/arrow-datafusion/pull/7452) (devinjdangelo) +- Minor: Remove stray comment markings from encoding error message [#7512](https://github.com/apache/arrow-datafusion/pull/7512) (devinjdangelo) +- Remove implicit interval type coercion from ScalarValue comparison [#7514](https://github.com/apache/arrow-datafusion/pull/7514) (tustvold) +- Minor: deprecate ScalarValue::get_datatype() [#7507](https://github.com/apache/arrow-datafusion/pull/7507) (Weijun-H) +- Propagate error from spawned task reading spills [#7510](https://github.com/apache/arrow-datafusion/pull/7510) (viirya) +- Refactor the EnforceDistribution Rule [#7488](https://github.com/apache/arrow-datafusion/pull/7488) (mustafasrepo) +- Remove get_scan_files and ExecutionPlan::file_scan_config (#7357) [#7487](https://github.com/apache/arrow-datafusion/pull/7487) (tustvold) +- Simplify ScalarValue::distance (#7517) [#7519](https://github.com/apache/arrow-datafusion/pull/7519) (tustvold) +- typo: change `delimeter` to `delimiter` [#7521](https://github.com/apache/arrow-datafusion/pull/7521) (Weijun-H) +- Fix some simplification rules for floating-point arithmetic operations [#7515](https://github.com/apache/arrow-datafusion/pull/7515) (jonahgao) +- Documentation Updates for New Write Related Features [#7520](https://github.com/apache/arrow-datafusion/pull/7520) (devinjdangelo) +- [MINOR]: Move tests from repartition to enforce_distribution file [#7539](https://github.com/apache/arrow-datafusion/pull/7539) (mustafasrepo) +- Update the async-trait crate to resolve clippy bug [#7541](https://github.com/apache/arrow-datafusion/pull/7541) (metesynnada) +- Fix flaky `test_sort_fetch_memory_calculation` test [#7534](https://github.com/apache/arrow-datafusion/pull/7534) (viirya) +- Move common code to utils [#7545](https://github.com/apache/arrow-datafusion/pull/7545) (mustafasrepo) +- Minor: Add comments and clearer constructors to `Interval` [#7526](https://github.com/apache/arrow-datafusion/pull/7526) (alamb) +- fix: skip EliminateCrossJoin rule if inner join with filter is found [#7529](https://github.com/apache/arrow-datafusion/pull/7529) (epsio-banay) +- Create a Priority Queue based Aggregation with `limit` [#7192](https://github.com/apache/arrow-datafusion/pull/7192) (avantgardnerio) +- feat: add guarantees to simplification [#7467](https://github.com/apache/arrow-datafusion/pull/7467) (wjones127) +- [Minor]: Produce better plan when group by contains all of the ordering requirements [#7542](https://github.com/apache/arrow-datafusion/pull/7542) (mustafasrepo) +- Minor: beautify interval display [#7554](https://github.com/apache/arrow-datafusion/pull/7554) (Weijun-H) +- replace ptree with termtree [#7560](https://github.com/apache/arrow-datafusion/pull/7560) (avantgardnerio) +- Make AvroArrowArrayReader possible to scan Avro backed table which contains nested records [#7525](https://github.com/apache/arrow-datafusion/pull/7525) (sarutak) +- Fix a race condition issue on reading spilled file [#7538](https://github.com/apache/arrow-datafusion/pull/7538) (sarutak) +- [MINOR]: Add is single method [#7558](https://github.com/apache/arrow-datafusion/pull/7558) (mustafasrepo) +- Fix `describe
` to work without SessionContext [#7441](https://github.com/apache/arrow-datafusion/pull/7441) (alamb) +- Make the tests in SHJ faster [#7543](https://github.com/apache/arrow-datafusion/pull/7543) (metesynnada) +- feat: Support spilling for hash aggregation [#7400](https://github.com/apache/arrow-datafusion/pull/7400) (kazuyukitanimura) +- Make backtrace as a cargo feature [#7527](https://github.com/apache/arrow-datafusion/pull/7527) (comphead) +- Minor: Fix `clippy` by switching to `timestamp_nanos_opt` instead of (deprecated) `timestamp_nanos` [#7572](https://github.com/apache/arrow-datafusion/pull/7572) (alamb) +- Update sqllogictest requirement from 0.15.0 to 0.16.0 [#7569](https://github.com/apache/arrow-datafusion/pull/7569) (dependabot[bot]) +- extract `datafusion-physical-plan` to its own crate [#7432](https://github.com/apache/arrow-datafusion/pull/7432) (alamb) +- First and Last Accumulators should update with state row excluding is_set flag [#7565](https://github.com/apache/arrow-datafusion/pull/7565) (viirya) +- refactor: simplify code of eliminate_cross_join.rs [#7561](https://github.com/apache/arrow-datafusion/pull/7561) (jackwener) +- Update release instructions for datafusion-physical-plan crate [#7576](https://github.com/apache/arrow-datafusion/pull/7576) (alamb) +- Minor: Update chrono pin to `0.4.31` [#7575](https://github.com/apache/arrow-datafusion/pull/7575) (alamb) +- [feat] Introduce cacheManager in session ctx and make StatisticsCache share in session [#7570](https://github.com/apache/arrow-datafusion/pull/7570) (Ted-Jiang) +- Enhance/Refactor Ordering Equivalence Properties [#7566](https://github.com/apache/arrow-datafusion/pull/7566) (mustafasrepo) +- fix misplaced statements in sqllogictest [#7586](https://github.com/apache/arrow-datafusion/pull/7586) (jonahgao) +- Update substrait requirement from 0.13.1 to 0.14.0 [#7585](https://github.com/apache/arrow-datafusion/pull/7585) (dependabot[bot]) +- chore: use the `create_udwf` function in `simple_udwf`, consistent with `simple_udf` and `simple_udaf` [#7579](https://github.com/apache/arrow-datafusion/pull/7579) (tanruixiang) +- Implement protobuf serialization for AnalyzeExec [#7574](https://github.com/apache/arrow-datafusion/pull/7574) (adhish20) +- chore: fix catalog's usage docs error and add docs about `CatalogList` trait [#7582](https://github.com/apache/arrow-datafusion/pull/7582) (tanruixiang) +- Implement `CardinalityAwareRowConverter` while doing streaming merge [#7401](https://github.com/apache/arrow-datafusion/pull/7401) (JayjeetAtGithub) +- Parallelize Parquet Serialization [#7562](https://github.com/apache/arrow-datafusion/pull/7562) (devinjdangelo) +- feat: natively support more data types for the `abs` function. [#7568](https://github.com/apache/arrow-datafusion/pull/7568) (jonahgao) +- implement string_to_array [#7577](https://github.com/apache/arrow-datafusion/pull/7577) (casperhart) +- Create 2023 Q4 roadmap [#7551](https://github.com/apache/arrow-datafusion/pull/7551) (graydenshand) +- chore: reduce `physical-plan` dependencies [#7599](https://github.com/apache/arrow-datafusion/pull/7599) (crepererum) +- Minor: add githubs start/fork buttons to documentation page [#7588](https://github.com/apache/arrow-datafusion/pull/7588) (alamb) +- Minor: add more examples for `CREATE EXTERNAL TABLE` doc [#7594](https://github.com/apache/arrow-datafusion/pull/7594) (comphead) +- Update nix requirement from 0.26.1 to 0.27.1 [#7438](https://github.com/apache/arrow-datafusion/pull/7438) (dependabot[bot]) +- Update sqllogictest requirement from 0.16.0 to 0.17.0 [#7606](https://github.com/apache/arrow-datafusion/pull/7606) (dependabot[bot]) +- Fix panic in TopK [#7609](https://github.com/apache/arrow-datafusion/pull/7609) (avantgardnerio) +- Move `FileCompressionType` out of `common` and into `core` [#7596](https://github.com/apache/arrow-datafusion/pull/7596) (haohuaijin) +- Expose contents of Constraints [#7603](https://github.com/apache/arrow-datafusion/pull/7603) (tv42) +- Change the unbounded_output API default [#7605](https://github.com/apache/arrow-datafusion/pull/7605) (metesynnada) +- feat: Parallel collecting parquet files statistics #7573 [#7595](https://github.com/apache/arrow-datafusion/pull/7595) (hengfeiyang) +- Support hashing List columns [#7616](https://github.com/apache/arrow-datafusion/pull/7616) (jonmmease) +- [MINOR] Make the sink input aware of its plan [#7610](https://github.com/apache/arrow-datafusion/pull/7610) (metesynnada) +- [MINOR] Reduce complexity on SHJ [#7607](https://github.com/apache/arrow-datafusion/pull/7607) (metesynnada) +- feat: Better large output display in datafusion-cli with --maxrows option [#7617](https://github.com/apache/arrow-datafusion/pull/7617) (2010YOUY01) +- Minor: add examples for `arrow_cast` and `arrow_typeof` to user guide [#7615](https://github.com/apache/arrow-datafusion/pull/7615) (alamb) +- [MINOR]: Fix stack overflow bug for get field access expr [#7623](https://github.com/apache/arrow-datafusion/pull/7623) (mustafasrepo) +- Group By All [#7622](https://github.com/apache/arrow-datafusion/pull/7622) (berkaysynnada) +- Implement protobuf serialization for `(Bounded)WindowAggExec`. [#7557](https://github.com/apache/arrow-datafusion/pull/7557) (vrongmeal) +- Make it possible to compile datafusion-common without default features [#7625](https://github.com/apache/arrow-datafusion/pull/7625) (jonmmease) +- Minor: Adding backtrace documentation [#7628](https://github.com/apache/arrow-datafusion/pull/7628) (comphead) +- fix(5975/5976): timezone handling for timestamps and `date_trunc`, `date_part` and `date_bin` [#7614](https://github.com/apache/arrow-datafusion/pull/7614) (wiedld) +- Minor: remove unecessary `Arc`s in datetime_expressions [#7630](https://github.com/apache/arrow-datafusion/pull/7630) (alamb) +- fix: check for precision overflow when parsing float as decimal [#7627](https://github.com/apache/arrow-datafusion/pull/7627) (jonahgao) +- Update arrow 47.0.0 in DataFusion [#7587](https://github.com/apache/arrow-datafusion/pull/7587) (tustvold) +- Add test crate to compile DataFusion with wasm-pack [#7633](https://github.com/apache/arrow-datafusion/pull/7633) (jonmmease) +- Minor: Update documentation of case expression [#7646](https://github.com/apache/arrow-datafusion/pull/7646) (ongchi) +- Minor: improve docstrings on `SessionState` [#7654](https://github.com/apache/arrow-datafusion/pull/7654) (alamb) +- Update example in the DataFrame documentation. [#7650](https://github.com/apache/arrow-datafusion/pull/7650) (jsimpson-gro) +- Add HTTP object store example [#7602](https://github.com/apache/arrow-datafusion/pull/7602) (pka) +- feat: make parse_float_as_decimal work on negative numbers [#7648](https://github.com/apache/arrow-datafusion/pull/7648) (jonahgao) +- Minor: add doc comments to `ExtractEquijoinPredicate` [#7658](https://github.com/apache/arrow-datafusion/pull/7658) (alamb) +- [MINOR]: Do not add unnecessary hash repartition to the physical plan [#7667](https://github.com/apache/arrow-datafusion/pull/7667) (mustafasrepo) +- Minor: add ticket references to parallel parquet writing code [#7592](https://github.com/apache/arrow-datafusion/pull/7592) (alamb) +- Minor: Add ticket reference and add test comment [#7593](https://github.com/apache/arrow-datafusion/pull/7593) (alamb) +- Support Avro's Enum type and Fixed type [#7635](https://github.com/apache/arrow-datafusion/pull/7635) (sarutak) +- Minor: Migrate datafusion-proto tests into it own binary [#7668](https://github.com/apache/arrow-datafusion/pull/7668) (ongchi) +- Upgrade apache-avro to 0.16 [#7674](https://github.com/apache/arrow-datafusion/pull/7674) (sarutak) +- Move window analysis to the window method [#7672](https://github.com/apache/arrow-datafusion/pull/7672) (mustafasrepo) +- Don't add filters to projection in TableScan [#7670](https://github.com/apache/arrow-datafusion/pull/7670) (Dandandan) +- Minor: Improve `TableProviderFilterPushDown` docs [#7685](https://github.com/apache/arrow-datafusion/pull/7685) (alamb) +- FIX: Test timestamp with table [#7701](https://github.com/apache/arrow-datafusion/pull/7701) (jayzhan211) +- Fix bug in `SimplifyExpressions` [#7699](https://github.com/apache/arrow-datafusion/pull/7699) (Dandandan) +- Enhance Enforce Dist capabilities to fix, sub optimal bad plans [#7671](https://github.com/apache/arrow-datafusion/pull/7671) (mustafasrepo) +- docs: add section on supports_filters_pushdown [#7680](https://github.com/apache/arrow-datafusion/pull/7680) (tshauck) +- Improve cache usage in CI [#7678](https://github.com/apache/arrow-datafusion/pull/7678) (sarutak) +- fix: substrait limit when fetch is None [#7669](https://github.com/apache/arrow-datafusion/pull/7669) (waynexia) +- minor: revert parsing precedence between Aggr and UDAF [#7682](https://github.com/apache/arrow-datafusion/pull/7682) (waynexia) +- Minor: Move hash utils to common [#7684](https://github.com/apache/arrow-datafusion/pull/7684) (jayzhan211) +- Update Default Parquet Write Compression [#7692](https://github.com/apache/arrow-datafusion/pull/7692) (devinjdangelo) +- Stop using cache for the benchmark job [#7706](https://github.com/apache/arrow-datafusion/pull/7706) (sarutak) +- Change rust.yml to run benchmark [#7708](https://github.com/apache/arrow-datafusion/pull/7708) (sarutak) +- Extend infer_placeholder_types to support BETWEEN predicates [#7703](https://github.com/apache/arrow-datafusion/pull/7703) (andrelmartins) +- Minor: Add comment explaining why verify benchmark results uses release mode [#7712](https://github.com/apache/arrow-datafusion/pull/7712) (alamb) +- Support all the codecs supported by Avro [#7718](https://github.com/apache/arrow-datafusion/pull/7718) (sarutak) +- Update substrait requirement from 0.14.0 to 0.15.0 [#7719](https://github.com/apache/arrow-datafusion/pull/7719) (dependabot[bot]) +- fix: coerce text to timestamps with timezones [#7720](https://github.com/apache/arrow-datafusion/pull/7720) (mhilton) +- Add LanceDB to the list of Known Users [#7716](https://github.com/apache/arrow-datafusion/pull/7716) (alamb) +- Enable avro reading/writing in datafusion-cli [#7715](https://github.com/apache/arrow-datafusion/pull/7715) (alamb) +- Document crate feature flags [#7713](https://github.com/apache/arrow-datafusion/pull/7713) (alamb) +- Minor: Consolidate UDF tests [#7704](https://github.com/apache/arrow-datafusion/pull/7704) (alamb) +- Minor: fix CI failure due to Cargo.lock in datafusioncli [#7733](https://github.com/apache/arrow-datafusion/pull/7733) (yjshen) +- MINOR: change file to column index in page_filter trace log [#7730](https://github.com/apache/arrow-datafusion/pull/7730) (mapleFU) +- preserve array type / timezone in `date_bin` and `date_trunc` functions [#7729](https://github.com/apache/arrow-datafusion/pull/7729) (mhilton) +- Remove redundant is_numeric for DataType [#7734](https://github.com/apache/arrow-datafusion/pull/7734) (qrilka) +- fix: avro_to_arrow: Handle avro nested nullable struct (union) [#7663](https://github.com/apache/arrow-datafusion/pull/7663) (Samrose-Ahmed) +- Rename `SessionContext::with_config_rt` to `SessionContext::new_with_config_from_rt`, etc [#7631](https://github.com/apache/arrow-datafusion/pull/7631) (alamb) +- Rename `bounded_order_preserving_variants` config to `prefer_exising_sort` and update docs [#7723](https://github.com/apache/arrow-datafusion/pull/7723) (alamb) +- Optimize "ORDER BY + LIMIT" queries for speed / memory with special TopK operator [#7721](https://github.com/apache/arrow-datafusion/pull/7721) (Dandandan) +- Minor: Improve crate docs [#7740](https://github.com/apache/arrow-datafusion/pull/7740) (alamb) +- [MINOR]: Resolve linter errors in the main [#7753](https://github.com/apache/arrow-datafusion/pull/7753) (mustafasrepo) +- Minor: Build concat_internal() with ListArray construction instead of ArrayData [#7748](https://github.com/apache/arrow-datafusion/pull/7748) (jayzhan211) +- Minor: Add comment on input_schema from AggregateExec [#7727](https://github.com/apache/arrow-datafusion/pull/7727) (viirya) +- Fix column name for COUNT(\*) set by AggregateStatistics [#7757](https://github.com/apache/arrow-datafusion/pull/7757) (qrilka) +- Add documentation about type signatures, and export `TIMEZONE_WILDCARD` [#7726](https://github.com/apache/arrow-datafusion/pull/7726) (alamb) +- [feat] Support cache ListFiles result cache in session level [#7620](https://github.com/apache/arrow-datafusion/pull/7620) (Ted-Jiang) +- Support `SHOW ALL VERBOSE` to show settings description [#7735](https://github.com/apache/arrow-datafusion/pull/7735) (comphead) diff --git a/dev/changelog/33.0.0.md b/dev/changelog/33.0.0.md new file mode 100644 index 0000000000000..17862a64a9512 --- /dev/null +++ b/dev/changelog/33.0.0.md @@ -0,0 +1,292 @@ + + +## [33.0.0](https://github.com/apache/arrow-datafusion/tree/33.0.0) (2023-11-12) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/32.0.0...33.0.0) + +**Breaking changes:** + +- Refactor Statistics, introduce precision estimates (`Exact`, `Inexact`, `Absent`) [#7793](https://github.com/apache/arrow-datafusion/pull/7793) (berkaysynnada) +- Remove redundant unwrap in `ScalarValue::new_primitive`, return a `Result` [#7830](https://github.com/apache/arrow-datafusion/pull/7830) (maruschin) +- Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) +- Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) +- Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) + +**Performance related:** + +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) + +**Implemented enhancements:** + +- Support InsertInto Sorted ListingTable [#7743](https://github.com/apache/arrow-datafusion/pull/7743) (devinjdangelo) +- External Table Primary key support [#7755](https://github.com/apache/arrow-datafusion/pull/7755) (mustafasrepo) +- add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) +- Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) +- Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) +- Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) +- Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) +- Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) +- Support Decimal256 on AVG aggregate expression [#7853](https://github.com/apache/arrow-datafusion/pull/7853) (viirya) +- Support Decimal256 column in create external table [#7866](https://github.com/apache/arrow-datafusion/pull/7866) (viirya) +- Support Decimal256 in Min/Max aggregate expressions [#7881](https://github.com/apache/arrow-datafusion/pull/7881) (viirya) +- Implement Hive-Style Partitioned Write Support [#7801](https://github.com/apache/arrow-datafusion/pull/7801) (devinjdangelo) +- feat: support `Decimal256` for the `abs` function [#7904](https://github.com/apache/arrow-datafusion/pull/7904) (jonahgao) +- Parallelize Serialization of Columns within Parquet RowGroups [#7655](https://github.com/apache/arrow-datafusion/pull/7655) (devinjdangelo) +- feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) +- Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) +- Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) + +**Fixed bugs:** + +- fix: preserve column qualifier for `DataFrame::with_column` [#7792](https://github.com/apache/arrow-datafusion/pull/7792) (jonahgao) +- fix: don't push down volatile predicates in projection [#7909](https://github.com/apache/arrow-datafusion/pull/7909) (haohuaijin) +- fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) +- fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) +- fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) + +**Documentation updates:** + +- Minor: Improve TableProvider document, and add ascii art [#7759](https://github.com/apache/arrow-datafusion/pull/7759) (alamb) +- Expose arrow-schema `serde` crate feature flag [#7829](https://github.com/apache/arrow-datafusion/pull/7829) (lewiszlw) +- doc: fix ExecutionContext to SessionContext in custom-table-providers.md [#7903](https://github.com/apache/arrow-datafusion/pull/7903) (ZENOTME) +- Minor: Document `parquet` crate feature [#7927](https://github.com/apache/arrow-datafusion/pull/7927) (alamb) +- Add some initial content about creating logical plans [#7952](https://github.com/apache/arrow-datafusion/pull/7952) (andygrove) +- Minor: Add implementation examples to ExecutionPlan::execute [#8013](https://github.com/apache/arrow-datafusion/pull/8013) (tustvold) +- Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) +- Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) +- Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) + +**Merged pull requests:** + +- Minor: Improve TableProvider document, and add ascii art [#7759](https://github.com/apache/arrow-datafusion/pull/7759) (alamb) +- Prepare 32.0.0 Release [#7769](https://github.com/apache/arrow-datafusion/pull/7769) (andygrove) +- Minor: Change all file links to GitHub in document [#7768](https://github.com/apache/arrow-datafusion/pull/7768) (ongchi) +- Minor: Improve `PruningPredicate` documentation [#7738](https://github.com/apache/arrow-datafusion/pull/7738) (alamb) +- Support InsertInto Sorted ListingTable [#7743](https://github.com/apache/arrow-datafusion/pull/7743) (devinjdangelo) +- Minor: improve documentation to `stagger_batch` [#7754](https://github.com/apache/arrow-datafusion/pull/7754) (alamb) +- External Table Primary key support [#7755](https://github.com/apache/arrow-datafusion/pull/7755) (mustafasrepo) +- Minor: Build array_array() with ListArray construction instead of ArrayData [#7780](https://github.com/apache/arrow-datafusion/pull/7780) (jayzhan211) +- Minor: Remove unnecessary `#[cfg(feature = "avro")]` [#7773](https://github.com/apache/arrow-datafusion/pull/7773) (sarutak) +- add interval arithmetic for timestamp types [#7758](https://github.com/apache/arrow-datafusion/pull/7758) (mhilton) +- Minor: make tests deterministic [#7771](https://github.com/apache/arrow-datafusion/pull/7771) (Weijun-H) +- Minor: Improve `Interval` Docs [#7782](https://github.com/apache/arrow-datafusion/pull/7782) (alamb) +- `DataSink` additions [#7778](https://github.com/apache/arrow-datafusion/pull/7778) (Dandandan) +- Update substrait requirement from 0.15.0 to 0.16.0 [#7783](https://github.com/apache/arrow-datafusion/pull/7783) (dependabot[bot]) +- Move nested union optimization from plan builder to logical optimizer [#7695](https://github.com/apache/arrow-datafusion/pull/7695) (maruschin) +- Minor: comments that explain the schema used in simply_expressions [#7747](https://github.com/apache/arrow-datafusion/pull/7747) (alamb) +- Update regex-syntax requirement from 0.7.1 to 0.8.0 [#7784](https://github.com/apache/arrow-datafusion/pull/7784) (dependabot[bot]) +- Minor: Add sql test for `UNION` / `UNION ALL` + plans [#7787](https://github.com/apache/arrow-datafusion/pull/7787) (alamb) +- fix: preserve column qualifier for `DataFrame::with_column` [#7792](https://github.com/apache/arrow-datafusion/pull/7792) (jonahgao) +- Interval Arithmetic NegativeExpr Support [#7804](https://github.com/apache/arrow-datafusion/pull/7804) (berkaysynnada) +- Exactness Indicator of Parameters: Precision [#7809](https://github.com/apache/arrow-datafusion/pull/7809) (berkaysynnada) +- add `LogicalPlanBuilder::join_on` [#7805](https://github.com/apache/arrow-datafusion/pull/7805) (haohuaijin) +- Fix SortPreservingRepartition with no existing ordering. [#7811](https://github.com/apache/arrow-datafusion/pull/7811) (mustafasrepo) +- Update zstd requirement from 0.12 to 0.13 [#7806](https://github.com/apache/arrow-datafusion/pull/7806) (dependabot[bot]) +- [Minor]: Remove input_schema field from window executor [#7810](https://github.com/apache/arrow-datafusion/pull/7810) (mustafasrepo) +- refactor(7181): move streaming_merge() into separate mod from the merge node [#7799](https://github.com/apache/arrow-datafusion/pull/7799) (wiedld) +- Improve update error [#7777](https://github.com/apache/arrow-datafusion/pull/7777) (lewiszlw) +- Minor: Update LogicalPlan::join_on API, use it more [#7814](https://github.com/apache/arrow-datafusion/pull/7814) (alamb) +- Add distinct union optimization [#7788](https://github.com/apache/arrow-datafusion/pull/7788) (maruschin) +- Make CI fail on any occurrence of rust-tomlfmt failed [#7774](https://github.com/apache/arrow-datafusion/pull/7774) (ongchi) +- Encode all join conditions in a single expression field [#7612](https://github.com/apache/arrow-datafusion/pull/7612) (nseekhao) +- Update substrait requirement from 0.16.0 to 0.17.0 [#7808](https://github.com/apache/arrow-datafusion/pull/7808) (dependabot[bot]) +- Minor: include `sort` expressions in `SortPreservingRepartitionExec` explain plan [#7796](https://github.com/apache/arrow-datafusion/pull/7796) (alamb) +- minor: add more document to Wildcard expr [#7822](https://github.com/apache/arrow-datafusion/pull/7822) (waynexia) +- Minor: Move `Monotonicity` to `expr` crate [#7820](https://github.com/apache/arrow-datafusion/pull/7820) (2010YOUY01) +- Use code block for better formatting of rustdoc for PhysicalGroupBy [#7823](https://github.com/apache/arrow-datafusion/pull/7823) (qrilka) +- Update explain plan to show `TopK` operator [#7826](https://github.com/apache/arrow-datafusion/pull/7826) (haohuaijin) +- Extract ReceiverStreamBuilder [#7817](https://github.com/apache/arrow-datafusion/pull/7817) (tustvold) +- Extend backtrace coverage for `DatafusionError::Plan` errors errors [#7803](https://github.com/apache/arrow-datafusion/pull/7803) (comphead) +- Add documentation and usability for prepared parameters [#7785](https://github.com/apache/arrow-datafusion/pull/7785) (alamb) +- Implement GetIndexedField for map-typed columns [#7825](https://github.com/apache/arrow-datafusion/pull/7825) (swgillespie) +- Minor: Assert `streaming_merge` has non empty sort exprs [#7795](https://github.com/apache/arrow-datafusion/pull/7795) (alamb) +- Minor: Upgrade docs for `PhysicalExpr::{propagate_constraints, evaluate_bounds}` [#7812](https://github.com/apache/arrow-datafusion/pull/7812) (alamb) +- Change ScalarValue::List to store ArrayRef [#7629](https://github.com/apache/arrow-datafusion/pull/7629) (jayzhan211) +- [MINOR]:Do not introduce unnecessary repartition when row count is 1. [#7832](https://github.com/apache/arrow-datafusion/pull/7832) (mustafasrepo) +- Minor: Add tests for binary / utf8 coercion [#7839](https://github.com/apache/arrow-datafusion/pull/7839) (alamb) +- Avoid panics on error while encoding/decoding ListValue::Array as protobuf [#7837](https://github.com/apache/arrow-datafusion/pull/7837) (alamb) +- Refactor Statistics, introduce precision estimates (`Exact`, `Inexact`, `Absent`) [#7793](https://github.com/apache/arrow-datafusion/pull/7793) (berkaysynnada) +- Remove redundant unwrap in `ScalarValue::new_primitive`, return a `Result` [#7830](https://github.com/apache/arrow-datafusion/pull/7830) (maruschin) +- Fix precision loss when coercing date_part utf8 argument [#7846](https://github.com/apache/arrow-datafusion/pull/7846) (Dandandan) +- Add operator section to user guide, Add `std::ops` operations to `prelude`, and add `not()` expr_fn [#7732](https://github.com/apache/arrow-datafusion/pull/7732) (ongchi) +- Expose arrow-schema `serde` crate feature flag [#7829](https://github.com/apache/arrow-datafusion/pull/7829) (lewiszlw) +- Improve `ContextProvider` naming: rename` get_table_provider` --> `get_table_source`, deprecate `get_table_provider` [#7831](https://github.com/apache/arrow-datafusion/pull/7831) (lewiszlw) +- DataSink Dynamic Execution Time Demux [#7791](https://github.com/apache/arrow-datafusion/pull/7791) (devinjdangelo) +- Add small column on empty projection [#7833](https://github.com/apache/arrow-datafusion/pull/7833) (ch-sc) +- feat(7849): coerce TIMESTAMP to TIMESTAMPTZ [#7850](https://github.com/apache/arrow-datafusion/pull/7850) (mhilton) +- Support `Binary`/`LargeBinary` --> `Utf8`/`LargeUtf8` in ilike and string functions [#7840](https://github.com/apache/arrow-datafusion/pull/7840) (alamb) +- Minor: fix typo in comments [#7856](https://github.com/apache/arrow-datafusion/pull/7856) (haohuaijin) +- Minor: improve `join` / `join_on` docs [#7813](https://github.com/apache/arrow-datafusion/pull/7813) (alamb) +- Support Decimal256 on AVG aggregate expression [#7853](https://github.com/apache/arrow-datafusion/pull/7853) (viirya) +- Minor: fix typo in comments [#7861](https://github.com/apache/arrow-datafusion/pull/7861) (alamb) +- Minor: fix typo in GreedyMemoryPool documentation [#7864](https://github.com/apache/arrow-datafusion/pull/7864) (avh4) +- Minor: fix multiple typos [#7863](https://github.com/apache/arrow-datafusion/pull/7863) (Smoothieewastaken) +- Minor: Fix docstring typos [#7873](https://github.com/apache/arrow-datafusion/pull/7873) (alamb) +- Add CursorValues Decoupling Cursor Data from Cursor Position [#7855](https://github.com/apache/arrow-datafusion/pull/7855) (tustvold) +- Support Decimal256 column in create external table [#7866](https://github.com/apache/arrow-datafusion/pull/7866) (viirya) +- Support Decimal256 in Min/Max aggregate expressions [#7881](https://github.com/apache/arrow-datafusion/pull/7881) (viirya) +- Implement Hive-Style Partitioned Write Support [#7801](https://github.com/apache/arrow-datafusion/pull/7801) (devinjdangelo) +- Minor: fix config typo [#7874](https://github.com/apache/arrow-datafusion/pull/7874) (alamb) +- Add Decimal256 sqllogictests for SUM, MEDIAN and COUNT aggregate expressions [#7889](https://github.com/apache/arrow-datafusion/pull/7889) (viirya) +- [test] add fuzz test for topk [#7772](https://github.com/apache/arrow-datafusion/pull/7772) (Tangruilin) +- Allow Setting Minimum Parallelism with RowCount Based Demuxer [#7841](https://github.com/apache/arrow-datafusion/pull/7841) (devinjdangelo) +- Drop single quotes to make warnings for parquet options not confusing [#7902](https://github.com/apache/arrow-datafusion/pull/7902) (qrilka) +- Add multi-column topk fuzz tests [#7898](https://github.com/apache/arrow-datafusion/pull/7898) (alamb) +- Change `FileScanConfig.table_partition_cols` from `(String, DataType)` to `Field`s [#7890](https://github.com/apache/arrow-datafusion/pull/7890) (NGA-TRAN) +- Maintain time zone in `ScalarValue::new_list` [#7899](https://github.com/apache/arrow-datafusion/pull/7899) (Dandandan) +- [MINOR]: Move joinside struct to common [#7908](https://github.com/apache/arrow-datafusion/pull/7908) (mustafasrepo) +- doc: fix ExecutionContext to SessionContext in custom-table-providers.md [#7903](https://github.com/apache/arrow-datafusion/pull/7903) (ZENOTME) +- Update arrow 48.0.0 [#7854](https://github.com/apache/arrow-datafusion/pull/7854) (tustvold) +- feat: support `Decimal256` for the `abs` function [#7904](https://github.com/apache/arrow-datafusion/pull/7904) (jonahgao) +- [MINOR] Simplify Aggregate, and Projection output_partitioning implementation [#7907](https://github.com/apache/arrow-datafusion/pull/7907) (mustafasrepo) +- Bump actions/setup-node from 3 to 4 [#7915](https://github.com/apache/arrow-datafusion/pull/7915) (dependabot[bot]) +- [Bug Fix]: Fix bug, first last reverse [#7914](https://github.com/apache/arrow-datafusion/pull/7914) (mustafasrepo) +- Minor: provide default implementation for ExecutionPlan::statistics [#7911](https://github.com/apache/arrow-datafusion/pull/7911) (alamb) +- Update substrait requirement from 0.17.0 to 0.18.0 [#7916](https://github.com/apache/arrow-datafusion/pull/7916) (dependabot[bot]) +- Minor: Remove unnecessary clone in datafusion_proto [#7921](https://github.com/apache/arrow-datafusion/pull/7921) (ongchi) +- [MINOR]: Simplify code, change requirement from PhysicalSortExpr to PhysicalSortRequirement [#7913](https://github.com/apache/arrow-datafusion/pull/7913) (mustafasrepo) +- [Minor] Move combine_join util to under equivalence.rs [#7917](https://github.com/apache/arrow-datafusion/pull/7917) (mustafasrepo) +- support scan empty projection [#7920](https://github.com/apache/arrow-datafusion/pull/7920) (haohuaijin) +- Cleanup logical optimizer rules. [#7919](https://github.com/apache/arrow-datafusion/pull/7919) (mustafasrepo) +- Parallelize Serialization of Columns within Parquet RowGroups [#7655](https://github.com/apache/arrow-datafusion/pull/7655) (devinjdangelo) +- feat: Use bloom filter when reading parquet to skip row groups [#7821](https://github.com/apache/arrow-datafusion/pull/7821) (hengfeiyang) +- fix: don't push down volatile predicates in projection [#7909](https://github.com/apache/arrow-datafusion/pull/7909) (haohuaijin) +- Add `parquet` feature flag, enabled by default, and make parquet conditional [#7745](https://github.com/apache/arrow-datafusion/pull/7745) (ongchi) +- [MINOR]: Simplify enforce_distribution, minor changes [#7924](https://github.com/apache/arrow-datafusion/pull/7924) (mustafasrepo) +- Add simple window query to sqllogictest [#7928](https://github.com/apache/arrow-datafusion/pull/7928) (Jefffrey) +- ci: upgrade node to version 20 [#7918](https://github.com/apache/arrow-datafusion/pull/7918) (crepererum) +- Change input for `to_timestamp` function to be seconds rather than nanoseconds, add `to_timestamp_nanos` [#7844](https://github.com/apache/arrow-datafusion/pull/7844) (comphead) +- Minor: Document `parquet` crate feature [#7927](https://github.com/apache/arrow-datafusion/pull/7927) (alamb) +- Minor: reduce some `#cfg(feature = "parquet")` [#7929](https://github.com/apache/arrow-datafusion/pull/7929) (alamb) +- Minor: reduce use of `#cfg(feature = "parquet")` in tests [#7930](https://github.com/apache/arrow-datafusion/pull/7930) (alamb) +- Fix CI failures on `to_timestamp()` calls [#7941](https://github.com/apache/arrow-datafusion/pull/7941) (comphead) +- minor: add a datatype casting for the updated value [#7922](https://github.com/apache/arrow-datafusion/pull/7922) (jonahgao) +- Minor:add `avro` feature in datafusion-examples to make `avro_sql` run [#7946](https://github.com/apache/arrow-datafusion/pull/7946) (haohuaijin) +- Add simple exclude all columns test to sqllogictest [#7945](https://github.com/apache/arrow-datafusion/pull/7945) (Jefffrey) +- Support Partitioning Data by Dictionary Encoded String Array Types [#7896](https://github.com/apache/arrow-datafusion/pull/7896) (devinjdangelo) +- Minor: Remove array() in array_expression [#7961](https://github.com/apache/arrow-datafusion/pull/7961) (jayzhan211) +- Minor: simplify update code [#7943](https://github.com/apache/arrow-datafusion/pull/7943) (alamb) +- Add some initial content about creating logical plans [#7952](https://github.com/apache/arrow-datafusion/pull/7952) (andygrove) +- Minor: Change from `&mut SessionContext` to `&SessionContext` in substrait [#7965](https://github.com/apache/arrow-datafusion/pull/7965) (my-vegetable-has-exploded) +- Fix crate READMEs [#7964](https://github.com/apache/arrow-datafusion/pull/7964) (Jefffrey) +- Minor: Improve `HashJoinExec` documentation [#7953](https://github.com/apache/arrow-datafusion/pull/7953) (alamb) +- chore: clean useless clone baesd on clippy [#7973](https://github.com/apache/arrow-datafusion/pull/7973) (Weijun-H) +- Add README.md to `core`, `execution` and `physical-plan` crates [#7970](https://github.com/apache/arrow-datafusion/pull/7970) (alamb) +- Move source repartitioning into `ExecutionPlan::repartition` [#7936](https://github.com/apache/arrow-datafusion/pull/7936) (alamb) +- minor: fix broken links in README.md [#7986](https://github.com/apache/arrow-datafusion/pull/7986) (jonahgao) +- Minor: Upate the `sqllogictest` crate README [#7971](https://github.com/apache/arrow-datafusion/pull/7971) (alamb) +- Improve MemoryCatalogProvider default impl block placement [#7975](https://github.com/apache/arrow-datafusion/pull/7975) (lewiszlw) +- Fix `ScalarValue` handling of NULL values for ListArray [#7969](https://github.com/apache/arrow-datafusion/pull/7969) (viirya) +- Refactor of Ordering and Prunability Traversals and States [#7985](https://github.com/apache/arrow-datafusion/pull/7985) (berkaysynnada) +- Keep output as scalar for scalar function if all inputs are scalar [#7967](https://github.com/apache/arrow-datafusion/pull/7967) (viirya) +- Fix crate READMEs for core, execution, physical-plan [#7990](https://github.com/apache/arrow-datafusion/pull/7990) (Jefffrey) +- Update sqlparser requirement from 0.38.0 to 0.39.0 [#7983](https://github.com/apache/arrow-datafusion/pull/7983) (jackwener) +- Fix panic in multiple distinct aggregates by fixing `ScalarValue::new_list` [#7989](https://github.com/apache/arrow-datafusion/pull/7989) (alamb) +- Minor: Add `MemoryReservation::consumer` getter [#8000](https://github.com/apache/arrow-datafusion/pull/8000) (milenkovicm) +- fix: generate logical plan for `UPDATE SET FROM` statement [#7984](https://github.com/apache/arrow-datafusion/pull/7984) (jonahgao) +- Create temporary files for reading or writing [#8005](https://github.com/apache/arrow-datafusion/pull/8005) (smallzhongfeng) +- Minor: fix comment on SortExec::with_fetch method [#8011](https://github.com/apache/arrow-datafusion/pull/8011) (westonpace) +- Fix: dataframe_subquery example Optimizer rule `common_sub_expression_eliminate` failed [#8016](https://github.com/apache/arrow-datafusion/pull/8016) (smallzhongfeng) +- Percent Decode URL Paths (#8009) [#8012](https://github.com/apache/arrow-datafusion/pull/8012) (tustvold) +- Minor: Extract common deps into workspace [#7982](https://github.com/apache/arrow-datafusion/pull/7982) (lewiszlw) +- minor: change some plan_err to exec_err [#7996](https://github.com/apache/arrow-datafusion/pull/7996) (waynexia) +- Minor: error on unsupported RESPECT NULLs syntax [#7998](https://github.com/apache/arrow-datafusion/pull/7998) (alamb) +- Break GroupedHashAggregateStream spill batch into smaller chunks [#8004](https://github.com/apache/arrow-datafusion/pull/8004) (milenkovicm) +- Minor: Add implementation examples to ExecutionPlan::execute [#8013](https://github.com/apache/arrow-datafusion/pull/8013) (tustvold) +- Minor: Extend wrap_into_list_array to accept multiple args [#7993](https://github.com/apache/arrow-datafusion/pull/7993) (jayzhan211) +- GroupedHashAggregateStream should register spillable consumer [#8002](https://github.com/apache/arrow-datafusion/pull/8002) (milenkovicm) +- fix: single_distinct_aggretation_to_group_by fail [#7997](https://github.com/apache/arrow-datafusion/pull/7997) (haohuaijin) +- Read only enough bytes to infer Arrow IPC file schema via stream [#7962](https://github.com/apache/arrow-datafusion/pull/7962) (Jefffrey) +- Minor: remove a strange char [#8030](https://github.com/apache/arrow-datafusion/pull/8030) (haohuaijin) +- Minor: Improve documentation for Filter Pushdown [#8023](https://github.com/apache/arrow-datafusion/pull/8023) (alamb) +- Minor: Improve `ExecutionPlan` documentation [#8019](https://github.com/apache/arrow-datafusion/pull/8019) (alamb) +- fix: clippy warnings from nightly rust 1.75 [#8025](https://github.com/apache/arrow-datafusion/pull/8025) (waynexia) +- Minor: Avoid recomputing compute_array_ndims in align_array_dimensions [#7963](https://github.com/apache/arrow-datafusion/pull/7963) (jayzhan211) +- Minor: fix doc and fmt CI check [#8037](https://github.com/apache/arrow-datafusion/pull/8037) (alamb) +- Minor: remove uncessary #cfg test [#8036](https://github.com/apache/arrow-datafusion/pull/8036) (alamb) +- Minor: Improve documentation for `PartitionStream` and `StreamingTableExec` [#8035](https://github.com/apache/arrow-datafusion/pull/8035) (alamb) +- Combine Equivalence and Ordering equivalence to simplify state [#8006](https://github.com/apache/arrow-datafusion/pull/8006) (mustafasrepo) +- Encapsulate `ProjectionMapping` as a struct [#8033](https://github.com/apache/arrow-datafusion/pull/8033) (alamb) +- Minor: Fix bugs in docs for `to_timestamp`, `to_timestamp_seconds`, ... [#8040](https://github.com/apache/arrow-datafusion/pull/8040) (alamb) +- Improve comments for `PartitionSearchMode` struct [#8047](https://github.com/apache/arrow-datafusion/pull/8047) (ozankabak) +- General approach for Array replace [#8050](https://github.com/apache/arrow-datafusion/pull/8050) (jayzhan211) +- Minor: Remove the irrelevant note from the Expression API doc [#8053](https://github.com/apache/arrow-datafusion/pull/8053) (ongchi) +- Minor: Add more documentation about Partitioning [#8022](https://github.com/apache/arrow-datafusion/pull/8022) (alamb) +- Minor: improve documentation for IsNotNull, DISTINCT, etc [#8052](https://github.com/apache/arrow-datafusion/pull/8052) (alamb) +- Prepare 33.0.0 Release [#8057](https://github.com/apache/arrow-datafusion/pull/8057) (andygrove) +- Minor: improve error message by adding types to message [#8065](https://github.com/apache/arrow-datafusion/pull/8065) (alamb) +- Minor: Remove redundant BuiltinScalarFunction::supports_zero_argument() [#8059](https://github.com/apache/arrow-datafusion/pull/8059) (2010YOUY01) +- Add example to ci [#8060](https://github.com/apache/arrow-datafusion/pull/8060) (smallzhongfeng) +- Update substrait requirement from 0.18.0 to 0.19.0 [#8076](https://github.com/apache/arrow-datafusion/pull/8076) (dependabot[bot]) +- Fix incorrect results in COUNT(\*) queries with LIMIT [#8049](https://github.com/apache/arrow-datafusion/pull/8049) (msirek) +- feat: Support determining extensions from names like `foo.parquet.snappy` as well as `foo.parquet` [#7972](https://github.com/apache/arrow-datafusion/pull/7972) (Weijun-H) +- Use FairSpillPool for TaskContext with spillable config [#8072](https://github.com/apache/arrow-datafusion/pull/8072) (viirya) +- Minor: Improve HashJoinStream docstrings [#8070](https://github.com/apache/arrow-datafusion/pull/8070) (alamb) +- Fixing broken link [#8085](https://github.com/apache/arrow-datafusion/pull/8085) (edmondop) +- fix: DataFusion suggests invalid functions [#8083](https://github.com/apache/arrow-datafusion/pull/8083) (jonahgao) +- Replace macro with function for `array_repeat` [#8071](https://github.com/apache/arrow-datafusion/pull/8071) (jayzhan211) +- Minor: remove unnecessary projection in `single_distinct_to_group_by` rule [#8061](https://github.com/apache/arrow-datafusion/pull/8061) (haohuaijin) +- minor: Remove duplicate version numbers for arrow, object_store, and parquet dependencies [#8095](https://github.com/apache/arrow-datafusion/pull/8095) (andygrove) +- fix: add encode/decode to protobuf encoding [#8089](https://github.com/apache/arrow-datafusion/pull/8089) (Syleechan) +- feat: Protobuf serde for Json file sink [#8062](https://github.com/apache/arrow-datafusion/pull/8062) (Jefffrey) +- Minor: use `Expr::alias` in a few places to make the code more concise [#8097](https://github.com/apache/arrow-datafusion/pull/8097) (alamb) +- Minor: Cleanup BuiltinScalarFunction::return_type() [#8088](https://github.com/apache/arrow-datafusion/pull/8088) (2010YOUY01) +- Update sqllogictest requirement from 0.17.0 to 0.18.0 [#8102](https://github.com/apache/arrow-datafusion/pull/8102) (dependabot[bot]) +- Projection Pushdown in PhysicalPlan [#8073](https://github.com/apache/arrow-datafusion/pull/8073) (berkaysynnada) +- Push limit into aggregation for DISTINCT ... LIMIT queries [#8038](https://github.com/apache/arrow-datafusion/pull/8038) (msirek) +- Bug-fix in Filter and Limit statistics [#8094](https://github.com/apache/arrow-datafusion/pull/8094) (berkaysynnada) +- feat: support target table alias in update statement [#8080](https://github.com/apache/arrow-datafusion/pull/8080) (jonahgao) +- Minor: Simlify downcast functions in cast.rs. [#8103](https://github.com/apache/arrow-datafusion/pull/8103) (Weijun-H) +- Fix ArrayAgg schema mismatch issue [#8055](https://github.com/apache/arrow-datafusion/pull/8055) (jayzhan211) +- Minor: Support `nulls` in `array_replace`, avoid a copy [#8054](https://github.com/apache/arrow-datafusion/pull/8054) (alamb) +- Minor: Improve the document format of JoinHashMap [#8090](https://github.com/apache/arrow-datafusion/pull/8090) (Asura7969) +- Simplify ProjectionPushdown and make it more general [#8109](https://github.com/apache/arrow-datafusion/pull/8109) (alamb) +- Minor: clean up the code regarding clippy [#8122](https://github.com/apache/arrow-datafusion/pull/8122) (Weijun-H) +- Support remaining functions in protobuf serialization, add `expr_fn` for `StructFunction` [#8100](https://github.com/apache/arrow-datafusion/pull/8100) (JacobOgle) +- Minor: Cleanup BuiltinScalarFunction's phys-expr creation [#8114](https://github.com/apache/arrow-datafusion/pull/8114) (2010YOUY01) +- rewrite `array_append/array_prepend` to remove deplicate codes [#8108](https://github.com/apache/arrow-datafusion/pull/8108) (Veeupup) +- Implementation of `array_intersect` [#8081](https://github.com/apache/arrow-datafusion/pull/8081) (Veeupup) +- Minor: fix ci break [#8136](https://github.com/apache/arrow-datafusion/pull/8136) (haohuaijin) +- Improve documentation for calculate_prune_length method in `SymmetricHashJoin` [#8125](https://github.com/apache/arrow-datafusion/pull/8125) (Asura7969) +- Minor: remove duplicated `array_replace` tests [#8066](https://github.com/apache/arrow-datafusion/pull/8066) (alamb) +- Minor: Fix temporary files created but not deleted during testing [#8115](https://github.com/apache/arrow-datafusion/pull/8115) (2010YOUY01) +- chore: remove panics in datafusion-common::scalar by making more operations return `Result` [#7901](https://github.com/apache/arrow-datafusion/pull/7901) (junjunjd) +- Fix join order for TPCH Q17 & Q18 by improving FilterExec statistics [#8126](https://github.com/apache/arrow-datafusion/pull/8126) (andygrove) +- Fix: Do not try and preserve order when there is no order to preserve in RepartitionExec [#8127](https://github.com/apache/arrow-datafusion/pull/8127) (alamb) +- feat: add column statistics into explain [#8112](https://github.com/apache/arrow-datafusion/pull/8112) (NGA-TRAN) +- Add subtrait support for `IS NULL` and `IS NOT NULL` [#8093](https://github.com/apache/arrow-datafusion/pull/8093) (tgujar) +- Combine `Expr::Wildcard` and `Wxpr::QualifiedWildcard`, add `wildcard()` expr fn [#8105](https://github.com/apache/arrow-datafusion/pull/8105) (alamb) +- docs: show creation of DFSchema [#8132](https://github.com/apache/arrow-datafusion/pull/8132) (wjones127) +- feat: support UDAF in substrait producer/consumer [#8119](https://github.com/apache/arrow-datafusion/pull/8119) (waynexia) +- Improve documentation site to make it easier to find communication on Slack/Discord [#8138](https://github.com/apache/arrow-datafusion/pull/8138) (alamb) diff --git a/dev/release/README.md b/dev/release/README.md index ac180632367cf..53487678aa693 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -82,7 +82,7 @@ You will need a GitHub Personal Access Token for the following steps. Follow [these instructions](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token) to generate one if you do not already have one. -The changelog is generated using a Python script. There is a depency on `PyGitHub`, which can be installed using pip: +The changelog is generated using a Python script. There is a dependency on `PyGitHub`, which can be installed using pip: ```bash pip3 install PyGitHub @@ -284,7 +284,11 @@ of the following crates: - [datafusion-expr](https://crates.io/crates/datafusion-expr) - [datafusion-physical-expr](https://crates.io/crates/datafusion-physical-expr) - [datafusion-proto](https://crates.io/crates/datafusion-proto) -- [datafusion-row](https://crates.io/crates/datafusion-row) +- [datafusion-execution](https://crates.io/crates/datafusion-execution) +- [datafusion-physical-plan](https://crates.io/crates/datafusion-physical-plan) +- [datafusion-sql](https://crates.io/crates/datafusion-sql) +- [datafusion-optimizer](https://crates.io/crates/datafusion-optimizer) +- [datafusion-substrait](https://crates.io/crates/datafusion-substrait) Download and unpack the official release tarball @@ -308,10 +312,10 @@ dot -Tsvg dev/release/crate-deps.dot > dev/release/crate-deps.svg (cd datafusion/common && cargo publish) (cd datafusion/expr && cargo publish) (cd datafusion/sql && cargo publish) -(cd datafusion/row && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/execution && cargo publish) +(cd datafusion/physical-plan && cargo publish) (cd datafusion/core && cargo publish) (cd datafusion/proto && cargo publish) (cd datafusion/substrait && cargo publish) @@ -385,15 +389,16 @@ You can include mention crates.io and PyPI version URLs in the email if applicab ``` We have published new versions of DataFusion to crates.io: -https://crates.io/crates/datafusion/8.0.0 -https://crates.io/crates/datafusion-cli/8.0.0 -https://crates.io/crates/datafusion-common/8.0.0 -https://crates.io/crates/datafusion-expr/8.0.0 -https://crates.io/crates/datafusion-optimizer/8.0.0 -https://crates.io/crates/datafusion-physical-expr/8.0.0 -https://crates.io/crates/datafusion-proto/8.0.0 -https://crates.io/crates/datafusion-row/8.0.0 -https://crates.io/crates/datafusion-sql/8.0.0 +https://crates.io/crates/datafusion/28.0.0 +https://crates.io/crates/datafusion-cli/28.0.0 +https://crates.io/crates/datafusion-common/28.0.0 +https://crates.io/crates/datafusion-expr/28.0.0 +https://crates.io/crates/datafusion-optimizer/28.0.0 +https://crates.io/crates/datafusion-physical-expr/28.0.0 +https://crates.io/crates/datafusion-proto/28.0.0 +https://crates.io/crates/datafusion-sql/28.0.0 +https://crates.io/crates/datafusion-execution/28.0.0 +https://crates.io/crates/datafusion-substrait/28.0.0 ``` ### Add the release to Apache Reporter diff --git a/dev/release/crate-deps.dot b/dev/release/crate-deps.dot index a2199befaf8ed..618eb56afb75b 100644 --- a/dev/release/crate-deps.dot +++ b/dev/release/crate-deps.dot @@ -30,13 +30,20 @@ digraph G { datafusion_physical_expr -> datafusion_common datafusion_physical_expr -> datafusion_expr - datafusion_row -> datafusion_common + datafusion_execution -> datafusion_common + datafusion_execution -> datafusion_expr + + datafusion_physical_plan -> datafusion_common + datafusion_physical_plan -> datafusion_execution + datafusion_physical_plan -> datafusion_expr + datafusion_physical_plan -> datafusion_physical_expr datafusion -> datafusion_common + datafusion -> datafusion_execution datafusion -> datafusion_expr datafusion -> datafusion_optimizer datafusion -> datafusion_physical_expr - datafusion -> datafusion_row + datafusion -> datafusion_physical_plan datafusion -> datafusion_sql datafusion_proto -> datafusion diff --git a/dev/release/crate-deps.svg b/dev/release/crate-deps.svg index f55a5fcd7b246..a7c7b7fe4acd0 100644 --- a/dev/release/crate-deps.svg +++ b/dev/release/crate-deps.svg @@ -1,175 +1,217 @@ - - - + + G - + datafusion_common - -datafusion_common + +datafusion_common datafusion_expr - -datafusion_expr + +datafusion_expr datafusion_expr->datafusion_common - - + + datafusion_sql - -datafusion_sql + +datafusion_sql datafusion_sql->datafusion_common - - + + datafusion_sql->datafusion_expr - - + + datafusion_optimizer - -datafusion_optimizer + +datafusion_optimizer datafusion_optimizer->datafusion_common - - + + datafusion_optimizer->datafusion_expr - - + + datafusion_physical_expr - -datafusion_physical_expr + +datafusion_physical_expr datafusion_physical_expr->datafusion_common - - + + datafusion_physical_expr->datafusion_expr - - + + - + -datafusion_row - -datafusion_row +datafusion_execution + +datafusion_execution - + -datafusion_row->datafusion_common - - +datafusion_execution->datafusion_common + + - + + +datafusion_execution->datafusion_expr + + + + +datafusion_physical_plan + +datafusion_physical_plan + + + +datafusion_physical_plan->datafusion_common + + + + + +datafusion_physical_plan->datafusion_expr + + + + + +datafusion_physical_plan->datafusion_physical_expr + + + + + +datafusion_physical_plan->datafusion_execution + + + + + datafusion - -datafusion + +datafusion - + datafusion->datafusion_common - - + + - + datafusion->datafusion_expr - - + + - + datafusion->datafusion_sql - - + + - + datafusion->datafusion_optimizer - - + + - + datafusion->datafusion_physical_expr - - + + - - -datafusion->datafusion_row - - + + +datafusion->datafusion_execution + + + + + +datafusion->datafusion_physical_plan + + - + datafusion_proto - -datafusion_proto + +datafusion_proto - + datafusion_proto->datafusion - - + + - + datafusion_substrait - -datafusion_substrait + +datafusion_substrait - + datafusion_substrait->datafusion - - + + - + datafusion_cli - -datafusion_cli + +datafusion_cli - + datafusion_cli->datafusion - - + + diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index ff9e8d4754b2a..f419bdb3a1ac7 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -57,6 +57,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): bugs = [] docs = [] enhancements = [] + performance = [] # categorize the pull requests based on GitHub labels print("Categorizing pull requests", file=sys.stderr) @@ -79,6 +80,8 @@ def generate_changelog(repo, repo_name, tag1, tag2): breaking.append((pull, commit)) elif 'bug' in labels or cc_type == 'fix': bugs.append((pull, commit)) + elif 'performance' in labels or cc_type == 'perf': + performance.append((pull, commit)) elif 'enhancement' in labels or cc_type == 'feat': enhancements.append((pull, commit)) elif 'documentation' in labels or cc_type == 'docs': @@ -87,6 +90,7 @@ def generate_changelog(repo, repo_name, tag1, tag2): # produce the changelog content print("Generating changelog content", file=sys.stderr) print_pulls(repo_name, "Breaking changes", breaking) + print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) print_pulls(repo_name, "Fixed bugs", bugs) print_pulls(repo_name, "Documentation updates", docs) diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index e02ecc93b8843..f99d6e15e869f 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -115,6 +115,7 @@ python/requirements*.txt benchmarks/queries/* benchmarks/expected-plans/* benchmarks/data/* +datafusion-cli/tests/data/* ci/* **/*.svg **/*.csv @@ -133,4 +134,6 @@ datafusion/proto/src/generated/pbjson.rs datafusion/proto/src/generated/prost.rs .github/ISSUE_TEMPLATE/bug_report.yml .github/ISSUE_TEMPLATE/feature_request.yml -.github/workflows/docs.yaml \ No newline at end of file +.github/workflows/docs.yaml +**/node_modules/* +datafusion/wasmtest/pkg/* \ No newline at end of file diff --git a/dev/release/release-crates.sh b/dev/release/release-crates.sh index 658ec88b899da..00ce77a86749f 100644 --- a/dev/release/release-crates.sh +++ b/dev/release/release-crates.sh @@ -32,11 +32,12 @@ if ! [ git rev-parse --is-inside-work-tree ]; then cd datafusion/common && cargo publish cd datafusion/expr && cargo publish cd datafusion/sql && cargo publish - cd datafusion/row && cargo publish cd datafusion/physical-expr && cargo publish cd datafusion/optimizer && cargo publish cd datafusion/core && cargo publish cd datafusion/proto && cargo publish + cd datafusion/execution && cargo publish + cd datafusion/substrait && cargo publish cd datafusion-cli && cargo publish --no-verify else echo "Crates must be released from the source tarball that was voted on, not from the repo" diff --git a/dev/update_datafusion_versions.py b/dev/update_datafusion_versions.py index fd4bfadb9ed0f..19701b813671e 100755 --- a/dev/update_datafusion_versions.py +++ b/dev/update_datafusion_versions.py @@ -35,12 +35,15 @@ 'datafusion-execution': 'datafusion/execution/Cargo.toml', 'datafusion-optimizer': 'datafusion/optimizer/Cargo.toml', 'datafusion-physical-expr': 'datafusion/physical-expr/Cargo.toml', + 'datafusion-physical-plan': 'datafusion/physical-plan/Cargo.toml', 'datafusion-proto': 'datafusion/proto/Cargo.toml', - 'datafusion-row': 'datafusion/row/Cargo.toml', 'datafusion-substrait': 'datafusion/substrait/Cargo.toml', 'datafusion-sql': 'datafusion/sql/Cargo.toml', + 'datafusion-sqllogictest': 'datafusion/sqllogictest/Cargo.toml', + 'datafusion-wasmtest': 'datafusion/wasmtest/Cargo.toml', 'datafusion-benchmarks': 'benchmarks/Cargo.toml', 'datafusion-examples': 'datafusion-examples/Cargo.toml', + 'datafusion-docs': 'docs/Cargo.toml', } def update_workspace_version(new_version: str): diff --git a/docs/.gitignore b/docs/.gitignore index e1ba8440c124b..e2a54c053edf9 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -18,3 +18,4 @@ build temp venv/ +.python-version diff --git a/datafusion/row/Cargo.toml b/docs/Cargo.toml similarity index 75% rename from datafusion/row/Cargo.toml rename to docs/Cargo.toml index 73ad9c2874d8f..4d01466924f99 100644 --- a/datafusion/row/Cargo.toml +++ b/docs/Cargo.toml @@ -16,9 +16,9 @@ # under the License. [package] -name = "datafusion-row" -description = "Row backed by raw bytes for DataFusion query engine" -keywords = [ "arrow", "query", "sql" ] +name = "datafusion-docs-tests" +description = "DataFusion Documentation Tests" +publish = false version = { workspace = true } edition = { workspace = true } readme = { workspace = true } @@ -26,14 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = { workspace = true } - -[lib] -name = "datafusion_row" -path = "src/lib.rs" +rust-version = "1.70" [dependencies] -arrow = { workspace = true } -datafusion-common = { path = "../common", version = "26.0.0" } -paste = "^1.0" -rand = "0.8" +datafusion = { path = "../datafusion/core", version = "33.0.0", default-features = false } diff --git a/docs/source/_static/theme_overrides.css b/docs/source/_static/theme_overrides.css index 838eab067afcb..3b1b86daac6aa 100644 --- a/docs/source/_static/theme_overrides.css +++ b/docs/source/_static/theme_overrides.css @@ -49,7 +49,7 @@ code { } /* This is the bootstrap CSS style for "table-striped". Since the theme does -not yet provide an easy way to configure this globaly, it easier to simply +not yet provide an easy way to configure this globally, it easier to simply include this snippet here than updating each table in all rst files to add ":class: table-striped" */ @@ -59,7 +59,7 @@ add ":class: table-striped" */ /* Limit the max height of the sidebar navigation section. Because in our -custimized template, there is more content above the navigation, i.e. +customized template, there is more content above the navigation, i.e. larger logo: if we don't decrease the max-height, it will overlap with the footer. Details: 8rem for search box etc*/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 29bae4b8acb2d..3fa6c6091d6fa 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,7 +34,7 @@ # -- Project information ----------------------------------------------------- project = 'Arrow DataFusion' -copyright = '2022, Apache Software Foundation' +copyright = '2023, Apache Software Foundation' author = 'Arrow DataFusion Authors' @@ -102,7 +102,13 @@ html_logo = "_static/images/DataFusion-Logo-Background-White.png" -html_css_files = ["theme_overrides.css"] +html_css_files = [ + "theme_overrides.css" +] + +html_js_files = [ + ("https://buttons.github.io/buttons.js", {'async': 'true', 'defer': 'true'}), +] html_sidebars = { "**": ["docs-sidebar.html"], @@ -112,4 +118,4 @@ myst_heading_anchors = 3 # enable nice rendering of checkboxes for the task lists -myst_enable_extensions = [ "tasklist"] +myst_enable_extensions = ["colon_fence", "deflist", "tasklist"] diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 11e0e4e0f0eaa..8678aa534baf0 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -26,15 +26,25 @@ All participation in the Apache Arrow DataFusion project is governed by the Apache Software Foundation's [code of conduct](https://www.apache.org/foundation/policies/conduct.html). +## GitHub + The vast majority of communication occurs in the open on our -[github repository](https://github.com/apache/arrow-datafusion). +[github repository](https://github.com/apache/arrow-datafusion) in the form of tickets, issues, discussions, and Pull Requests. + +## Slack and Discord -## Questions? +We use the Slack and Discord platforms for informal discussions and coordination. These are great places to +meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and +decisions are made fully in the open, on GitHub. -### Mailing list +Most of us use the `#arrow-datafusion` and `#arrow-rust` channels in the [ASF Slack workspace](https://s.apache.org/slack-invite) . +Unfortunately, due to spammers, the ASF Slack workspace requires an invitation to join. To get an invitation, +request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https://discord.gg/Qw5gKqHxUM). -We use arrow.apache.org's `dev@` mailing list for project management, release -coordination and design discussions +## Mailing list + +We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other +than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). @@ -42,33 +52,3 @@ coordination and design discussions When emailing the dev list, please make sure to prefix the subject line with a `[DataFusion]` tag, e.g. `"[DataFusion] New API for remote data sources"`, so that the appropriate people in the Apache Arrow community notice the message. - -### Slack and Discord - -We use the official [ASF](https://s.apache.org/slack-invite) Slack workspace -for informal discussions and coordination. This is a great place to meet other -contributors and get guidance on where to contribute. Join us in the -`#arrow-rust` channel. - -We also have a backup Arrow Rust Discord -server ([invite link](https://discord.gg/Qw5gKqHxUM)) in case you are not able -to join the Slack workspace. If you need an invite to the Slack workspace, you -can also ask for one in our Discord server. - -### Sync up video calls - -We have biweekly sync calls every other Thursdays at both 04:00 UTC -and 16:00 UTC (starting September 30, 2021) depending on if there are -items on the agenda to discuss and someone being willing to host. - -Please see the [agenda](https://docs.google.com/document/d/1atCVnoff5SR4eM4Lwf2M1BBJTY6g3_HUNR6qswYJW_U/edit) -for the video call link, add topics and to see what others plan to discuss. - -The goals of these calls are: - -1. Help "put a face to the name" of some of other contributors we are working with -2. Discuss / synchronize on the goals and major initiatives from different stakeholders to identify areas where more alignment is needed - -No decisions are made on the call and anything of substance will be discussed on the mailing list or in github issues / google docs. - -We will send a summary of all sync ups to the dev@arrow.apache.org mailing list. diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 5bb4e26b9525b..8d69ade83d72e 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -151,7 +151,7 @@ Tests for code in an individual module are defined in the same source file with ### sqllogictests Tests -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/core/tests/sqllogictests) which are run like any other Rust test using `cargo test --test sqllogictests`. +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/arrow-datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. `sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. @@ -221,9 +221,11 @@ Below is a checklist of what you need to do to add a new scalar function to Data - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_physical_expr`/`create_physical_fun` mapping the built-in to the implementation - tests to the function. -- In [core/tests/sql](../../../datafusion/core/tests/sql), add a new test where the function is called through SQL against well known data and returns the expected result. +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) - In [expr/src/expr_fn.rs](../../../datafusion/expr/src/expr_fn.rs), add: - a new entry of the `unary_scalar_expr!` macro for the new function. +- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) ### How to add a new aggregate function @@ -241,7 +243,9 @@ Below is a checklist of what you need to do to add a new aggregate function to D - a new line in `signature` with the signature of the function (number and types of its arguments) - a new line in `create_aggregate_expr` mapping the built-in to the implementation - tests to the function. -- In [tests/sql](../../../datafusion/core/tests/sql), add a new test where the function is called through SQL against well known data and returns the expected result. +- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) +- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) ### How to display plans graphically diff --git a/docs/source/contributor-guide/quarterly_roadmap.md b/docs/source/contributor-guide/quarterly_roadmap.md index 77b22852bf500..26c503f7e506e 100644 --- a/docs/source/contributor-guide/quarterly_roadmap.md +++ b/docs/source/contributor-guide/quarterly_roadmap.md @@ -21,6 +21,12 @@ A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. +## 2023 Q4 + +- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/arrow-datafusion/issues/6569) +- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/arrow-datafusion/issues/6980) +- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/arrow-datafusion/issues/6782) + ## 2022 Q2 ### DataFusion Core diff --git a/docs/source/contributor-guide/roadmap.md b/docs/source/contributor-guide/roadmap.md index 8413fef20d2df..a7e81555b77a5 100644 --- a/docs/source/contributor-guide/roadmap.md +++ b/docs/source/contributor-guide/roadmap.md @@ -19,100 +19,27 @@ under the License. # Roadmap -This document describes high level goals of the DataFusion and -Ballista development community. It is not meant to restrict -possibilities, but rather help newcomers understand the broader -context of where the community is headed, and inspire -additional contributions. - -DataFusion and Ballista are part of the [Apache -Arrow](https://arrow.apache.org/) project and governed by the Apache -Software Foundation governance model. These projects are entirely -driven by volunteers, and we welcome contributions for items not on -this roadmap. However, before submitting a large PR, we strongly -suggest you start a conversation using a github issue or the -dev@arrow.apache.org mailing list to make review efficient and avoid -surprises. - -## DataFusion - -DataFusion's goal is to become the embedded query engine of choice -for new analytic applications, by leveraging the unique features of -[Rust](https://www.rust-lang.org/) and [Apache Arrow](https://arrow.apache.org/) -to provide: - -1. Best-in-class single node query performance -2. A Declarative SQL query interface compatible with PostgreSQL -3. A Dataframe API, similar to those offered by Pandas and Spark -4. A Procedural API for programmatically creating and running execution plans -5. High performance, data race free, ergonomic extensibility points at at every layer - -### Additional SQL Language Features - -- Decimal Support [#122](https://github.com/apache/arrow-datafusion/issues/122) -- Complete support list on [status](https://github.com/apache/arrow-datafusion/blob/main/README.md#status) -- Timestamp Arithmetic [#194](https://github.com/apache/arrow-datafusion/issues/194) -- SQL Parser extension point [#533](https://github.com/apache/arrow-datafusion/issues/533) -- Support for nested structures (fields, lists, structs) [#119](https://github.com/apache/arrow-datafusion/issues/119) -- Run all queries from the TPCH benchmark (see [milestone](https://github.com/apache/arrow-datafusion/milestone/2) for more details) - -### Query Optimizer - -- More sophisticated cost based optimizer for join ordering -- Implement advanced query optimization framework (Tokomak) [#440](https://github.com/apache/arrow-datafusion/issues/440) -- Finer optimizations for group by and aggregate functions - -### Datasources - -- Better support for reading data from remote filesystems (e.g. S3) without caching it locally [#907](https://github.com/apache/arrow-datafusion/issues/907) [#1060](https://github.com/apache/arrow-datafusion/issues/1060) -- Improve performances of file format datasources (parallelize file listings, async Arrow readers, file chunk prefetching capability...) - -### Runtime / Infrastructure - -- Migrate to some sort of arrow2 based implementation (see [milestone](https://github.com/apache/arrow-datafusion/milestone/3) for more details) -- Add DataFusion to h2oai/db-benchmark [#147](https://github.com/apache/arrow-datafusion/issues/147) -- Improve build time [#348](https://github.com/apache/arrow-datafusion/issues/348) - -### Resource Management - -- Finer grain control and limit of runtime memory [#587](https://github.com/apache/arrow-datafusion/issues/587) and CPU usage [#54](https://github.com/apache/arrow-datafusion/issues/64) - -### Python Interface - -TBD - -### DataFusion CLI (`datafusion-cli`) - -Note: There are some additional thoughts on a datafusion-cli vision on [#1096](https://github.com/apache/arrow-datafusion/issues/1096#issuecomment-939418770). - -- Better abstraction between REPL parsing and queries so that commands are separated and handled correctly -- Connect to the `Statistics` subsystem and have the cli print out more stats for query debugging, etc. -- Improved error handling for interactive use and shell scripting usage -- publishing to apt, brew, and possible NuGet registry so that people can use it more easily -- adopt a shorter name, like dfcli? - -## Ballista - -Ballista is a distributed compute platform based on Apache Arrow and DataFusion. It provides a query scheduler that -breaks a physical plan into stages and tasks and then schedules tasks for execution across the available executors -in the cluster. - -Having Ballista as part of the DataFusion codebase helps ensure that DataFusion remains suitable for distributed -compute. For example, it helps ensure that physical query plans can be serialized to protobuf format and that they -remain language-agnostic so that executors can be built in languages other than Rust. - -### Ballista Roadmap - -### Move query scheduler into DataFusion - -The Ballista scheduler has some advantages over DataFusion query execution because it doesn't try to eagerly execute -the entire query at once but breaks it down into a directionally-acyclic graph (DAG) of stages and executes a -configurable number of stages and tasks concurrently. It should be possible to push some of this logic down to -DataFusion so that the same scheduler can be used to scale across cores in-process and across nodes in a cluster. - -### Implement execution-time cost-based optimizations based on statistics - -After the execution of a query stage, accurate statistics are available for the resulting data. These statistics -could be leveraged by the scheduler to optimize the query during execution. For example, when performing a hash join -it is desirable to load the smaller side of the join into memory and in some cases we cannot predict which side will -be smaller until execution time. +The [project introduction](../user-guide/introduction) explains the +overview and goals of DataFusion, and our development efforts largely +align to that vision. + +## Planning `EPIC`s + +DataFusion uses [GitHub +issues](https://github.com/apache/arrow-datafusion/issues) to track +planned work. We collect related tickets using tracking issues labeled +with `[EPIC]` which contain discussion and links to more detailed items. + +Epics offer a high level roadmap of what the DataFusion +community is thinking about. The epics are not meant to restrict +possibilities, but rather help the community see where development is +headed, align our work, and inspire additional contributions. + +As this project is entirely driven by volunteers, we welcome +contributions for items not currently covered by epics. However, +before submitting a large PR, we strongly suggest and request you +start a conversation using a github issue or the +[dev@arrow.apache.org](mailto:dev@arrow.apache.org) mailing list to +make review efficient and avoid surprises. + +[The current list of `EPIC`s can be found here](https://github.com/apache/arrow-datafusion/issues?q=is%3Aissue+is%3Aopen+epic). diff --git a/docs/source/index.rst b/docs/source/index.rst index 4f45771173bfa..3853716617162 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,16 @@ Apache Arrow DataFusion ======================= +.. Code from https://buttons.github.io/ +.. raw:: html + +

+ + Star + + Fork +

+ DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in `Rust `_, using the `Apache Arrow `_ in-memory format. @@ -33,11 +43,12 @@ community. The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. -The `developer’s guide`_ contains information on how to contribute. +Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/arrow-datafusion/tree/master/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _communication: contributor-guide/communication.html .. _toc.links: .. toctree:: @@ -63,6 +74,22 @@ The `developer’s guide`_ contains information on how to contribute. user-guide/configs user-guide/faq +.. _toc.library-user-guide: + +.. toctree:: + :maxdepth: 1 + :caption: Library User Guide + + library-user-guide/index + library-user-guide/using-the-sql-api + library-user-guide/working-with-exprs + library-user-guide/using-the-dataframe-api + library-user-guide/building-logical-plans + library-user-guide/catalogs + library-user-guide/adding-udfs + library-user-guide/custom-table-providers + library-user-guide/extending-operators + .. _toc.contributor-guide: .. toctree:: diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md new file mode 100644 index 0000000000000..1e710bc321a2f --- /dev/null +++ b/docs/source/library-user-guide/adding-udfs.md @@ -0,0 +1,434 @@ + + +# Adding User Defined Functions: Scalar/Window/Aggregate + +User Defined Functions (UDFs) are functions that can be used in the context of DataFusion execution. + +This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. + +| UDF Type | Description | Example | +| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs) | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs) | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs) | + +First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. + +## Adding a Scalar UDF + +A Scalar UDF is a function that takes a row of data and returns a single value. For example, this function takes a single i64 and returns a single i64 with 1 added to it: + +```rust +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, Int64Array}; +use datafusion::common::Result; + +use datafusion::common::cast::as_int64_array; + +pub fn add_one(args: &[ArrayRef]) -> Result { + // Error handling omitted for brevity + + let i64s = as_int64_array(&args[0])?; + + let new_array = i64s + .iter() + .map(|array_elem| array_elem.map(|value| value + 1)) + .collect::(); + + Ok(Arc::new(new_array)) +} +``` + +For brevity, we'll skipped some error handling, but e.g. you may want to check that `args.len()` is the expected number of arguments. + +This "works" in isolation, i.e. if you have a slice of `ArrayRef`s, you can call `add_one` and it will return a new `ArrayRef` with 1 added to each value. + +```rust +let input = vec![Some(1), None, Some(3)]; +let input = Arc::new(Int64Array::from(input)) as ArrayRef; + +let result = add_one(&[input]).unwrap(); +let result = result.as_any().downcast_ref::().unwrap(); + +assert_eq!(result, &Int64Array::from(vec![Some(2), None, Some(4)])); +``` + +The challenge however is that DataFusion doesn't know about this function. We need to register it with DataFusion so that it can be used in the context of a query. + +### Registering a Scalar UDF + +To register a Scalar UDF, you need to wrap the function implementation in a `ScalarUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udf` and `make_scalar_function` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::physical_plan::functions::make_scalar_function; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +let udf = create_udf( + "add_one", + vec![DataType::Int64], + Arc::new(DataType::Int64), + Volatility::Immutable, + make_scalar_function(add_one), +); +``` + +A few things to note: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Int64` argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- The fifth argument is the function implementation. This is the function that we defined above. + +That gives us a `ScalarUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let mut ctx = SessionContext::new(); + +ctx.register_udf(udf); +``` + +At this point, you can use the `add_one` function in your query: + +```rust +let sql = "SELECT add_one(1)"; + +let df = ctx.sql(&sql).await.unwrap(); +``` + +## Adding a Window UDF + +Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. + +For example, we will declare a user defined window function that computes a moving average. + +```rust +use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +use datafusion::logical_expr::{PartitionEvaluator}; +use datafusion::common::ScalarValue; +use datafusion::error::Result; +/// This implements the lowest level evaluation for a window function +/// +/// It handles calculating the value of the window function for each +/// distinct values of `PARTITION BY` +#[derive(Clone, Debug)] +struct MyPartitionEvaluator {} + +impl MyPartitionEvaluator { + fn new() -> Self { + Self {} + } +} + +/// Different evaluation methods are called depending on the various +/// settings of WindowUDF. This example uses the simplest and most +/// general, `evaluate`. See `PartitionEvaluator` for the other more +/// advanced uses. +impl PartitionEvaluator for MyPartitionEvaluator { + /// Tell DataFusion the window function varies based on the value + /// of the window frame. + fn uses_window_frame(&self) -> bool { + true + } + + /// This function is called once per input row. + /// + /// `range`specifies which indexes of `values` should be + /// considered for the calculation. + /// + /// Note this is the SLOWEST, but simplest, way to evaluate a + /// window function. It is much faster to implement + /// evaluate_all or evaluate_all_with_rank, if possible + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &std::ops::Range, + ) -> Result { + // Again, the input argument is an array of floating + // point numbers to calculate a moving average + let arr: &Float64Array = values[0].as_ref().as_primitive::(); + + let range_len = range.end - range.start; + + // our smoothing function will average all the values in the + let output = if range_len > 0 { + let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); + Some(sum / range_len as f64) + } else { + None + }; + + Ok(ScalarValue::Float64(output)) + } +} + +/// Create a `PartitionEvalutor` to evaluate this function on a new +/// partition. +fn make_partition_evaluator() -> Result> { + Ok(Box::new(MyPartitionEvaluator::new())) +} +``` + +### Registering a Window UDF + +To register a Window UDF, you need to wrap the function implementation in a `WindowUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udwf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udwf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDWF. We also declare its signature: +let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), +); +``` + +The `create_udwf` has five arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- **The second argument** is the `DataType` of input array (attention: this is not a list of arrays). I.e. in this case, the function accepts `Float64` as argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Float64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- **The fifth argument** is the function implementation. This is the function that we defined above. + +That gives us a `WindowUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udwf(smooth_it); +``` + +At this point, you can use the `smooth_it` function in your query: + +For example, if we have a [`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like + +```csv +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +... +``` + +Then, we can query like below: + +```rust +use datafusion::datasource::file_format::options::CsvReadOptions; +// register csv table first +let csv_path = "cars.csv".to_string(); +ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; +// do query with smooth_it +let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; +// print the results +df.show().await?; +``` + +the output will be like: + +```csv ++-------+-------+--------------------+---------------------+ +| car | speed | smooth_speed | time | ++-------+-------+--------------------+---------------------+ +| green | 10.0 | 10.0 | 1996-04-12T12:05:03 | +| green | 10.3 | 10.15 | 1996-04-12T12:05:04 | +| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 | +| green | 10.5 | 10.3 | 1996-04-12T12:05:06 | +| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 | +| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 | +| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 | +| green | 15.0 | 11.65 | 1996-04-12T12:05:10 | +| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 | +| green | 15.2 | 12.35 | 1996-04-12T12:05:12 | +| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 | +| green | 2.0 | 11.125 | 1996-04-12T12:05:14 | +| red | 20.0 | 20.0 | 1996-04-12T12:05:03 | +| red | 20.3 | 20.15 | 1996-04-12T12:05:04 | +... +``` + +## Adding an Aggregate UDF + +Aggregate UDFs are functions that take a group of rows and return a single value. These are akin to SQL's `SUM` or `COUNT` functions. + +For example, we will declare a single-type, single return type UDAF that computes the geometric mean. + +```rust +use datafusion::arrow::array::ArrayRef; +use datafusion::scalar::ScalarValue; +use datafusion::{error::Result, physical_plan::Accumulator}; + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} +``` + +### registering an Aggregate UDF + +To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. + +```rust +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; + +// here is where we define the UDAF. We also declare its signature: +let geometric_mean = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), +); +``` + +The `create_udaf` has six arguments to check: + +- The first argument is the name of the function. This is the name that will be used in SQL queries. +- The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in this case, the function accepts a single `Float64` argument. +- The third argument is the return type of the function. I.e. in this case, the function returns an `Int64`. +- The fourth argument is the volatility of the function. In short, this is used to determine if the function's performance can be optimized in some situations. In this case, the function is `Immutable` because it always returns the same value for the same input. A random number generator would be `Volatile` because it returns a different value for the same input. +- The fifth argument is the function implementation. This is the function that we defined above. +- The sixth argument is the description of the state, which will by passed between execution stages. + +That gives us a `AggregateUDF` that we can register with the `SessionContext`: + +```rust +use datafusion::execution::context::SessionContext; + +let ctx = SessionContext::new(); + +ctx.register_udaf(geometric_mean); +``` + +Then, we can query like below: + +```rust +let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; +``` diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md new file mode 100644 index 0000000000000..fe922d8eaeb11 --- /dev/null +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -0,0 +1,149 @@ + + +# Building Logical Plans + +A logical plan is a structured representation of a database query that describes the high-level operations and +transformations needed to retrieve data from a database or data source. It abstracts away specific implementation +details and focuses on the logical flow of the query, including operations like filtering, sorting, and joining tables. + +This logical plan serves as an intermediate step before generating an optimized physical execution plan. This is +explained in more detail in the [Query Planning and Execution Overview] section of the [Architecture Guide]. + +## Building Logical Plans Manually + +DataFusion's [LogicalPlan] is an enum containing variants representing all the supported operators, and also +contains an `Extension` variant that allows projects building on DataFusion to add custom logical operators. + +It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as follows, but is is +much easier to use the [LogicalPlanBuilder], which is described in the next section. + +Here is an example of building a logical plan directly: + + + +```rust +// create a logical table source +let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), +]); +let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + +// create a TableScan plan +let projection = None; // optional projection +let filters = vec![]; // optional filters to push down +let fetch = None; // optional LIMIT +let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, +)?); + +// create a Filter plan that evaluates `id > 500` that wraps the TableScan +let filter_expr = col("id").gt(lit(500)); +let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); + +// print the plan +println!("{}", plan.display_indent_schema()); +``` + +This example produces the following plan: + +``` +Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] + TableScan: person [id:Int32;N, name:Utf8;N] +``` + +## Building Logical Plans with LogicalPlanBuilder + +DataFusion logical plans can be created using the [LogicalPlanBuilder] struct. There is also a [DataFrame] API which is +a higher-level API that delegates to [LogicalPlanBuilder]. + +The following associated functions can be used to create a new builder: + +- `empty` - create an empty plan with no fields +- `values` - create a plan from a set of literal values +- `scan` - create a plan representing a table scan +- `scan_with_filters` - create a plan representing a table scan with filters + +Once the builder is created, transformation methods can be called to declare that further operations should be +performed on the plan. Note that all we are doing at this stage is building up the logical plan structure. No query +execution will be performed. + +Here are some examples of transformation methods, but for a full list, refer to the [LogicalPlanBuilder] API documentation. + +- `filter` +- `limit` +- `sort` +- `distinct` +- `join` + +The following example demonstrates building the same simple query plan as the previous example, with a table scan followed by a filter. + + + +```rust +// create a logical table source +let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), +]); +let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + +// optional projection +let projection = None; + +// create a LogicalPlanBuilder for a table scan +let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + +// perform a filter operation and build the plan +let plan = builder + .filter(col("id").gt(lit(500)))? // WHERE id > 500 + .build()?; + +// print the plan +println!("{}", plan.display_indent_schema()); +``` + +This example produces the following plan: + +``` +Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] + TableScan: person [id:Int32;N, name:Utf8;N] +``` + +## Table Sources + +The previous example used a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also +suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. However, if you +want to use a [TableSource] that can be executed in DataFusion then you will need to use [DefaultTableSource], which is a +wrapper for a [TableProvider]. + +[query planning and execution overview]: https://docs.rs/datafusion/latest/datafusion/index.html#query-planning-and-execution-overview +[architecture guide]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture +[logicalplan]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html +[logicalplanbuilder]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalPlanBuilder.html +[dataframe]: using-the-dataframe-api.md +[logicaltablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalTableSource.html +[defaulttablesource]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html +[tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +[tablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.TableSource.html diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md new file mode 100644 index 0000000000000..e53d163663502 --- /dev/null +++ b/docs/source/library-user-guide/catalogs.md @@ -0,0 +1,216 @@ + + +# Catalogs, Schemas, and Tables + +This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs). + +## General Concepts + +CatalogList, Catalogs, schemas, and tables are organized in a hierarchy. A CatalogList contains catalogs, a catalog contains schemas and a schema contains tables. + +DataFusion comes with a basic in memory catalog functionality in the [`catalog` module]. You can use these in memory implementations as is, or extend DataFusion with your own catalog implementations, for example based on local files or files on remote object storage. + +[`catalog` module]: https://docs.rs/datafusion/latest/datafusion/catalog/index.html + +Similarly to other concepts in DataFusion, you'll implement various traits to create your own catalogs, schemas, and tables. The following sections describe the traits you'll need to implement. + +The `CatalogList` trait has methods to register new catalogs, get a catalog by name and list all catalogs .The `CatalogProvider` trait has methods to set a schema to a name, get a schema by name, and list all schemas. The `SchemaProvider`, which can be registered with a `CatalogProvider`, has methods to set a table to a name, get a table by name, list all tables, deregister a table, and check for a table's existence. The `TableProvider` trait has methods to scan underlying data and use it in DataFusion. The `TableProvider` trait is covered in more detail [here](./custom-table-providers.md). + +In the following example, we'll implement an in memory catalog, starting with the `SchemaProvider` trait as we need one to register with the `CatalogProvider`. Finally we will implement `CatalogList` to register the `CatalogProvider`. + +## Implementing `MemorySchemaProvider` + +The `MemorySchemaProvider` is a simple implementation of the `SchemaProvider` trait. It stores state (i.e. tables) in a `DashMap`, which then underlies the `SchemaProvider` trait. + +```rust +pub struct MemorySchemaProvider { + tables: DashMap>, +} +``` + +`tables` is the key-value pair described above. The underlying state could also be another data structure or other storage mechanism such as a file or transactional database. + +Then we implement the `SchemaProvider` trait for `MemorySchemaProvider`. + +```rust +#[async_trait] +impl SchemaProvider for MemorySchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.tables + .iter() + .map(|table| table.key().clone()) + .collect() + } + + async fn table(&self, name: &str) -> Option> { + self.tables.get(name).map(|table| table.value().clone()) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + if self.table_exist(name.as_str()) { + return Err(DataFusionError::Execution(format!( + "The table {name} already exists" + ))); + } + Ok(self.tables.insert(name, table)) + } + + fn deregister_table(&self, name: &str) -> Result>> { + Ok(self.tables.remove(name).map(|(_, table)| table)) + } + + fn table_exist(&self, name: &str) -> bool { + self.tables.contains_key(name) + } +} +``` + +Without getting into a `CatalogProvider` implementation, we can create a `MemorySchemaProvider` and register `TableProvider`s with it. + +```rust +let schema_provider = Arc::new(MemorySchemaProvider::new()); +let table_provider = _; // create a table provider + +schema_provider.register_table("table_name".to_string(), table_provider); + +let table = schema_provider.table("table_name").unwrap(); +``` + +### Asynchronous `SchemaProvider` + +It's often useful to fetch metadata about which tables are in a schema, from a remote source. For example, a schema provider could fetch metadata from a remote database. To support this, the `SchemaProvider` trait has an asynchronous `table` method. + +The trait is roughly the same except for the `table` method, and the addition of the `#[async_trait]` attribute. + +```rust +#[async_trait] +impl SchemaProvider for Schema { + async fn table(&self, name: &str) -> Option> { + // fetch metadata from remote source + } +} +``` + +## Implementing `MemoryCatalogProvider` + +As mentioned, the `CatalogProvider` can manage the schemas in a catalog, and the `MemoryCatalogProvider` is a simple implementation of the `CatalogProvider` trait. It stores schemas in a `DashMap`. + +```rust +pub struct MemoryCatalogProvider { + schemas: DashMap>, +} +``` + +With that the `CatalogProvider` trait can be implemented. + +```rust +impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.schemas.iter().map(|s| s.key().clone()).collect() + } + + fn schema(&self, name: &str) -> Option> { + self.schemas.get(name).map(|s| s.value().clone()) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + Ok(self.schemas.insert(name.into(), schema)) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + /// `cascade` is not used here, but can be used to control whether + /// to delete all tables in the schema or not. + if let Some(schema) = self.schema(name) { + let (_, removed) = self.schemas.remove(name).unwrap(); + Ok(Some(removed)) + } else { + Ok(None) + } + } +} +``` + +Again, this is fairly straightforward, as there's an underlying data structure to store the state, via key-value pairs. + +## Implementing `MemoryCatalogList` + +```rust +pub struct MemoryCatalogList { + /// Collection of catalogs containing schemas and ultimately TableProviders + pub catalogs: DashMap>, +} +``` + +With that the `CatalogList` trait can be implemented. + +```rust +impl CatalogList for MemoryCatalogList { + fn as_any(&self) -> &dyn Any { + self + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.catalogs.insert(name, catalog) + } + + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.catalogs.get(name).map(|c| c.value().clone()) + } +} +``` + +Like other traits, it also maintains the mapping of the Catalog's name to the CatalogProvider. + +## Recap + +To recap, you need to: + +1. Implement the `TableProvider` trait to create a table provider, or use an existing one. +2. Implement the `SchemaProvider` trait to create a schema provider, or use an existing one. +3. Implement the `CatalogProvider` trait to create a catalog provider, or use an existing one. +4. Implement the `CatalogList` trait to create a CatalogList, or use an existing one. diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md new file mode 100644 index 0000000000000..9da207da68f32 --- /dev/null +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -0,0 +1,177 @@ + + +# Custom Table Provider + +Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The `TableProvider` and associated traits, have methods that allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. + +This section will also touch on how to have DataFusion use the new `TableProvider` implementation. + +## Table Provider and Scan + +The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. + +### Scan + +As mentioned, `scan` returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. + +#### Execution Plan + +The `ExecutionPlan` trait at its core is a way to get a stream of batches. The aptly-named `execute` method returns a `Result`, which should be a stream of `RecordBatch`es that can be sent across threads, and has a schema that matches the data to be contained in those batches. + +There are many different types of `SendableRecordBatchStream` implemented in DataFusion -- you can use a pre existing one, such as `MemoryStream` (if your `RecordBatch`es are all in memory) or implement your own custom logic, depending on your usecase. + +Looking at the [example in this repo][ex], the execute method: + +```rust +struct CustomExec { + db: CustomDataSource, + projected_schema: SchemaRef, +} + +impl ExecutionPlan for CustomExec { + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + let users: Vec = { + let db = self.db.inner.lock().unwrap(); + db.data.values().cloned().collect() + }; + + let mut id_array = UInt8Builder::with_capacity(users.len()); + let mut account_array = UInt64Builder::with_capacity(users.len()); + + for user in users { + id_array.append_value(user.id); + account_array.append_value(user.bank_account); + } + + Ok(Box::pin(MemoryStream::try_new( + vec![RecordBatch::try_new( + self.projected_schema.clone(), + vec![ + Arc::new(id_array.finish()), + Arc::new(account_array.finish()), + ], + )?], + self.schema(), + None, + )?)) + } +} +``` + +This: + +1. Gets the users from the database +2. Constructs the individual output arrays (columns) +3. Returns a `MemoryStream` of a single `RecordBatch` with the arrays + +I.e. returns the "physical" data. For other examples, refer to the [`CsvExec`][csv] and [`ParquetExec`][parquet] for more complex implementations. + +With the `ExecutionPlan` implemented, we can now implement the `scan` method of the `TableProvider`. + +#### Scan Revisited + +The `scan` method of the `TableProvider` returns a `Result>`. We can use the `Arc` to return a reference-counted pointer to the `ExecutionPlan` we implemented. In the example, this is done by: + +```rust +impl CustomDataSource { + pub(crate) async fn create_physical_plan( + &self, + projections: Option<&Vec>, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(CustomExec::new(projections, schema, self.clone()))) + } +} + +#[async_trait] +impl TableProvider for CustomDataSource { + async fn scan( + &self, + _state: &SessionState, + projection: Option<&Vec>, + // filters and limit can be used here to inject some push-down operations if needed + _filters: &[Expr], + _limit: Option, + ) -> Result> { + return self.create_physical_plan(projection, self.schema()).await; + } +} +``` + +With this, and the implementation of the omitted methods, we can now use the `CustomDataSource` as a `TableProvider` in DataFusion. + +##### Additional `TableProvider` Methods + +`scan` has no default implementation, so it needed to be written. There are other methods on the `TableProvider` that have default implementations, but can be overridden if needed to provide additional functionality. + +###### `supports_filters_pushdown` + +The `supports_filters_pushdown` method can be overridden to indicate which filter expressions support being pushed down to the data source and within that the specificity of the pushdown. + +This returns a `Vec` of `TableProviderFilterPushDown` enums where each enum represents a filter that can be pushed down. The `TableProviderFilterPushDown` enum has three variants: + +- `TableProviderFilterPushDown::Unsupported` - the filter cannot be pushed down +- `TableProviderFilterPushDown::Exact` - the filter can be pushed down and the data source can guarantee that the filter will be applied completely to all rows. This is the highest performance option. +- `TableProviderFilterPushDown::Inexact` - the filter can be pushed down, but the data source cannot guarantee that the filter will be applied to all rows. DataFusion will apply `Inexact` filters again after the scan to ensure correctness. + +For filters that can be pushed down, they'll be passed to the `scan` method as the `filters` parameter and they can be made use of there. + +## Using the Custom Table Provider + +In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `SessionContext`. + +```rust +let mut ctx = SessionContext::new(); + +let custom_table_provider = CustomDataSource::new(); +ctx.register_table("custom_table", Arc::new(custom_table_provider)); +``` + +This will allow you to use the custom table provider in DataFusion. For example, you could use it in a SQL query to get a `DataFrame`. + +```rust +let df = ctx.sql("SELECT id, bank_account FROM custom_table")?; +``` + +## Recap + +To recap, in order to implement a custom table provider, you need to: + +1. Implement the `TableProvider` trait +2. Implement the `ExecutionPlan` trait +3. Register the `TableProvider` with the `SessionContext` + +## Next Steps + +As mentioned the [csv] and [parquet] implementations are good examples of how to implement a `TableProvider`. The [example in this repo][ex] is a good example of how to implement a `TableProvider` that uses a custom data source. + +More abstractly, see the following traits for more information on how to implement a custom `TableProvider` for a file format: + +- `FileOpener` - a trait for opening a file and inferring the schema +- `FileFormat` - a trait for reading a file format +- `ListingTableProvider` - a useful trait for implementing a `TableProvider` that lists files in a directory + +[ex]: https://github.com/apache/arrow-datafusion/blob/a5e86fae3baadbd99f8fd0df83f45fde22f7b0c6/datafusion-examples/examples/custom_datasource.rs#L214C1-L276 +[csv]: https://github.com/apache/arrow-datafusion/blob/a5e86fae3baadbd99f8fd0df83f45fde22f7b0c6/datafusion/core/src/datasource/physical_plan/csv.rs#L57-L70 +[parquet]: https://github.com/apache/arrow-datafusion/blob/a5e86fae3baadbd99f8fd0df83f45fde22f7b0c6/datafusion/core/src/datasource/physical_plan/parquet.rs#L77-L104 diff --git a/docs/source/library-user-guide/extending-operators.md b/docs/source/library-user-guide/extending-operators.md new file mode 100644 index 0000000000000..631bdc67975a4 --- /dev/null +++ b/docs/source/library-user-guide/extending-operators.md @@ -0,0 +1,22 @@ + + +# Extending DataFusion's operators: custom LogicalPlan and Execution Plans + +Coming soon diff --git a/docs/source/library-user-guide/index.md b/docs/source/library-user-guide/index.md new file mode 100644 index 0000000000000..47257e0c926e7 --- /dev/null +++ b/docs/source/library-user-guide/index.md @@ -0,0 +1,26 @@ + + +# Introduction + +The library user guide explains how to use the DataFusion library as a dependency in your Rust project. Please check out the user-guide for more details on how to use DataFusion's SQL and DataFrame APIs, or the contributor guide for details on how to contribute to DataFusion. + +If you haven't reviewed the [architecture section in the docs][docs], it's a useful place to get the lay of the land before starting down a specific path. + +[docs]: https://docs.rs/datafusion/latest/datafusion/#architecture diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md new file mode 100644 index 0000000000000..c4f4ecd4f1370 --- /dev/null +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -0,0 +1,147 @@ + + +# Using the DataFrame API + +## What is a DataFrame + +`DataFrame` in `DataFrame` is modeled after the Pandas DataFrame interface, and is a thin wrapper over LogicalPlan that adds functionality for building and executing those plans. + +```rust +pub struct DataFrame { + session_state: SessionState, + plan: LogicalPlan, +} +``` + +You can build up `DataFrame`s using its methods, similarly to building `LogicalPlan`s using `LogicalPlanBuilder`: + +```rust +let df = ctx.table("users").await?; + +// Create a new DataFrame sorted by `id`, `bank_account` +let new_df = df.select(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])?; + +// Build the same plan using the LogicalPlanBuilder +let plan = LogicalPlanBuilder::from(&df.to_logical_plan()) + .project(vec![col("id"), col("bank_account")])? + .sort(vec![col("id")])? + .build()?; +``` + +You can use `collect` or `execute_stream` to execute the query. + +## How to generate a DataFrame + +You can directly use the `DataFrame` API or generate a `DataFrame` from a SQL query. + +For example, to use `sql` to construct `DataFrame`: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; +``` + +To construct `DataFrame` using the API: + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(create_memtable()?))?; +let dataframe = ctx + .table("users") + .filter(col("a").lt_eq(col("b")))? + .sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; +``` + +## Collect / Streaming Exec + +DataFusion `DataFrame`s are "lazy", meaning they do not do any processing until they are executed, which allows for additional optimizations. + +When you have a `DataFrame`, you can run it in one of three ways: + +1. `collect` which executes the query and buffers all the output into a `Vec` +2. `streaming_exec`, which begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` +3. `cache` which executes the query and buffers the output into a new in memory DataFrame. + +You can just collect all outputs once like: + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let batches = df.collect().await?; +``` + +You can also use stream output to incrementally generate output one `RecordBatch` at a time + +```rust +let ctx = SessionContext::new(); +let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +let mut stream = df.execute_stream().await?; +while let Some(rb) = stream.next().await { + println!("{rb:?}"); +} +``` + +# Write DataFrame to Files + +You can also serialize `DataFrame` to a file. For now, `Datafusion` supports write `DataFrame` to `csv`, `json` and `parquet`. + +When writing a file, DataFusion will execute the DataFrame and stream the results to a file. + +For example, to write a csv_file + +```rust +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +dataframe + .write_csv("user_dataframe.csv", DataFrameWriteOptions::default(), None) + .await; +``` + +and the file will look like (Example Output): + +``` +id,bank_account +1,9000 +``` + +## Transform between LogicalPlan and DataFrame + +As shown above, `DataFrame` is just a very thin wrapper of `LogicalPlan`, so you can easily go back and forth between them. + +```rust +// Just combine LogicalPlan with SessionContext and you get a DataFrame +let ctx = SessionContext::new(); +// Register the in-memory table containing the data +ctx.register_table("users", Arc::new(mem_table))?; +let dataframe = ctx.sql("SELECT * FROM users;").await?; + +// get LogicalPlan in dataframe +let plan = dataframe.logical_plan().clone(); + +// construct a DataFrame with LogicalPlan +let new_df = DataFrame::new(ctx.state(), plan); +``` diff --git a/docs/source/library-user-guide/using-the-sql-api.md b/docs/source/library-user-guide/using-the-sql-api.md new file mode 100644 index 0000000000000..f4e85ee4e3a92 --- /dev/null +++ b/docs/source/library-user-guide/using-the-sql-api.md @@ -0,0 +1,22 @@ + + +# Using the SQL API + +Coming Soon diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md new file mode 100644 index 0000000000000..96be8ef7f1aeb --- /dev/null +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -0,0 +1,185 @@ + + +# Working with `Expr`s + + + +`Expr` is short for "expression". It is a core abstraction in DataFusion for representing a computation, and follows the standard "expression tree" abstraction found in most compilers and databases. + +For example, the SQL expression `a + b` would be represented as an `Expr` with a `BinaryExpr` variant. A `BinaryExpr` has a left and right `Expr` and an operator. + +As another example, the SQL expression `a + b * c` would be represented as an `Expr` with a `BinaryExpr` variant. The left `Expr` would be `a` and the right `Expr` would be another `BinaryExpr` with a left `Expr` of `b` and a right `Expr` of `c`. As a classic expression tree, this would look like: + +```text + ┌────────────────────┐ + │ BinaryExpr │ + │ op: + │ + └────────────────────┘ + ▲ ▲ + ┌───────┘ └────────────────┐ + │ │ +┌────────────────────┐ ┌────────────────────┐ +│ Expr::Col │ │ BinaryExpr │ +│ col: a │ │ op: * │ +└────────────────────┘ └────────────────────┘ + ▲ ▲ + ┌────────┘ └─────────┐ + │ │ + ┌────────────────────┐ ┌────────────────────┐ + │ Expr::Col │ │ Expr::Col │ + │ col: b │ │ col: c │ + └────────────────────┘ └────────────────────┘ +``` + +As the writer of a library, you can use `Expr`s to represent computations that you want to perform. This guide will walk you through how to make your own scalar UDF as an `Expr` and how to rewrite `Expr`s to inline the simple UDF. + +## Creating and Evaluating `Expr`s + +Please see [expr_api.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs) for well commented code for creating, evaluating, simplifying, and analyzing `Expr`s. + +## A Scalar UDF Example + +We'll use a `ScalarUDF` expression as our example. This necessitates implementing an actual UDF, and for ease we'll use the same example from the [adding UDFs](./adding-udfs.md) guide. + +So assuming you've written that function, you can use it to create an `Expr`: + +```rust +let add_one_udf = create_udf( + "add_one", + vec![DataType::Int64], + Arc::new(DataType::Int64), + Volatility::Immutable, + make_scalar_function(add_one), // <-- the function we wrote +); + +// make the expr `add_one(5)` +let expr = add_one_udf.call(vec![lit(5)]); + +// make the expr `add_one(my_column)` +let expr = add_one_udf.call(vec![col("my_column")]); +``` + +If you'd like to learn more about `Expr`s, before we get into the details of creating and rewriting them, you can read the [expression user-guide](./../user-guide/expressions.md). + +## Rewriting `Expr`s + +[rewrite_expr.rs](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. + +Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: + +- Simplifying `Expr`s to make them easier to evaluate +- Optimizing `Expr`s to make them faster to evaluate +- Converting `Expr`s to other forms, e.g. converting a `BinaryExpr` to a `CastExpr` + +In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten as a `BinaryExpr` with a `Literal` of 1. We're effectively inlining the UDF. + +### Rewriting with `transform` + +To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::No` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::Yes` is used to wrap the new `Expr`. + +```rust +fn rewrite_add_one(expr: Expr) -> Result { + expr.transform(&|expr| { + Ok(match expr { + Expr::ScalarUDF(scalar_fun) if scalar_fun.fun.name == "add_one" => { + let input_arg = scalar_fun.args[0].clone(); + let new_expression = input_arg + lit(1i64); + + Transformed::Yes(new_expression) + } + _ => Transformed::No(expr), + }) + }) +} +``` + +### Creating an `OptimizerRule` + +In DataFusion, an `OptimizerRule` is a trait that supports rewriting`Expr`s that appear in various parts of the `LogicalPlan`. It follows DataFusion's general mantra of trait implementations to drive behavior. + +We'll call our rule `AddOneInliner` and implement the `OptimizerRule` trait. The `OptimizerRule` trait has two methods: + +- `name` - returns the name of the rule +- `try_optimize` - takes a `LogicalPlan` and returns an `Option`. If the rule is able to optimize the plan, it returns `Some(LogicalPlan)` with the optimized plan. If the rule is not able to optimize the plan, it returns `None`. + +```rust +struct AddOneInliner {} + +impl OptimizerRule for AddOneInliner { + fn name(&self) -> &str { + "add_one_inliner" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // Map over the expressions and rewrite them + let new_expressions = plan + .expressions() + .into_iter() + .map(|expr| rewrite_add_one(expr)) + .collect::>>()?; + + let inputs = plan.inputs().into_iter().cloned().collect::>(); + + let plan = plan.with_new_exprs(&new_expressions, &inputs); + + plan.map(Some) + } +} +``` + +Note the use of `rewrite_add_one` which is mapped over `plan.expressions()` to rewrite the expressions, then `plan.with_new_exprs` is used to create a new `LogicalPlan` with the rewritten expressions. + +We're almost there. Let's just test our rule works properly. + +## Testing the Rule + +Testing the rule is fairly simple, we can create a SessionState with our rule and then create a DataFrame and run a query. The logical plan will be optimized by our rule. + +```rust +use datafusion::prelude::*; + +let rules = Arc::new(AddOneInliner {}); +let state = ctx.state().with_optimizer_rules(vec![rules]); + +let ctx = SessionContext::with_state(state); +ctx.register_udf(add_one); + +let sql = "SELECT add_one(1) AS added_one"; +let plan = ctx.sql(sql).await?.logical_plan(); + +println!("{:?}", plan); +``` + +This results in the following output: + +```text +Projection: Int64(1) + Int64(1) AS added_one + EmptyRelation +``` + +I.e. the `add_one` UDF has been inlined into the projection. + +## Conclusion + +In this guide, we've seen how to create `Expr`s programmatically and how to rewrite them. This is useful for simplifying and optimizing `Expr`s. We've also seen how to test our rule to ensure it works properly. diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index 53cceb8d0af1c..525ab090ce514 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -17,55 +17,12 @@ under the License. --> -# `datafusion-cli` +# Command line SQL console The DataFusion CLI is a command-line interactive SQL utility for executing queries against any supported data files. It is a convenient way to try DataFusion's SQL support with your own data. -## Example - -Create a CSV file to query. - -```shell -$ echo "a,b" > data.csv -$ echo "1,2" >> data.csv -``` - -Query that single file (the CLI also supports parquet, compressed csv, avro, json and more) - -```shell -$ datafusion-cli -DataFusion CLI v17.0.0 -❯ select * from 'data.csv'; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -+---+---+ -1 row in set. Query took 0.007 seconds. -``` - -You can also query directories of files with compatible schemas: - -```shell -$ ls data_dir/ -data.csv data2.csv -``` - -```shell -$ datafusion-cli -DataFusion CLI v16.0.0 -❯ select * from 'data_dir'; -+---+---+ -| a | b | -+---+---+ -| 3 | 4 | -| 1 | 2 | -+---+---+ -2 rows in set. Query took 0.007 seconds. -``` - ## Installation ### Install and run using Cargo @@ -74,7 +31,9 @@ The easiest way to install DataFusion CLI a spin is via `cargo install datafusio ### Install and run using Homebrew (on MacOS) -DataFusion CLI can also be installed via Homebrew (on MacOS). Install it as any other pre-built software like this: +DataFusion CLI can also be installed via Homebrew (on MacOS). If you don't have Homebrew installed, you can check how to install it [here](https://docs.brew.sh/Installation). + +Install it as any other pre-built software like this: ```bash brew install datafusion @@ -89,6 +48,34 @@ brew install datafusion datafusion-cli ``` +### Install and run using PyPI + +DataFusion CLI can also be installed via PyPI. You can check how to install PyPI [here](https://pip.pypa.io/en/latest/installation/). + +Install it as any other pre-built software like this: + +```bash +pip3 install datafusion +# Defaulting to user installation because normal site-packages is not writeable +# Collecting datafusion +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (9.6 kB) +# Collecting pyarrow>=11.0.0 (from datafusion) +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (3.0 kB) +# Requirement already satisfied: numpy>=1.16.6 in /Users/Library/Python/3.9/lib/python/site-packages (from pyarrow>=11.0.0->datafusion) (1.23.4) +# Downloading datafusion-33.0.0-cp38-abi3-macosx_11_0_arm64.whl (13.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 3.6 MB/s eta 0:00:00 +# Downloading pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl (24.0 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 36.4 MB/s eta 0:00:00 +# Installing collected packages: pyarrow, datafusion +# Attempting uninstall: pyarrow +# Found existing installation: pyarrow 10.0.1 +# Uninstalling pyarrow-10.0.1: +# Successfully uninstalled pyarrow-10.0.1 +# Successfully installed datafusion-33.0.0 pyarrow-14.0.1 + +datafusion-cli +``` + ### Run using Docker There is no officially published Docker image for the DataFusion CLI, so it is necessary to build from source @@ -118,28 +105,102 @@ USAGE: datafusion-cli [OPTIONS] OPTIONS: - -c, --batch-size The batch size of each query, or use DataFusion default - -f, --file ... Execute commands from file(s), then exit - --format [default: table] [possible values: csv, tsv, table, json, - nd-json] - -h, --help Print help information - -p, --data-path Path to your data, default to current directory - -q, --quiet Reduce printing other than the results and work quietly - -r, --rc ... Run the provided files on startup instead of ~/.datafusionrc - -V, --version Print version information + -b, --batch-size + The batch size of each query, or use DataFusion default + + -c, --command ... + Execute the given command string(s), then exit + + -f, --file ... + Execute commands from file(s), then exit + + --format + [default: table] [possible values: csv, tsv, table, json, nd-json] + + -h, --help + Print help information + + -m, --memory-limit + The memory pool limitation (e.g. '10g'), default to None (no limit) + + --maxrows + The max number of rows to display for 'Table' format + [default: 40] [possible values: numbers(0/10/...), inf(no limit)] + + --mem-pool-type + Specify the memory pool type 'greedy' or 'fair', default to 'greedy' + + -p, --data-path + Path to your data, default to current directory + + -q, --quiet + Reduce printing other than the results and work quietly + + -r, --rc ... + Run the provided files on startup instead of ~/.datafusionrc + + -V, --version + Print version information ``` -## Selecting files directly +## Querying data from the files directly Files can be queried directly by enclosing the file or directory name in single `'` quotes as shown in the example. +## Example + +Create a CSV file to query. + +```shell +$ echo "a,b" > data.csv +$ echo "1,2" >> data.csv +``` + +Query that single file (the CLI also supports parquet, compressed csv, avro, json and more) + +```shell +$ datafusion-cli +DataFusion CLI v17.0.0 +❯ select * from 'data.csv'; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | ++---+---+ +1 row in set. Query took 0.007 seconds. +``` + +You can also query directories of files with compatible schemas: + +```shell +$ ls data_dir/ +data.csv data2.csv +``` + +```shell +$ datafusion-cli +DataFusion CLI v16.0.0 +❯ select * from 'data_dir'; ++---+---+ +| a | b | ++---+---+ +| 3 | 4 | +| 1 | 2 | ++---+---+ +2 rows in set. Query took 0.007 seconds. +``` + +## Creating external tables + It is also possible to create a table backed by files by explicitly -via `CREATE EXTERNAL TABLE` as shown below. +via `CREATE EXTERNAL TABLE` as shown below. Filemask wildcards supported ## Registering Parquet Data Sources -Parquet data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. It is not necessary to provide schema information for Parquet files. +Parquet data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. The schema information will be derived automatically. + +Register a single file parquet datasource ```sql CREATE EXTERNAL TABLE taxi @@ -147,6 +208,22 @@ STORED AS PARQUET LOCATION '/mnt/nyctaxi/tripdata.parquet'; ``` +Register a single folder parquet datasource. All files inside must be valid parquet files! + +```sql +CREATE EXTERNAL TABLE taxi +STORED AS PARQUET +LOCATION '/mnt/nyctaxi/'; +``` + +Register a single folder parquet datasource by specifying a wildcard for files to read + +```sql +CREATE EXTERNAL TABLE taxi +STORED AS PARQUET +LOCATION '/mnt/nyctaxi/*.parquet'; +``` + ## Registering CSV Data Sources CSV data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. @@ -350,11 +427,13 @@ Available commands inside DataFusion CLI are: - Show configuration options +`SHOW ALL [VERBOSE]` + ```SQL > show all; +-------------------------------------------------+---------+ -| name | setting | +| name | value | +-------------------------------------------------+---------+ | datafusion.execution.batch_size | 8192 | | datafusion.execution.coalesce_batches | true | @@ -367,6 +446,21 @@ Available commands inside DataFusion CLI are: ``` +- Show specific configuration option + +`SHOW xyz.abc.qwe [VERBOSE]` + +```SQL +> show datafusion.execution.batch_size; + ++-------------------------------------------------+---------+ +| name | value | ++-------------------------------------------------+---------+ +| datafusion.execution.batch_size | 8192 | ++-------------------------------------------------+---------+ + +``` + - Set configuration options ```SQL @@ -385,12 +479,12 @@ For example, to set `datafusion.execution.batch_size` to `1024` you would set the `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable appropriately: -```shell +```SQL $ DATAFUSION_EXECUTION_BATCH_SIZE=1024 datafusion-cli DataFusion CLI v12.0.0 ❯ show all; +-------------------------------------------------+---------+ -| name | setting | +| name | value | +-------------------------------------------------+---------+ | datafusion.execution.batch_size | 1024 | | datafusion.execution.coalesce_batches | true | @@ -405,13 +499,13 @@ DataFusion CLI v12.0.0 You can change the configuration options using `SET` statement as well -```shell +```SQL $ datafusion-cli DataFusion CLI v13.0.0 ❯ show datafusion.execution.batch_size; +---------------------------------+---------+ -| name | setting | +| name | value | +---------------------------------+---------+ | datafusion.execution.batch_size | 8192 | +---------------------------------+---------+ @@ -422,7 +516,7 @@ DataFusion CLI v13.0.0 ❯ show datafusion.execution.batch_size; +---------------------------------+---------+ -| name | setting | +| name | value | +---------------------------------+---------+ | datafusion.execution.batch_size | 1024 | +---------------------------------+---------+ diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 32001b9664279..d5a43e429e099 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -35,44 +35,74 @@ Values are parsed according to the [same rules used in casts from Utf8](https:// If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. -| key | default | description | -| ---------------------------------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | false | If the file has a header | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded | -| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to true, file groups will be repartitioned to achieve maximum parallelism. Currently supported only for Parquet format in which case multiple row groups from the same file may be read concurrently. If false then each row group is read serially, though different files may be read in parallel. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | false | If the file has a header | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded | +| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.compression | zstd(3) | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | Sets maximum number of rows in a row group | +| datafusion.execution.parquet.created_by | datafusion version 33.0.0 | Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | +| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 8ebf4cc678e13..4484b2c510197 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -44,7 +44,7 @@ let df = df.filter(col("a").lt_eq(col("b")))? .aggregate(vec![col("a")], vec![min(col("b"))])? .limit(0, Some(100))?; // Print results -df.show(); +df.show().await?; ``` The DataFrame API is well documented in the [API reference on docs.rs](https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html). diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 69eeb902861ab..a7557f9b0bc3f 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -19,14 +19,17 @@ # Example Usage -In this example some simple processing is performed on the [`example.csv`](../../../datafusion/core/tests/data/example.csv) file. +In this example some simple processing is performed on the [`example.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/example.csv) file. + +Even [`more code examples`](https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples) attached to the project. ## Update `Cargo.toml` -Add the following to your `Cargo.toml` file: +Find latest available Datafusion version on [DataFusion's +crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "22" +datafusion = "31" tokio = "1.0" ``` @@ -110,7 +113,7 @@ unexpectedly. [`arrow`]: https://docs.rs/arrow/latest/arrow/ [`parquet`]: https://docs.rs/parquet/latest/parquet/ [datafusion's crates.io]: https://crates.io/crates/datafusion -[datafusion `25.0.0` dependencies]: https://crates.io/crates/datafusion/25.0.0/dependencies +[datafusion `26.0.0` dependencies]: https://crates.io/crates/datafusion/26.0.0/dependencies ## Identifiers and Capitalization @@ -184,10 +187,6 @@ DataFusion is designed to be extensible at all points. To that end, you can prov - [x] User Defined `LogicalPlan` nodes - [x] User Defined `ExecutionPlan` nodes -## Rust Version Compatibility - -This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - ## Optimized Configuration For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is @@ -230,3 +229,37 @@ with `native` or at least `avx2`. ```shell RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release ``` + +## Enable backtraces + +By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, +like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: + +```toml +datafusion = { version = "31.0.0", features = ["backtrace"]} +``` + +Set environment [`variables`] https://doc.rust-lang.org/std/backtrace/index.html#environment-variables + +```bash +RUST_BACKTRACE=1 ./target/debug/datafusion-cli +DataFusion CLI v31.0.0 +❯ select row_numer() over (partition by a order by a) from (select 1 a); +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 + 3: std::backtrace::Backtrace::capture + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 + 4: datafusion_common::error::DataFusionError::get_back_trace + at /arrow-datafusion/datafusion/common/src/error.rs:436:30 + 5: datafusion_sql::expr::function::>::sql_function_to_expr + ............ +``` + +Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 07f5923a6a344..b8689e5567415 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -22,60 +22,94 @@ DataFrame methods such as `select` and `filter` accept one or more logical expressions and there are many functions available for creating logical expressions. These are documented below. -Expressions can be chained together using a fluent-style API: +:::{tip} +Most functions and methods may receive and return an `Expr`, which can be chained together using a fluent-style API: ```rust // create the expression `(a > 6) AND (b < 7)` col("a").gt(lit(6)).and(col("b").lt(lit(7))) ``` +::: + ## Identifiers -| Function | Notes | -| -------- | -------------------------------------------- | -| col | Reference a column in a dataframe `col("a")` | +| Syntax | Description | +| ---------- | -------------------------------------------- | +| col(ident) | Reference a column in a dataframe `col("a")` | + +:::{note} +ident +: A type which implement `Into` trait +::: ## Literal Values -| Function | Notes | -| -------- | -------------------------------------------------- | -| lit | Literal value such as `lit(123)` or `lit("hello")` | +| Syntax | Description | +| ---------- | -------------------------------------------------- | +| lit(value) | Literal value such as `lit(123)` or `lit("hello")` | + +:::{note} +value +: A type which implement `Literal` +::: ## Boolean Expressions -| Function | Notes | -| -------- | ----------------------------------------- | -| and | `and(expr1, expr2)` or `expr1.and(expr2)` | -| or | `or(expr1, expr2)` or `expr1.or(expr2)` | -| not | `not(expr)` or `expr.not()` | +| Syntax | Description | +| ------------------- | ----------- | +| and(x, y), x.and(y) | Logical AND | +| or(x, y), x.or(y) | Logical OR | +| !x, not(x), x.not() | Logical NOT | + +:::{note} +`!` is a bitwise or logical complement operator in Rust, but it only works as a logical NOT in expression API. +::: + +:::{note} +Since `&&` and `||` are existed as logical operators in Rust, but those are not overloadable and not works with expression API. +::: -## Bitwise expressions +## Bitwise Expressions -| Function | Notes | -| ------------------- | ------------------------------------------------------------------------- | -| bitwise_and | `bitwise_and(expr1, expr2)` or `expr1.bitwise_and(expr2)` | -| bitwise_or | `bitwise_or(expr1, expr2)` or `expr1.bitwise_or(expr2)` | -| bitwise_xor | `bitwise_xor(expr1, expr2)` or `expr1.bitwise_xor(expr2)` | -| bitwise_shift_right | `bitwise_shift_right(expr1, expr2)` or `expr1.bitwise_shift_right(expr2)` | -| bitwise_shift_left | `bitwise_shift_left(expr1, expr2)` or `expr1.bitwise_shift_left(expr2)` | +| Syntax | Description | +| ------------------------------------------- | ----------- | +| x & y, bitwise_and(x, y), x.bitand(y) | AND | +| x \| y, bitwise_or(x, y), x.bitor(y) | OR | +| x ^ y, bitwise_xor(x, y), x.bitxor(y) | XOR | +| x << y, bitwise_shift_left(x, y), x.shl(y) | Left shift | +| x >> y, bitwise_shift_right(x, y), x.shr(y) | Right shift | ## Comparison Expressions -| Function | Notes | -| -------- | --------------------- | -| eq | `expr1.eq(expr2)` | -| gt | `expr1.gt(expr2)` | -| gt_eq | `expr1.gt_eq(expr2)` | -| lt | `expr1.lt(expr2)` | -| lt_eq | `expr1.lt_eq(expr2)` | -| not_eq | `expr1.not_eq(expr2)` | +| Syntax | Description | +| ----------- | --------------------- | +| x.eq(y) | Equal | +| x.not_eq(y) | Not Equal | +| x.gt(y) | Greater Than | +| x.gt_eq(y) | Greater Than or Equal | +| x.lt(y) | Less Than | +| x.lt_eq(y) | Less Than or Equal | + +:::{note} +Comparison operators (`<`, `<=`, `==`, `>=`, `>`) could be overloaded by the `PartialOrd` and `PartialEq` trait in Rust, +but these operators always return a `bool` which makes them not work with the expression API. +::: + +## Arithmetic Expressions + +| Syntax | Description | +| ---------------- | -------------- | +| x + y, x.add(y) | Addition | +| x - y, x.sub(y) | Subtraction | +| x \* y, x.mul(y) | Multiplication | +| x / y, x.div(y) | Division | +| x % y, x.rem(y) | Remainder | +| -x, x.neg() | Negation | ## Math Functions -In addition to the math functions listed here, some Rust operators are implemented for expressions, allowing -expressions such as `col("a") + col("b")` to be used. - -| Function | Notes | +| Syntax | Description | | --------------------- | ------------------------------------------------- | | abs(x) | absolute value | | acos(x) | inverse cosine | @@ -94,11 +128,14 @@ expressions such as `col("a") + col("b")` to be used. | factorial(x) | factorial | | floor(x) | nearest integer less than or equal to argument | | gcd(x, y) | greatest common divisor | +| isnan(x) | predicate determining whether NaN/-NaN or not | +| iszero(x) | predicate determining whether 0.0/-0.0 or not | | lcm(x, y) | least common multiple | | ln(x) | natural logarithm | | log(base, x) | logarithm of x for a particular base | | log10(x) | base 10 logarithm | | log2(x) | base 2 logarithm | +| nanvl(x, y) | returns x if x is not NaN otherwise returns y | | pi() | approximate value of π | | power(base, exponent) | base raised to the power of exponent | | radians(x) | converts degrees to radians | @@ -111,11 +148,10 @@ expressions such as `col("a") + col("b")` to be used. | tanh(x) | hyperbolic tangent | | trunc(x) | truncate toward zero | -### Math functions usage notes: - +:::{note} Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g -``` +```sql ❯ select log(-1), log(0), sqrt(-1); +----------------+---------------+-----------------+ | log(Int64(-1)) | log(Int64(0)) | sqrt(Int64(-1)) | @@ -124,27 +160,19 @@ Unlike to some databases the math functions in Datafusion works the same way as +----------------+---------------+-----------------+ ``` -## Bitwise Operators - -| Operator | Notes | -| -------- | ----------------------------------------------- | -| & | Bitwise AND => `(expr1 & expr2)` | -| | | Bitwise OR => (expr1 | expr2) | -| # | Bitwise XOR => `(expr1 # expr2)` | -| << | Bitwise left shift => `(expr1 << expr2)` | -| >> | Bitwise right shift => `(expr1 << expr2)` | +::: ## Conditional Expressions -| Function | Notes | -| -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| coalesce | Returns the first of its arguments that is not null. Null is returned only if all arguments are null. It is often used to substitute a default value for null values when data is retrieved for display. | -| case | CASE expression. Example: `case(expr).when(expr, expr).when(expr, expr).otherwise(expr).end()`. | -| nullif | Returns a null value if `value1` equals `value2`; otherwise it returns `value1`. This can be used to perform the inverse operation of the `coalesce` expression. | +| Syntax | Description | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| coalesce([value, ...]) | Returns the first of its arguments that is not null. Null is returned only if all arguments are null. It is often used to substitute a default value for null values when data is retrieved for display. | +| case(expr)
    .when(expr)
    .end(),
case(expr)
    .when(expr)
    .otherwise(expr) | CASE expression. The expression may chain multiple `when` expressions and end with an `end` or `otherwise` expression. Example:
case(col("a") % lit(3))
    .when(lit(0), lit("A"))
    .when(lit(1), lit("B"))
    .when(lit(2), lit("C"))
    .end()
or, end with `otherwise` to match any other conditions:
case(col("b").gt(lit(100)))
    .when(lit(true), lit("value > 100"))
    .otherwise(lit("value <= 100"))
| +| nullif(value1, value2) | Returns a null value if `value1` equals `value2`; otherwise it returns `value1`. This can be used to perform the inverse operation of the `coalesce` expression. | ## String Expressions -| Function | Notes | +| Syntax | Description | | ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ascii(character) | Returns a numeric representation of the character (`character`). Example: `ascii('a') -> 97` | | bit_length(text) | Returns the length of the string (`text`) in bits. Example: `bit_length('spider') -> 48` | @@ -179,34 +207,51 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Array Expressions -| Function | Notes | -| ------------------------------------ | --------------------------------------------------------------- | -| array_append(array, element) | Appends an element to the end of an array. | -| array_concat(array[, ..., array_n]) | Concatenates arrays. | -| array_dims(array) | Returns an array of the array's dimensions. | -| array_fill(element, array) | Returns an array filled with copies of the given value. | -| array_length(array, dimension) | Returns the length of the array dimension. | -| array_ndims(array) | Returns the number of dimensions of the array. | -| array_position(array, element) | Searches for an element in the array, returns first occurrence. | -| array_positions(array, element) | Searches for an element in the array, returns all occurrences. | -| array_prepend(array, element) | Prepends an element to the beginning of an array. | -| array_remove(array, element) | Removes all elements equal to the given value from the array. | -| array_replace(array, from, to) | Replaces a specified element with another specified element. | -| array_to_string(array, delimeter) | Converts each element to its text representation. | -| cardinality(array) | Returns the total number of elements in the array. | -| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. | -| trim_array(array, n) | Removes the last n elements from the array. | +| Syntax | Description | +| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | +| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | +| array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | +| array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | +| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | +| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | +| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | +| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | +| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_pop_front(array) | Returns the array without the first element. `array_pop_front([1, 2, 3]) -> [2, 3]` | +| array_pop_back(array) | Returns the array without the last element. `array_pop_back([1, 2, 3]) -> [1, 2]` | +| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | +| array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | +| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | +| array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | +| array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | +| array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | +| array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | +| array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | +| array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | +| array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | +| array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` | +| array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | +| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | +| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | +| trim_array(array, n) | Deprecated | ## Regular Expressions -| Function | Notes | +| Syntax | Description | | -------------- | ----------------------------------------------------------------------------- | | regexp_match | Matches a regular expression against a string and returns matched substrings. | | regexp_replace | Replaces strings that match a regular expression | ## Temporal Expressions -| Function | Notes | +| Syntax | Description | | -------------------- | ------------------------------------------------------ | | date_part | Extracts a subfield from the date. | | date_trunc | Truncates the date to a specified level of precision. | @@ -219,7 +264,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Other Expressions -| Function | Notes | +| Syntax | Description | | ---------------------------- | ---------------------------------------------------------------------------------------------------------- | | array([value1, ...]) | Returns an array of fixed size with each argument (`[value1, ...]`) on it. | | in_list(expr, list, negated) | Returns `true` if (`expr`) belongs or not belongs (`negated`) to a list (`list`), otherwise returns false. | @@ -232,7 +277,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Aggregate Functions -| Function | Notes | +| Syntax | Description | | ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | | avg(expr) | Сalculates the average value for `expr`. | | approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | @@ -256,7 +301,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Subquery Expressions -| Function | Notes | +| Syntax | Description | | --------------- | --------------------------------------------------------------------------------------------- | | exists | Creates an `EXISTS` subquery expression | | in_subquery | `df1.filter(in_subquery(col("foo"), df2))?` is the equivalent of the SQL `WHERE foo IN ` | @@ -266,7 +311,7 @@ Unlike to some databases the math functions in Datafusion works the same way as ## User-Defined Function Expressions -| Function | Notes | +| Syntax | Description | | ----------- | ------------------------------------------------------------------------- | | create_udf | Creates a new UDF with a specific signature and specific return type. | | create_udaf | Creates a new UDAF with a specific signature, state type and return type. | diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 23157d3f36870..da250fbb1f9c0 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -22,8 +22,20 @@ DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) in-memory format. +DataFusion is part of the [Apache Arrow](https://arrow.apache.org/) +project. -DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. +DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, [python bindings], extensive customization, a great community, and more. + +[python bindings]: https://github.com/apache/arrow-datafusion-python + +## Project Goals + +DataFusion aims to be the query engine of choice for new, fast +data centric systems such as databases, dataframe libraries, machine +learning and streaming applications by leveraging the unique features +of [Rust](https://www.rust-lang.org/) and [Apache +Arrow](https://arrow.apache.org/). ## Features @@ -34,24 +46,34 @@ DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchm - Many extension points: user defined scalar/aggregate/window functions, DataSources, SQL, other query languages, custom plan and execution nodes, optimizer passes, and more. - Streaming, asynchronous IO directly from popular object stores, including AWS S3, - Azure Blob Storage, and Google Cloud Storage. Other storage systems are supported via the - `ObjectStore` trait. + Azure Blob Storage, and Google Cloud Storage (Other storage systems are supported via the + `ObjectStore` trait). - [Excellent Documentation](https://docs.rs/datafusion/latest) and a [welcoming community](https://arrow.apache.org/datafusion/contributor-guide/communication.html). -- A state of the art query optimizer with projection and filter pushdown, sort aware optimizations, - automatic join reordering, expression coercion, and more. -- Permissive Apache 2.0 License, Apache Software Foundation governance -- Written in [Rust](https://www.rust-lang.org/), a modern system language with development - productivity similar to Java or Golang, the performance of C++, and - [loved by programmers everywhere](https://insights.stackoverflow.com/survey/2021#technology-most-loved-dreaded-and-wanted). -- Support for [Substrait](https://substrait.io/) for query plan serialization, making it easier to integrate DataFusion - with other projects, and to pass plans across language boundaries. +- A state of the art query optimizer with expression coercion and + simplification, projection and filter pushdown, sort and distribution + aware optimizations, automatic join reordering, and more. +- Permissive Apache 2.0 License, predictable and well understood + [Apache Software Foundation](https://www.apache.org/) governance. +- Implementation in [Rust](https://www.rust-lang.org/), a modern + system language with development productivity similar to Java or + Golang, the performance of C++, and [loved by programmers + everywhere](https://insights.stackoverflow.com/survey/2021#technology-most-loved-dreaded-and-wanted). +- Support for [Substrait](https://substrait.io/) query plans, to + easily pass plans across language and system boundaries. ## Use Cases DataFusion can be used without modification as an embedded SQL engine or can be customized and used as a foundation for -building new systems. Here are some examples of systems built using DataFusion: +building new systems. + +While most current usecases are "analytic" or (throughput) some +components of DataFusion such as the plan representations, are +suitable for "streaming" and "transaction" style systems (low +latency). + +Here are some example systems built using DataFusion: - Specialized Analytical Database systems such as [CeresDB] and more general Apache Spark like system such a [Ballista]. - New query language engines such as [prql-query] and accelerators such as [VegaFusion] @@ -59,40 +81,50 @@ building new systems. Here are some examples of systems built using DataFusion: - SQL support to another library, such as [dask sql] - Streaming data platforms such as [Synnada] - Tools for reading / sorting / transcoding Parquet, CSV, AVRO, and JSON files such as [qv] -- A faster Spark runtime replacement [Blaze] +- Native Spark runtime replacement such as [Blaze] -By using DataFusion, the projects are freed to focus on their specific +By using DataFusion, projects are freed to focus on their specific features, and avoid reimplementing general (but still necessary) features such as an expression representation, standard optimizations, -execution plans, file format support, etc. +parellelized streaming execution plans, file format support, etc. ## Known Users -Here are some of the projects known to use DataFusion: +Here are some active projects using DataFusion: + + +- [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core - [CeresDB](https://github.com/CeresDB/ceresdb) Distributed Time-Series Database -- [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python -- [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion +- [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake -- [Flock](https://github.com/flock-lab/flock) - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database +- [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline +- [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML - [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [qv](https://github.com/timvw/qv) Quickly view your data - [bdt](https://github.com/andygrove/bdt) Boring Data Tool +- [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await - [ROAPI](https://github.com/roapi/roapi) - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Synnada](https://synnada.ai/) Streaming-first framework for data products -- [Tensorbase](https://github.com/tensorbase/tensorbase) - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [ZincObserve](https://github.com/zinclabs/zincobserve) Distributed cloud native observability platform +Here are some less active projects that used DataFusion: + +- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core +- [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) +- [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion +- [Flock](https://github.com/flock-lab/flock) +- [Tensorbase](https://github.com/tensorbase/tensorbase) + [ballista]: https://github.com/apache/arrow-ballista [blaze]: https://github.com/blaze-init/blaze [ceresdb]: https://github.com/CeresDB/ceresdb @@ -119,7 +151,7 @@ Here are some of the projects known to use DataFusion: ## Integrations and Extensions There are a number of community projects that extend DataFusion or -provide integrations with other systems. +provide integrations with other systems, some of which are described below: ### Language Bindings @@ -137,5 +169,5 @@ provide integrations with other systems. - _High Performance_: Leveraging Rust and Arrow's memory model, DataFusion is very fast. - _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet and Flight), DataFusion works well with the rest of the big data ecosystem -- _Easy to Embed_: Allowing extension at almost any point in its design, DataFusion can be tailored for your specific usecase -- _High Quality_: Extensively tested, both by itself and with the rest of the Arrow ecosystem, DataFusion can be used as the foundation for production systems. +- _Easy to Embed_: Allowing extension at almost any point in its design, and published regularly as a crate on [crates.io](http://crates.io), DataFusion can be integrated and tailored for your specific usecase. +- _High Quality_: Extensively tested, both by itself and with the rest of the Arrow ecosystem, DataFusion can and is used as the foundation for production systems. diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 132ba47e24615..427a7bf130a77 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -245,6 +245,15 @@ last_value(expression [ORDER BY expression]) - [var](#var) - [var_pop](#var_pop) - [var_samp](#var_samp) +- [regr_avgx](#regr_avgx) +- [regr_avgy](#regr_avgy) +- [regr_count](#regr_count) +- [regr_intercept](#regr_intercept) +- [regr_r2](#regr_r2) +- [regr_slope](#regr_slope) +- [regr_sxx](#regr_sxx) +- [regr_syy](#regr_syy) +- [regr_sxy](#regr_sxy) ### `corr` @@ -384,6 +393,142 @@ var_samp(expression) - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `regr_slope` + +Returns the slope of the linear regression line for non-null pairs in aggregate columns. +Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. + +``` +regr_slope(expression1, expression2) +``` + +#### Arguments + +- **expression_y**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_avgx` + +Computes the average of the independent variable (input) `expression_x` for the non-null paired data points. + +``` +regr_avgx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_avgy` + +Computes the average of the dependent variable (output) `expression_y` for the non-null paired data points. + +``` +regr_avgy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_count` + +Counts the number of non-null paired data points. + +``` +regr_count(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_intercept` + +Computes the y-intercept of the linear regression line. For the equation \(y = kx + b\), this function returns `b`. + +``` +regr_intercept(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_r2` + +Computes the square of the correlation coefficient between the independent and dependent variables. + +``` +regr_r2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_sxx` + +Computes the sum of squares of the independent variable. + +``` +regr_sxx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_syy` + +Computes the sum of squares of the dependent variable. + +``` +regr_syy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_sxy` + +Computes the sum of products of paired data points. + +``` +regr_sxy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + ## Approximate - [approx_distinct](#approx_distinct) diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 1d3455abc2365..caa08b8bae63d 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -96,6 +96,9 @@ For example, to cast the output of `now()` to a `Timestamp` with second precisio | ------------ | :------------- | | `BYTEA` | `Binary` | +You can create binary literals using a hex string literal such as +`X'1234` to create a `Binary` value of two bytes, `0x12` and `0x34`. + ## Unsupported SQL Types | SQL Data Type | Arrow DataType | @@ -148,7 +151,6 @@ The following types are supported by the `arrow_typeof` function: | `Interval(YearMonth)` | | `Interval(DayTime)` | | `Interval(MonthDayNano)` | -| `Interval(MonthDayNano)` | | `FixedSizeBinary()` (e.g. `FixedSizeBinary(16)`) | | `Decimal128(, )` e.g. `Decimal128(3, 10)` | | `Decimal256(, )` e.g. `Decimal256(3, 10)` | diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index 0dcc4517b55ac..b67d323126996 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -19,6 +19,9 @@ # DDL +DDL stands for "Data Definition Language" and relates to creating and +modifying catalog objects such as Tables. + ## CREATE DATABASE Create catalog with specified name. @@ -61,7 +64,7 @@ STORED AS [ DELIMITER ] [ COMPRESSION TYPE ] [ PARTITIONED BY () ] -[ WITH ORDER () +[ WITH ORDER () ] [ OPTIONS () ] LOCATION @@ -74,9 +77,11 @@ LOCATION := ( , ...) ``` -`file_type` is one of `CSV`, `PARQUET`, `AVRO` or `JSON` +For a detailed list of write related options which can be passed in the OPTIONS key_value_list, see [Write Options](write_options). + +`file_type` is one of `CSV`, `ARROW`, `PARQUET`, `AVRO` or `JSON` -`LOCATION ` specfies the location to find the data. It can be +`LOCATION ` specifies the location to find the data. It can be a path to a file or directory of partitioned files locally or on an object store. @@ -99,6 +104,16 @@ WITH HEADER ROW LOCATION '/path/to/aggregate_simple.csv'; ``` +It is also possible to use compressed files, such as `.csv.gz`: + +```sql +CREATE EXTERNAL TABLE test +STORED AS CSV +COMPRESSION TYPE GZIP +WITH HEADER ROW +LOCATION '/path/to/aggregate_simple.csv.gz'; +``` + It is also possible to specify the schema manually. ```sql diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md new file mode 100644 index 0000000000000..c3226936e7ac0 --- /dev/null +++ b/docs/source/user-guide/sql/dml.md @@ -0,0 +1,88 @@ + + +# DML + +DML stands for "Data Manipulation Language" and relates to inserting +and modifying data in tables. + +## COPY + +Copies the contents of a table or query to file(s). Supported file +formats are `parquet`, `csv`, and `json` and can be inferred based on +filename if writing to a single file. + +
+COPY { table_name | query } TO 'file_name' [ ( option [, ... ] ) ]
+
+ +For a detailed list of valid OPTIONS, see [Write Options](write_options). + +Copy the contents of `source_table` to `file_name.json` in JSON format: + +```sql +> COPY source_table TO 'file_name.json'; ++-------+ +| count | ++-------+ +| 2 | ++-------+ +``` + +Copy the contents of `source_table` to one or more Parquet formatted +files in the `dir_name` directory: + +```sql +> COPY source_table TO 'dir_name' (FORMAT parquet, SINGLE_FILE_OUTPUT false); ++-------+ +| count | ++-------+ +| 2 | ++-------+ +``` + +Run the query `SELECT * from source ORDER BY time` and write the +results (maintaining the order) to a parquet file named +`output.parquet` with a maximum parquet row group size of 10MB: + +```sql +> COPY (SELECT * from source ORDER BY time) TO 'output.parquet' (ROW_GROUP_LIMIT_BYTES 10000000); ++-------+ +| count | ++-------+ +| 2 | ++-------+ +``` + +## INSERT + +Insert values into a table. + +
+INSERT INTO table_name { VALUES ( expression [, ...] ) [, ...] | query }
+
+ +```sql +> INSERT INTO target_table VALUES (1, 'Foo'), (2, 'Bar'); ++-------+ +| count | ++-------+ +| 2 | ++-------+ +``` diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index dab27960168d5..04d1fc228f816 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -25,9 +25,12 @@ SQL Reference select subqueries ddl + dml explain information_schema + operators aggregate_functions window_functions scalar_functions sql_status + write_options diff --git a/docs/source/user-guide/sql/operators.md b/docs/source/user-guide/sql/operators.md new file mode 100644 index 0000000000000..265e56bb2c348 --- /dev/null +++ b/docs/source/user-guide/sql/operators.md @@ -0,0 +1,413 @@ + + +# Operators + +## Numerical Operators + +- [+ (plus)](#id1) +- [- (minus)](#id2) +- [\* (multiply)](#id3) +- [/ (divide)](#id4) +- [% (modulo)](#id5) + +### `+` + +Addition + +```sql +> SELECT 1 + 2; ++---------------------+ +| Int64(1) + Int64(2) | ++---------------------+ +| 3 | ++---------------------+ +``` + +### `-` + +Subtraction + +```sql +> SELECT 4 - 3; ++---------------------+ +| Int64(4) - Int64(3) | ++---------------------+ +| 1 | ++---------------------+ +``` + +### `*` + +Multiplication + +```sql +> SELECT 2 * 3; ++---------------------+ +| Int64(2) * Int64(3) | ++---------------------+ +| 6 | ++---------------------+ +``` + +### `/` + +Division (integer division truncates toward zero) + +```sql +> SELECT 8 / 4; ++---------------------+ +| Int64(8) / Int64(4) | ++---------------------+ +| 2 | ++---------------------+ +``` + +### `%` + +Modulo (remainder) + +```sql +> SELECT 7 % 3; ++---------------------+ +| Int64(7) % Int64(3) | ++---------------------+ +| 1 | ++---------------------+ +``` + +## Comparison Operators + +- [= (equal)](#id6) +- [!= (not equal)](#id7) +- [< (less than)](#id8) +- [<= (less than or equal to)](#id9) +- [> (greater than)](#id10) +- [>= (greater than or equal to)](#id11) +- [IS DISTINCT FROM](#is-distinct-from) +- [IS NOT DISTINCT FROM](#is-not-distinct-from) +- [~ (regex match)](#id12) +- [~\* (regex case-insensitive match)](#id13) +- [!~ (not regex match)](#id14) +- [!~\* (not regex case-insensitive match)](#id15) + +### `=` + +Equal + +```sql +> SELECT 1 = 1; ++---------------------+ +| Int64(1) = Int64(1) | ++---------------------+ +| true | ++---------------------+ +``` + +### `!=` + +Not Equal + +```sql +> SELECT 1 != 2; ++----------------------+ +| Int64(1) != Int64(2) | ++----------------------+ +| true | ++----------------------+ +``` + +### `<` + +Less Than + +```sql +> SELECT 3 < 4; ++---------------------+ +| Int64(3) < Int64(4) | ++---------------------+ +| true | ++---------------------+ +``` + +### `<=` + +Less Than or Equal To + +```sql +> SELECT 3 <= 3; ++----------------------+ +| Int64(3) <= Int64(3) | ++----------------------+ +| true | ++----------------------+ +``` + +### `>` + +Greater Than + +```sql +> SELECT 6 > 5; ++---------------------+ +| Int64(6) > Int64(5) | ++---------------------+ +| true | ++---------------------+ +``` + +### `>=` + +Greater Than or Equal To + +```sql +> SELECT 5 >= 5; ++----------------------+ +| Int64(5) >= Int64(5) | ++----------------------+ +| true | ++----------------------+ +``` + +### `IS DISTINCT FROM` + +Guarantees the result of a comparison is `true` or `false` and not an empty set + +```sql +> SELECT 0 IS DISTINCT FROM NULL; ++--------------------------------+ +| Int64(0) IS DISTINCT FROM NULL | ++--------------------------------+ +| true | ++--------------------------------+ +``` + +### `IS NOT DISTINCT FROM` + +The negation of `IS DISTINCT FROM` + +```sql +> SELECT NULL IS NOT DISTINCT FROM NULL; ++--------------------------------+ +| NULL IS NOT DISTINCT FROM NULL | ++--------------------------------+ +| true | ++--------------------------------+ +``` + +### `~` + +Regex Match + +```sql +> SELECT 'datafusion' ~ '^datafusion(-cli)*'; ++-------------------------------------------------+ +| Utf8("datafusion") ~ Utf8("^datafusion(-cli)*") | ++-------------------------------------------------+ +| true | ++-------------------------------------------------+ +``` + +### `~*` + +Regex Case-Insensitive Match + +```sql +> SELECT 'datafusion' ~* '^DATAFUSION(-cli)*'; ++--------------------------------------------------+ +| Utf8("datafusion") ~* Utf8("^DATAFUSION(-cli)*") | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` + +### `!~` + +Not Regex Match + +```sql +> SELECT 'datafusion' !~ '^DATAFUSION(-cli)*'; ++--------------------------------------------------+ +| Utf8("datafusion") !~ Utf8("^DATAFUSION(-cli)*") | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` + +### `!~*` + +Not Regex Case-Insensitive Match + +```sql +> SELECT 'datafusion' !~* '^DATAFUSION(-cli)+'; ++---------------------------------------------------+ +| Utf8("datafusion") !~* Utf8("^DATAFUSION(-cli)+") | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +``` + +## Logical Operators + +- [AND](#and) +- [OR](#or) + +### `AND` + +Logical And + +```sql +> SELECT true AND true; ++---------------------------------+ +| Boolean(true) AND Boolean(true) | ++---------------------------------+ +| true | ++---------------------------------+ +``` + +### `OR` + +Logical Or + +```sql +> SELECT false OR true; ++---------------------------------+ +| Boolean(false) OR Boolean(true) | ++---------------------------------+ +| true | ++---------------------------------+ +``` + +## Bitwise Operators + +- [& (bitwise and)](#id16) +- [| (bitwise or)](#id17) +- [# (bitwise xor)](#id18) +- [>> (bitwise shift right)](#id19) +- [<< (bitwise shift left)](#id20) + +### `&` + +Bitwise And + +```sql +> SELECT 5 & 3; ++---------------------+ +| Int64(5) & Int64(3) | ++---------------------+ +| 1 | ++---------------------+ +``` + +### `|` + +Bitwise Or + +```sql +> SELECT 5 | 3; ++---------------------+ +| Int64(5) | Int64(3) | ++---------------------+ +| 7 | ++---------------------+ +``` + +### `#` + +Bitwise Xor (interchangeable with `^`) + +```sql +> SELECT 5 # 3; ++---------------------+ +| Int64(5) # Int64(3) | ++---------------------+ +| 6 | ++---------------------+ +``` + +### `>>` + +Bitwise Shift Right + +```sql +> SELECT 5 >> 3; ++----------------------+ +| Int64(5) >> Int64(3) | ++----------------------+ +| 0 | ++----------------------+ +``` + +### `<<` + +Bitwise Shift Left + +```sql +> SELECT 5 << 3; ++----------------------+ +| Int64(5) << Int64(3) | ++----------------------+ +| 40 | ++----------------------+ +``` + +## Other Operators + +- [|| (string concatenation)](#id21) +- [@> (array contains)](#id22) +- [<@ (array is contained by)](#id23) + +### `||` + +String Concatenation + +```sql +> SELECT 'Hello, ' || 'DataFusion!'; ++----------------------------------------+ +| Utf8("Hello, ") || Utf8("DataFusion!") | ++----------------------------------------+ +| Hello, DataFusion! | ++----------------------------------------+ +``` + +### `@>` + +Array Contains + +```sql +> SELECT make_array(1,2,3) @> make_array(1,3); ++-------------------------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3)) @> make_array(Int64(1),Int64(3)) | ++-------------------------------------------------------------------------+ +| true | ++-------------------------------------------------------------------------+ +``` + +### `<@` + +Array Is Contained By + +```sql +> SELECT make_array(1,3) <@ make_array(1,2,3); ++-------------------------------------------------------------------------+ +| make_array(Int64(1),Int64(3)) <@ make_array(Int64(1),Int64(2),Int64(3)) | ++-------------------------------------------------------------------------+ +| true | ++-------------------------------------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 4b2a556806b7c..9a9bec9df77b3 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -38,11 +38,14 @@ - [factorial](#factorial) - [floor](#floor) - [gcd](#gcd) +- [isnan](#isnan) +- [iszero](#iszero) - [lcm](#lcm) - [ln](#ln) - [log](#log) - [log10](#log10) - [log2](#log2) +- [nanvl](#nanvl) - [pi](#pi) - [power](#power) - [pow](#pow) @@ -282,6 +285,32 @@ gcd(expression_x, expression_y) - **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `isnan` + +Returns true if a given number is +NaN or -NaN otherwise returns false. + +``` +isnan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `iszero` + +Returns true if a given number is +0.0 or -0.0 otherwise returns false. + +``` +iszero(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + ### `lcm` Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. @@ -353,6 +382,22 @@ log2(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `nanvl` + +Returns the first argument if it's not _NaN_. +Returns the second argument otherwise. + +``` +nanvl(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: Numeric expression to return if it's not _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. + ### `pi` Returns an approximate value of π. @@ -396,7 +441,6 @@ radians(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. - ======= ### `random` @@ -412,13 +456,15 @@ random() Rounds a number to the nearest integer. ``` -round(numeric_expression) +round(numeric_expression[, decimal_places]) ``` #### Arguments - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **decimal_places**: Optional. The number of decimal places to round to. + Defaults to 0. ### `signum` @@ -502,10 +548,10 @@ tanh(numeric_expression) ### `trunc` -Truncates a number toward zero (at the decimal point). +Truncates a number to a whole number or truncated to the specified decimal places. ``` -trunc(numeric_expression) +trunc(numeric_expression[, decimal_places]) ``` #### Arguments @@ -513,6 +559,12 @@ trunc(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **decimal_places**: Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`. + ## Conditional Functions - [coalesce](#coalesce) @@ -583,6 +635,10 @@ nullif(expression1, expression2) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) +- [overlay](#overlay) +- [levenshtein](#levenshtein) +- [substr_index](#substr_index) +- [find_in_set](#find_in_set) ### `ascii` @@ -1068,6 +1124,106 @@ Returns UUID v4 string value which is unique per row. uuid() ``` +### `overlay` + +Returns the string which is replaced by another string from the specified position and specified count length. +For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` + +``` +overlay(str PLACING substr FROM pos [FOR count]) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **substr**: the string to replace part of str. +- **pos**: the start position to replace of str. +- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. + +### `levenshtein` + +Returns the Levenshtein distance between the two given strings. +For example, `levenshtein('kitten', 'sitting') = 3` + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. +- **delim**: the string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. +For example, `find_in_set('b', 'a,b,c,d') = 2` + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + +## Binary String Functions + +- [decode](#decode) +- [encode](#encode) + +### `encode` + +Encode binary data into a textual representation. + +``` +encode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing string or binary data + +- **format**: Supported formats are: `base64`, `hex` + +**Related functions**: +[decode](#decode) + +### `decode` + +Decode binary data from textual representation in string. + +``` +decode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing encoded string data + +- **format**: Same arguments as [encode](#encode) + +**Related functions**: +[encode](#encode) + ## Regular Expression Functions Apache DataFusion uses the POSIX regular expression syntax and @@ -1127,6 +1283,7 @@ regexp_replace(str, regexp, replacement, flags) - [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) - [to_timestamp_seconds](#to_timestamp_seconds) +- [to_timestamp_nanos](#to_timestamp_nanos) - [from_unixtime](#from_unixtime) ### `now` @@ -1299,10 +1456,14 @@ extract(field FROM source) ### `to_timestamp` -Converts a value to RFC3339 nanosecond timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 nanosecond timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). +Supports strings, integer, unsigned integer, and double types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. +Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. ``` to_timestamp(expression) @@ -1315,10 +1476,11 @@ to_timestamp(expression) ### `to_timestamp_millis` -Converts a value to RFC3339 millisecond timestamp format (`YYYY-MM-DDT00:00:00.000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` to_timestamp_millis(expression) @@ -1331,13 +1493,26 @@ to_timestamp_millis(expression) ### `to_timestamp_micros` -Converts a value to RFC3339 microsecond timestamp format (`YYYY-MM-DDT00:00:00.000000Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. + +``` +to_timestamp_nanos(expression) +``` + +### `to_timestamp_nanos` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` -to_timestamp_micros(expression) +to_timestamp_nanos(expression) ``` #### Arguments @@ -1347,10 +1522,11 @@ to_timestamp_micros(expression) ### `to_timestamp_seconds` -Converts a value to RFC3339 second timestamp format (`YYYY-MM-DDT00:00:00Z`). -Supports timestamp, integer, and unsigned integer types as input. -Integers and unsigned integers are parsed as Unix nanosecond timestamps and -return the corresponding RFC3339 timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). +Supports strings, integer, and unsigned integer types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` to_timestamp_seconds(expression) @@ -1364,8 +1540,8 @@ to_timestamp_seconds(expression) ### `from_unixtime` Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Input is parsed as a Unix nanosecond timestamp and returns the corresponding -RFC3339 timestamp. +Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) +return the corresponding timestamp. ``` from_unixtime(expression) @@ -1379,20 +1555,66 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) +- [array_cat](#array_cat) - [array_concat](#array_concat) +- [array_contains](#array_contains) - [array_dims](#array_dims) -- [array_fill](#array_fill) +- [array_element](#array_element) +- [array_extract](#array_extract) +- [array_indexof](#array_indexof) +- [array_join](#array_join) - [array_length](#array_length) - [array_ndims](#array_ndims) +- [array_prepend](#array_prepend) +- [array_pop_front](#array_pop_front) +- [array_pop_back](#array_pop_back) - [array_position](#array_position) - [array_positions](#array_positions) -- [array_prepend](#array_prepend) +- [array_push_back](#array_push_back) +- [array_push_front](#array_push_front) +- [array_repeat](#array_repeat) - [array_remove](#array_remove) +- [array_remove_n](#array_remove_n) +- [array_remove_all](#array_remove_all) - [array_replace](#array_replace) +- [array_replace_n](#array_replace_n) +- [array_replace_all](#array_replace_all) +- [array_slice](#array_slice) - [array_to_string](#array_to_string) - [cardinality](#cardinality) +- [empty](#empty) +- [list_append](#list_append) +- [list_sort](#list_sort) +- [list_cat](#list_cat) +- [list_concat](#list_concat) +- [list_dims](#list_dims) +- [list_element](#list_element) +- [list_extract](#list_extract) +- [list_indexof](#list_indexof) +- [list_join](#list_join) +- [list_length](#list_length) +- [list_ndims](#list_ndims) +- [list_prepend](#list_prepend) +- [list_position](#list_position) +- [list_positions](#list_positions) +- [list_push_back](#list_push_back) +- [list_push_front](#list_push_front) +- [list_repeat](#list_repeat) +- [list_remove](#list_remove) +- [list_remove_n](#list_remove_n) +- [list_remove_all](#list_remove_all) +- [list_replace](#list_replace) +- [list_replace_n](#list_replace_n) +- [list_replace_all](#list_replace_all) +- [list_slice](#list_slice) +- [list_to_string](#list_to_string) - [make_array](#make_array) +- [make_list](#make_list) +- [string_to_array](#string_to_array) +- [string_to_list](#string_to_list) - [trim_array](#trim_array) +- [range](#range) ### `array_append` @@ -1408,275 +1630,1136 @@ array_append(array, element) Can be a constant, column, or function, and any combination of array operators. - **element**: Element to append to the array. -### `array_concat` - -Concatenates arrays. +#### Example ``` -array_concat(array[, ..., array_n]) +❯ select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ ``` -#### Arguments +#### Aliases -- **array**: Array expression to concatenate. - Can be a constant, column, or function, and any combination of array operators. -- **array_n**: Subsequent array column or literal array to concatenate. +- array_push_back +- list_append +- list_push_back -### `array_dims` +### `array_sort` -Returns an array of the array's dimensions. +Sort array. ``` -array_dims(array) +array_sort(array, desc, nulls_first) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). -### `array_fill` - -Returns an array filled with copies of the given value. +#### Example ``` -array_fill(element, array) +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ ``` -#### Arguments +#### Aliases -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to copy to the array. +- list_sort -### `array_length` +### `array_cat` -Returns the length of the array dimension. +_Alias of [array_concat](#array_concat)._ + +### `array_concat` + +Concatenates arrays. ``` -array_length(array, dimension) +array_concat(array[, ..., array_n]) ``` #### Arguments -- **array**: Array expression. +- **array**: Array expression to concatenate. Can be a constant, column, or function, and any combination of array operators. -- **dimension**: Array dimension. - -### `array_ndims` +- **array_n**: Subsequent array column or literal array to concatenate. -Returns the number of dimensions of the array. +#### Example ``` -array_ndims(array, element) +❯ select array_concat([1, 2], [3, 4], [5, 6]); ++---------------------------------------------------+ +| array_concat(List([1,2]),List([3,4]),List([5,6])) | ++---------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++---------------------------------------------------+ ``` -#### Arguments +#### Aliases -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- array_cat +- list_cat +- list_concat -### `array_position` +### `array_has` -Returns a string with an input string repeated a specified number. +Returns true if the array contains the element ``` -array_position(array, element) -array_position(array, element, index) +array_has(array, element) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. -- **index**: Index at which to start searching. +- **element**: Scalar or Array expression. + Can be a constant, column, or function, and any combination of array operators. -### `array_positions` +### `array_has_all` -Searches for an element in the array, returns all occurrences. +Returns true if all elements of sub-array exist in array ``` -array_positions(array, element) +array_has_all(array, sub-array) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for positions in the array. +- **sub-array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. -### `array_prepend` +### `array_has_any` -Prepends an element to the beginning of an array. +Returns true if any elements exist in both arrays ``` -array_prepend(element, array) +array_has_any(array, sub-array) ``` #### Arguments -- **element**: Element to prepend to the array. - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **sub-array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. -### `array_remove` +### `array_dims` -Removes all elements equal to the given value from the array. +Returns an array of the array's dimensions. ``` -array_remove(array, element) +array_dims(array) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -### `array_replace` -Replaces a specified element with another specified element. +#### Example ``` -array_replace(array, from, to) +❯ select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ ``` -#### Arguments +#### Aliases -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. +- list_dims -### `array_to_string` +### `array_element` -Converts each element to its text representation. +Extracts the element with the index n from the array. ``` -array_to_string(array, delimeter) +array_element(array, index) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **delimeter**: Array element separator. +- **index**: Index to extract the element from the array. -### `cardinality` - -Returns the total number of elements in the array. +#### Example ``` -cardinality(array) +❯ select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ ``` -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +#### Aliases -### `make_array` +- array_extract +- list_element +- list_extract -Returns an Arrow array using the specified input expressions. +### `array_empty` -``` -make_array(expression1[, ..., expression_n]) -``` +### `array_extract` -#### Arguments +_Alias of [array_element](#array_element)._ -- **expression_n**: Expression to include in the output array. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. +### `array_fill` -### `trim_array` +Returns an array filled with copies of the given value. -Removes the last n elements from the array. +DEPRECATED: use `array_repeat` instead! ``` -trim_array(array, n) +array_fill(element, array) ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **n**: Element to trim the array. +- **element**: Element to copy to the array. -## Hashing Functions +### `flatten` -- [digest](#digest) -- [md5](#md5) -- [sha224](#sha224) -- [sha256](#sha256) -- [sha384](#sha384) -- [sha512](#sha512) +Converts an array of arrays to a flat array -### `digest` +- Applies to any depth of nested arrays +- Does not change arrays that are already flat -Computes the binary hash of an expression using the specified algorithm. +The flattened array contains all the elements from all source arrays. + +#### Arguments + +- **array**: Array expression + Can be a constant, column, or function, and any combination of array operators. ``` -digest(expression, algorithm) +flatten(array) ``` -#### Arguments +### `array_indexof` -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **algorithm**: String expression specifying algorithm to use. - Must be one of: +_Alias of [array_position](#array_position)._ - - md5 - - sha224 - - sha256 - - sha384 - - sha512 - - blake2s - - blake2b - - blake3 +### `array_join` -### `md5` +_Alias of [array_to_string](#array_to_string)._ -Computes an MD5 128-bit checksum for a string expression. +### `array_length` + +Returns the length of the array dimension. ``` -md5(expression) +array_length(array, dimension) ``` #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha224` +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **dimension**: Array dimension. -Computes the SHA-224 hash of a binary string. +#### Example ``` -sha224(expression) +❯ select array_length([1, 2, 3, 4, 5]); ++---------------------------------+ +| array_length(List([1,2,3,4,5])) | ++---------------------------------+ +| 5 | ++---------------------------------+ ``` -#### Arguments +#### Aliases -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- list_length -### `sha256` +### `array_ndims` -Computes the SHA-256 hash of a binary string. +Returns the number of dimensions of the array. ``` -sha256(expression) +array_ndims(array, element) ``` #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. -### `sha384` +#### Example + +``` +❯ select array_ndims([[1, 2, 3], [4, 5, 6]]); ++----------------------------------+ +| array_ndims(List([1,2,3,4,5,6])) | ++----------------------------------+ +| 2 | ++----------------------------------+ +``` + +#### Aliases + +- list_ndims + +### `array_prepend` + +Prepends an element to the beginning of an array. + +``` +array_prepend(element, array) +``` + +#### Arguments + +- **element**: Element to prepend to the array. +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_prepend(1, [2, 3, 4]); ++---------------------------------------+ +| array_prepend(Int64(1),List([2,3,4])) | ++---------------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------------+ +``` + +#### Aliases + +- array_push_front +- list_prepend +- list_push_front + +### `array_pop_front` + +Returns the array without the first element. + +``` +array_pop_first(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_pop_first([1, 2, 3]); ++-------------------------------+ +| array_pop_first(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +``` + +### `array_pop_back` + +Returns the array without the last element. + +``` +array_pop_back(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_pop_back([1, 2, 3]); ++-------------------------------+ +| array_pop_back(List([1,2,3])) | ++-------------------------------+ +| [1, 2] | ++-------------------------------+ +``` + +### `array_position` + +Returns a string with an input string repeated a specified number. + +``` +array_position(array, element) +array_position(array, element, index) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. + +#### Example + +``` +❯ select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +``` + +#### Aliases + +- array_indexof +- list_indexof +- list_position + +### `array_positions` + +Searches for an element in the array, returns all occurrences. + +``` +array_positions(array, element) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for positions in the array. + +#### Example + +``` +❯ select array_positions([1, 2, 2, 3, 1, 4], 2); ++-----------------------------------------------+ +| array_positions(List([1,2,2,3,1,4]),Int64(2)) | ++-----------------------------------------------+ +| [2, 3] | ++-----------------------------------------------+ +``` + +#### Aliases + +- list_positions + +### `array_push_back` + +_Alias of [array_append](#array_append)._ + +### `array_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `array_repeat` + +Returns an array containing element `count` times. + +``` +array_repeat(element, count) +``` + +#### Arguments + +- **element**: Element expression. + Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. + +#### Example + +``` +❯ select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +``` + +``` +❯ select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +``` + +### `array_remove` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +``` +❯ select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove + +### `array_remove_n` + +Removes the first `max` elements from the array equal to the given value. + +``` +array_remove_n(array, element, max) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. +- **max**: Number of first occurrences to remove. + +#### Example + +``` +❯ select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); ++---------------------------------------------------------+ +| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | ++---------------------------------------------------------+ +| [1, 3, 2, 1, 4] | ++---------------------------------------------------------+ +``` + +#### Aliases + +- list_remove_n + +### `array_remove_all` + +Removes all elements from the array equal to the given value. + +``` +array_remove_all(array, element) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +``` +❯ select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); ++--------------------------------------------------+ +| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | ++--------------------------------------------------+ +| [1, 3, 1, 4] | ++--------------------------------------------------+ +``` + +#### Aliases + +- list_remove_all + +### `array_replace` + +Replaces the first occurrence of the specified element with another specified element. + +``` +array_replace(array, from, to) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. + +#### Example + +``` +❯ select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); ++--------------------------------------------------------+ +| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++--------------------------------------------------------+ +| [1, 5, 2, 3, 2, 1, 4] | ++--------------------------------------------------------+ +``` + +#### Aliases + +- list_replace + +### `array_replace_n` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +``` +❯ select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_n + +### `array_replace_all` + +Replaces all occurrences of the specified element with another specified element. + +``` +array_replace_all(array, from, to) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. + +#### Example + +``` +❯ select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); ++------------------------------------------------------------+ +| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++------------------------------------------------------------+ +| [1, 5, 5, 3, 5, 1, 4] | ++------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_all + +### `array_slice` + +Returns a slice of the array. + +``` +array_slice(array, begin, end) +``` + +#### Example + +``` +❯ select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +``` + +#### Aliases + +- list_slice + +### `array_to_string` + +Converts each element to its text representation. + +``` +array_to_string(array, delimiter) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. + +#### Example + +``` +❯ select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` + +#### Aliases + +- array_join +- list_join +- list_to_string + +### `array_union` + +Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. + +``` +array_union(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_union + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +❯ select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +--- + +#### Aliases + +- list_except + +### `cardinality` + +Returns the total number of elements in the array. + +``` +cardinality(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +``` + +### `empty` + +Returns 1 for an empty array or 0 for a non-empty array. + +``` +empty(array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +#### Example + +``` +❯ select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +``` + +### `list_append` + +_Alias of [array_append](#array_append)._ + +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + +### `list_cat` + +_Alias of [array_concat](#array_concat)._ + +### `list_concat` + +_Alias of [array_concat](#array_concat)._ + +### `list_dims` + +_Alias of [array_dims](#array_dims)._ + +### `list_element` + +_Alias of [array_element](#array_element)._ + +### `list_extract` + +_Alias of [array_element](#array_element)._ + +### `list_indexof` + +_Alias of [array_position](#array_position)._ + +### `list_join` + +_Alias of [array_to_string](#array_to_string)._ + +### `list_length` + +_Alias of [array_length](#array_length)._ + +### `list_ndims` + +_Alias of [array_ndims](#array_ndims)._ + +### `list_prepend` + +_Alias of [array_prepend](#array_prepend)._ + +### `list_position` + +_Alias of [array_position](#array_position)._ + +### `list_positions` + +_Alias of [array_positions](#array_positions)._ + +### `list_push_back` + +_Alias of [array_append](#array_append)._ + +### `list_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `list_repeat` + +_Alias of [array_repeat](#array_repeat)._ + +### `list_remove` + +_Alias of [array_remove](#array_remove)._ + +### `list_remove_n` + +_Alias of [array_remove_n](#array_remove_n)._ + +### `list_remove_all` + +_Alias of [array_remove_all](#array_remove_all)._ + +### `list_replace` + +_Alias of [array_replace](#array_replace)._ + +### `list_replace_n` + +_Alias of [array_replace_n](#array_replace_n)._ + +### `list_replace_all` + +_Alias of [array_replace_all](#array_replace_all)._ + +### `list_slice` + +_Alias of [array_slice](#array_slice)._ + +### `list_to_string` + +_Alias of [list_to_string](#list_to_string)._ + +### `make_array` + +Returns an Arrow array using the specified input expressions. + +``` +make_array(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression_n**: Expression to include in the output array. + Can be a constant, column, or function, and any combination of arithmetic or + string operators. + +#### Example + +``` +❯ select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +``` + +#### Aliases + +- make_list + +### `make_list` + +_Alias of [make_array](#make_array)._ + +### `string_to_array` + +Splits a string in to an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL. + +``` +starts_with(str, delimiter[, null_str]) +``` + +#### Arguments + +- **str**: String expression to split. +- **delimiter**: Delimiter string to split on. +- **null_str**: Substring values to be replaced with `NULL` + +#### Aliases + +- string_to_list + +### `string_to_list` + +_Alias of [string_to_array](#string_to_array)._ + +### `trim_array` + +Removes the last n elements from the array. + +DEPRECATED: use `array_slice` instead! + +``` +trim_array(array, n) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **n**: Element to trim the array. + +### `range` + +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` + +The range start..end contains all values with start <= x < end. It is empty if start >= end. + +Step can not be 0 (then the range will be nonsense.). + +#### Arguments + +- **start**: start of the range +- **end**: end of the range (not included) +- **step**: increase by step (can not be 0) + +## Struct Functions + +- [struct](#struct) + +### `struct` + +Returns an Arrow struct using the specified input expressions. +Fields in the returned struct use the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc. + +``` +struct(expression1[, ..., expression_n]) +``` + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `c0` and `c1`: + +```sql +❯ select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +❯ select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ +``` + +#### Arguments + +- **expression_n**: Expression to include in the output struct. + Can be a constant, column, or function, and any combination of arithmetic or + string operators. + +## Hashing Functions + +- [digest](#digest) +- [md5](#md5) +- [sha224](#sha224) +- [sha256](#sha256) +- [sha384](#sha384) +- [sha512](#sha512) + +### `digest` + +Computes the binary hash of an expression using the specified algorithm. + +``` +digest(expression, algorithm) +``` + +#### Arguments + +- **expression**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. +- **algorithm**: String expression specifying algorithm to use. + Must be one of: + + - md5 + - sha224 + - sha256 + - sha384 + - sha512 + - blake2s + - blake2b + - blake3 + +### `md5` + +Computes an MD5 128-bit checksum for a string expression. + +``` +md5(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. + +### `sha224` + +Computes the SHA-224 hash of a binary string. + +``` +sha224(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. + +### `sha256` + +Computes the SHA-256 hash of a binary string. + +``` +sha256(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. + +### `sha384` Computes the SHA-384 hash of a binary string. @@ -1706,7 +2789,6 @@ sha512(expression) - [arrow_cast](#arrow_cast) - [arrow_typeof](#arrow_typeof) -- [struct](#struct) ### `arrow_cast` @@ -1721,12 +2803,28 @@ arrow_cast(expression, datatype) - **expression**: Expression to cast. Can be a constant, column, or function, and any combination of arithmetic or string operators. -- **datatype**: [Arrow data type](https://arrow.apache.org/datafusion/user-guide/sql/data_types.html) - to cast to. +- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name + to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] + +#### Example + +``` +❯ select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +1 row in set. Query took 0.001 seconds. +``` ### `arrow_typeof` -Returns the underlying Arrow data type of the expression: +Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression: ``` arrow_typeof(expression) @@ -1738,18 +2836,14 @@ arrow_typeof(expression) Can be a constant, column, or function, and any combination of arithmetic or string operators. -### `struct` - -Returns an Arrow struct using the specified input expressions. -Fields in the returned struct use the `cN` naming convention. -For example: `c0`, `c1`, `c2`, etc. +#### Example ``` -struct(expression1[, ..., expression_n]) +❯ select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +1 row in set. Query took 0.001 seconds. ``` - -#### Arguments - -- **expression_n**: Expression to include in the output struct. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/sql/sql_status.md index 6075a23330a8f..709534adf46ec 100644 --- a/docs/source/user-guide/sql/sql_status.md +++ b/docs/source/user-guide/sql/sql_status.md @@ -34,97 +34,61 @@ ## SQL Support -- [x] Projection -- [x] Filter (WHERE) -- [x] Filter post-aggregate (HAVING) -- [x] Limit -- [x] Aggregate -- [x] Common math functions -- [x] cast -- [x] try_cast +- [x] Projection (`SELECT`) +- [x] Filter (`WHERE`) +- [x] Filter post-aggregate (`HAVING`) +- [x] Sorting (`ORDER BY`) +- [x] Limit (`LIMIT` +- [x] Aggregate (`GROUP BY`) +- [x] cast /try_cast - [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) -- Postgres compatible String functions - - [x] ascii - - [x] bit_length - - [x] btrim - - [x] char_length - - [x] character_length - - [x] chr - - [x] concat - - [x] concat_ws - - [x] initcap - - [x] left - - [x] length - - [x] lpad - - [x] ltrim - - [x] octet_length - - [x] regexp_replace - - [x] repeat - - [x] replace - - [x] reverse - - [x] right - - [x] rpad - - [x] rtrim - - [x] split_part - - [x] starts_with - - [x] strpos - - [x] substr - - [x] to_hex - - [x] translate - - [x] trim -- Conditional functions - - [x] nullif - - [x] case - - [x] coalesce -- Approximation functions - - [x] approx_distinct - - [x] approx_median - - [x] approx_percentile_cont - - [x] approx_percentile_cont_with_weight -- Common date/time functions - - [ ] Basic date functions - - [ ] Basic time functions - - [x] Basic timestamp functions - - [x] [to_timestamp](./scalar_functions.md#to_timestamp) - - [x] [to_timestamp_millis](./scalar_functions.md#to_timestamp_millis) - - [x] [to_timestamp_micros](./scalar_functions.md#to_timestamp_micros) - - [x] [to_timestamp_seconds](./scalar_functions.md#to_timestamp_seconds) - - [x] [extract](./scalar_functions.md#extract) - - [x] [date_part](./scalar_functions.md#date_part) -- nested functions - - [x] Array of columns +- [x] [String Functions](./scalar_functions.md#string-functions) +- [x] [Conditional Functions](./scalar_functions.md#conditional-functions) +- [x] [Time and Date Functions](./scalar_functions.md#time-and-date-functions) +- [x] [Math Functions](./scalar_functions.md#math-functions) +- [x] [Aggregate Functions](./aggregate_functions.md) (`SUM`, `MEDIAN`, and many more) - [x] Schema Queries - - [x] SHOW TABLES - - [x] SHOW COLUMNS FROM
- - [x] SHOW CREATE TABLE - - [x] information_schema.{tables, columns, views} - - [ ] information_schema other views -- [x] Sorting -- [ ] Nested types -- [ ] Lists + - [x] `SHOW TABLES` + - [x] `SHOW COLUMNS FROM
` + - [x] `SHOW CREATE TABLE ` + - [x] Basic SQL [Information Schema](./information_schema.md) (`TABLES`, `VIEWS`, `COLUMNS`) + - [ ] Full SQL [Information Schema](./information_schema.md) support +- [ ] Support for nested types (`ARRAY`/`LIST` and `STRUCT`. See [#2326](https://github.com/apache/arrow-datafusion/issues/2326) for details) + - [x] Read support + - [x] Write support + - [x] Field access (`col['field']` and [`col[1]`]) + - [x] [Array Functions](./scalar_functions.md#array-functions) + - [ ] [Struct Functions](./scalar_functions.md#struct-functions) + - [x] `struct` + - [ ] [Postgres JSON operators](https://github.com/apache/arrow-datafusion/issues/6631) (`->`, `->>`, etc.) - [x] Subqueries -- [x] Common table expressions -- [x] Set Operations - - [x] UNION ALL - - [x] UNION - - [x] INTERSECT - - [x] INTERSECT ALL - - [x] EXCEPT - - [x] EXCEPT ALL -- [x] Joins - - [x] INNER JOIN - - [x] LEFT JOIN - - [x] RIGHT JOIN - - [x] FULL JOIN - - [x] CROSS JOIN -- [ ] Window - - [x] Empty window - - [x] Common window functions - - [x] Window with PARTITION BY clause - - [x] Window with ORDER BY clause - - [ ] Window with FILTER clause - - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361) - - [ ] UDF and UDAF for window functions +- [x] Common Table Expressions (CTE) +- [x] Set Operations (`UNION [ALL]`, `INTERSECT [ALL]`, `EXCEPT[ALL]`) +- [x] Joins (`INNER`, `LEFT`, `RIGHT`, `FULL`, `CROSS`) +- [x] Window Functions + - [x] Empty (`OVER()`) + - [x] Partitioning and ordering: (`OVER(PARTITION BY <..> ORDER BY <..>)`) + - [x] Custom Window (`ORDER BY time ROWS BETWEEN 2 PRECEDING AND 0 FOLLOWING)`) + - [x] User Defined Window and Aggregate Functions +- [x] Catalogs + - [x] Schemas (`CREATE / DROP SCHEMA`) + - [x] Tables (`CREATE / DROP TABLE`, `CREATE TABLE AS SELECT`) +- [ ] Data Insert + - [x] `INSERT INTO` + - [ ] `COPY .. INTO ..` + - [x] CSV + - [ ] JSON + - [ ] Parquet + - [ ] Avro + +## Runtime + +- [x] Streaming Grouping +- [x] Streaming Window Evaluation +- [x] Memory limits enforced +- [x] Spilling (to disk) Sort +- [ ] Spilling (to disk) Grouping +- [ ] Spilling (to disk) Joins ## Data Sources @@ -132,8 +96,7 @@ In addition to allowing arbitrary datasources via the `TableProvider` trait, DataFusion includes built in support for the following formats: - [x] CSV -- [x] Parquet primitive types -- [x] Parquet nested types +- [x] Parquet (for all primitive and nested types) - [x] JSON - [x] Avro - [x] Arrow diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md new file mode 100644 index 0000000000000..941484e84efd0 --- /dev/null +++ b/docs/source/user-guide/sql/write_options.md @@ -0,0 +1,131 @@ + + +# Write Options + +DataFusion supports customizing how data is written out to disk as a result of a `COPY` or `INSERT INTO` query. There are a few special options, file format (e.g. CSV or parquet) specific options, and parquet column specific options. Options can also in some cases be specified in multiple ways with a set order of precedence. + +## Specifying Options and Order of Precedence + +Write related options can be specified in the following ways: + +- Session level config defaults +- `CREATE EXTERNAL TABLE` options +- `COPY` option tuples + +For a list of supported session level config defaults see [Configuration Settings](configs). These defaults apply to all write operations but have the lowest level of precedence. + +If inserting to an external table, table specific write options can be specified when the table is created using the `OPTIONS` clause: + +```sql +CREATE EXTERNAL TABLE +my_table(a bigint, b bigint) +STORED AS csv +COMPRESSION TYPE gzip +WITH HEADER ROW +DELIMITER ';' +LOCATION '/test/location/my_csv_table/' +OPTIONS( +CREATE_LOCAL_PATH 'true', +NULL_VALUE 'NAN' +); +``` + +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. CREATE_LOCAL_PATH is a special option that indicates if DataFusion should create local file paths when writing new files if they do not already exist. This option is useful if you wish to create an external table from scratch, using only DataFusion SQL statements. Finally, NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. + +Finally, options can be passed when running a `COPY` command. + +```sql +COPY source_table +TO 'test/table_with_options' +(format parquet, +single_file_output false, +compression snappy, +'compression::col1' 'zstd(5)', +) +``` + +In this example, we write the entirety of `source_table` out to a folder of parquet files. The option `single_file_output` set to false, indicates that the destination path should be interpreted as a folder to which the query will output multiple files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. + +## Available Options + +### COPY Specific Options + +The following special options are specific to the `COPY` command. + +| Option | Description | Default Value | +| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| SINGLE_FILE_OUTPUT | If true, COPY query will write output to a single file. Otherwise, multiple files will be written to a directory in parallel. | true | +| FORMAT | Specifies the file format COPY query will write out. If single_file_output is false or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A | + +### CREATE EXTERNAL TABLE Specific Options + +The following special options are specific to creating an external table. + +| Option | Description | Default Value | +| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | +| SINGLE_FILE | If true, indicates that this external table is backed by a single file. INSERT INTO queries will append to this file. | false | +| CREATE_LOCAL_PATH | If true, the folder or file backing this table will be created on the local file system if it does not already exist when running INSERT INTO queries. | false | +| INSERT_MODE | Determines if INSERT INTO queries should append to existing files or append new files to an existing directory. Valid values are append_to_file, append_new_files, and error. Note that "error" will block inserting data into this table. | CSV and JSON default to append_to_file. Parquet defaults to append_new_files | + +### JSON Format Specific Options + +The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | + +### CSV Format Specific Options + +The following options are available when writing CSV files. Note: if any unsupported options is specified an error will be raised and the query will fail. + +| Option | Description | Default Value | +| --------------- | --------------------------------------------------------------------------------------------------------------------------------- | ---------------- | +| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | +| HEADER | Sets if the CSV file should include column headers | false | +| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file | arrow-rs default | +| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file | arrow-rs default | +| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file | arrow-rs default | +| RFC3339 | If true, uses RFC339 format for date and time encodings | arrow-rs default | +| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | arrow-rs default | +| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | arrow-rs default | + +### Parquet Format Specific Options + +The following options are available when writing parquet files. If any unsupported option is specified an error will be raised and the query will fail. If a column specific option is specified for a column which does not exist, the option will be ignored without error. For default values, see: [Configuration Settings](https://arrow.apache.org/datafusion/user-guide/configs.html). + +| Option | Can be Column Specific? | Description | +| ---------------------------- | ----------------------- | ------------------------------------------------------------------------------------------------------------- | +| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group | +| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | +| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | +| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | +| CREATED_BY | No | Sets the "created by" property in the parquet file | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | +| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | +| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | +| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | +| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | +| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | +| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | diff --git a/docs/src/lib.rs b/docs/src/lib.rs new file mode 100644 index 0000000000000..f73132468ec9e --- /dev/null +++ b/docs/src/lib.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod library_logical_plan; diff --git a/docs/src/library_logical_plan.rs b/docs/src/library_logical_plan.rs new file mode 100644 index 0000000000000..3550039415706 --- /dev/null +++ b/docs/src/library_logical_plan.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::error::Result; +use datafusion::logical_expr::builder::LogicalTableSource; +use datafusion::logical_expr::{Filter, LogicalPlan, LogicalPlanBuilder, TableScan}; +use datafusion::prelude::*; +use std::sync::Arc; + +#[test] +fn plan_1() -> Result<()> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // create a TableScan plan + let projection = None; // optional projection + let filters = vec![]; // optional filters to push down + let fetch = None; // optional LIMIT + let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, + )?); + + // create a Filter plan that evaluates `id > 500` and wraps the TableScan + let filter_expr = col("id").gt(lit(500)); + let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); + + // print the plan + println!("{}", plan.display_indent_schema()); + + Ok(()) +} + +#[test] +fn plan_builder_1() -> Result<()> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // optional projection + let projection = None; + + // create a LogicalPlanBuilder for a table scan + let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + + // perform a filter that evaluates `id > 500`, and build the plan + let plan = builder.filter(col("id").gt(lit(500)))?.build()?; + + // print the plan + println!("{}", plan.display_indent_schema()); + + Ok(()) +} diff --git a/parquet-testing b/parquet-testing index a11fc8f148f8a..e45cd23f784aa 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit a11fc8f148f8a7a89d9281cc0da3eb9d56095fbf +Subproject commit e45cd23f784aab3d6bf0701f8f4e621469ed3be7 diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 19ffb5940322b..b9c4db17c0981 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "test-utils" version = "0.1.0" -edition = "2021" +edition = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -26,4 +26,4 @@ edition = "2021" arrow = { workspace = true } datafusion-common = { path = "../datafusion/common" } env_logger = "0.10.0" -rand = "0.8" +rand = { workspace = true } diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index dfd878275181c..0c3668d2f8c0f 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -38,7 +38,7 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { .collect() } -/// extract values from batches and sort them +/// extract i32 values from batches and sort them pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { let mut values: Vec<_> = partitions .iter() @@ -70,13 +70,23 @@ pub fn add_empty_batches( } /// "stagger" batches: split the batches into random sized batches +/// +/// For example, if the input batch has 1000 rows, [`stagger_batch`] might return +/// multiple batches +/// ```text +/// [ +/// RecordBatch(123 rows), +/// RecordBatch(234 rows), +/// RecordBatch(634 rows), +/// ] +/// ``` pub fn stagger_batch(batch: RecordBatch) -> Vec { let seed = 42; stagger_batch_with_seed(batch, seed) } -/// "stagger" batches: split the batches into random sized batches -/// using the specified value for a rng seed +/// "stagger" batches: split the batches into random sized batches using the +/// specified value for a rng seed. See [`stagger_batch`] for more detail. pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec { let mut batches = vec![]; diff --git a/testing b/testing index e81d0c6de3594..98fceecd024dc 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit e81d0c6de35948b3be7984af8e00413b314cde6e +Subproject commit 98fceecd024dccd2f8a00e32fc144975f218acf4